In [1]:
import numpy as np

In [5]:
def calculate_returns(rewards: np.ndarray,
                      dones: np.ndarray,
                      next_value: [np.ndarray, float],
                      discount_factor: float) -> np.ndarray:
    if rewards.shape != dones.shape:
        raise ValueError("rewards and dones must have same shape; either (steps, ) or (batch_size, steps)")
    if rewards.ndim == 1:
        if not isinstance(next_value, (int, float)) or (isinstance(next_value, np.ndarray) and next_value.ndim != 0):
            raise ValueError(f"next_value must be a float scalar")
        if isinstance(next_value, int):
            next_value = float(next_value)
    if rewards.ndim == 2:
        batch_size, steps = rewards.shape
        if next_value.shape != (batch_size, 1):
            raise ValueError(f"next_value's shape must be ({batch_size}, 1)")

    # Bellman backup for Q function
    # Q(s_t,a_t) = R_t + gamma * V(s_t+1)
    num_steps = rewards.shape[-1]
    returns = np.zeros_like(rewards)
    for i in reversed(range(num_steps)):
        returns[i] = rewards[i] + discount_factor * (1 - dones[i]) * next_value
        next_value = returns[i]
    return returns

In [6]:
rewards = np.array([-1, -1, -1, 1, -1, 1], dtype=np.float32)
dones = np.array([0, 0, 0, 0, 0, 1], dtype=np.float32)
next_value = 0.0
discount_factor = 0.9

In [7]:
calculate_returns(rewards, dones, next_value, discount_factor)

array([-2.0466099 , -1.1629    , -0.18099998,  0.91      , -0.1       ,
        1.        ], dtype=float32)

In [8]:
import scipy.signal

def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]

In [9]:
x = np.append(rewards, next_value)
discount_cumsum(x, discount_factor)

array([-2.04661, -1.1629 , -0.181  ,  0.91   , -0.1    ,  1.     ,
        0.     ])

In [10]:
scipy.signal.lfilter?