# Replay Buffer and n-steps bootstrapping

Goal of this notebook is to test my old code for (at most) n-step value bootstrapping with (possibly) multiple episodes per sample in a batch.

In [2]:
import numpy as np
import torch
import torch.nn as nn

For precisely n-steps and no terminal states, we want to compute this formula
$$V^{(n)}(t) = \sum_{k=0}^{n-1} \gamma^{k} r_{t+k+1} + \gamma^n V(s_{t+n})$$
Where the notation for a state transition is $(s_t, a_t, r_{t+1}, s_{t+1})$.

For n=3 it looks like this:
$$V^{(3)}(t) = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \gamma^3 V(s_{t+3})$$

In [45]:
def compute_n_step_V_trg_v0(n_steps, discount, rewards, done, bootstrap, states, value_net, device="cpu"):
    """
    Compute m-steps value target, with m = min(n, steps-to-episode-end).
    Formula (for precisely n-steps):
        V^{(n)}(t) = \sum_{k=0}^{n-1} gamma^k r_{t+k+1} + gamma^n * V(s_{t+n})
    """
    n_step_rewards, episode_mask, n_steps_mask_b = compute_n_step_rewards(rewards, done, n_steps, discount)
    done[bootstrap] = False 
    trg_states = states[:,1:]
    new_states, Gamma_V, done = compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount)

    new_states = torch.tensor(new_states).float().to(device).reshape((-1,)+states.shape[2:])
    done = torch.LongTensor(done.astype(int)).to(device).reshape(-1)
    n_step_rewards = torch.tensor(n_step_rewards).float().to(device).reshape(-1)
    Gamma_V = torch.tensor(Gamma_V).float().to(device).reshape(-1)

    with torch.no_grad():
        V_pred = value_net(new_states).squeeze()
        V_trg = (1-done)*Gamma_V*V_pred + n_step_rewards
        V_trg = V_trg.squeeze()
    return V_trg

In [3]:
def compute_n_step_rewards(rewards, done, n_steps, discount):
    """
    Computes n-steps discounted reward. 
    Note: the rewards considered are AT MOST n, but can be less for the last n-1 elements.
    """

    B = done.shape[0]
    T = done.shape[1]

    # Compute episode mask (i-th row contains 1 if col j is in the same episode of col i, 0 otherwise)
    episode_mask = [[] for _ in range(B)]
    last = [-1 for _ in range(B)]
    xs, ys = np.nonzero(done)

    # Add done at the end of every batch to avoid exceptions -> not used in real target computations
    xs = np.concatenate([xs, np.arange(B)])
    ys = np.concatenate([ys, np.full(B, T-1)])
    for x, y in zip(xs, ys):
        m = [1 if (i > last[x] and i <= y) else 0 for i in range(T)]
        for _ in range(y-last[x]):
            episode_mask[x].append(m)
        last[x] = y
    episode_mask = np.array(episode_mask)

    # Compute n-steps mask and repeat it B times
    n_steps_mask = []
    for i in range(T):
        m = [1 if (j>=i and j<i+n_steps) else 0 for j in range(T)]
        n_steps_mask.append(m)
    n_steps_mask = np.array(n_steps_mask)
    n_steps_mask_b = np.repeat(n_steps_mask[np.newaxis,...] , B, axis=0)

    # Broadcast rewards to use multiplicative masks
    rewards_repeated = np.repeat(rewards[:,np.newaxis,:], T, axis=1)

    # Exponential discount factor
    Gamma = np.array([discount**i for i in range(T)]).reshape(1,-1)
    n_steps_r = (Gamma*rewards_repeated*episode_mask*n_steps_mask_b).sum(axis=2)/Gamma
    return n_steps_r, episode_mask, n_steps_mask_b


In [12]:
def compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount):
    """
    Computes n-steps target states (to be used by the critic as target values together with the
    n-steps discounted reward). For last n-1 elements the target state is the last one available.
    Adjusts also the `done` mask used for disabling the bootstrapping in the case of terminal states
    and returns Gamma_V, that are the discount factors for the target state-values, since they are 
    n-steps away (except for the last n-1 states, whose discount is adjusted accordingly).

    Return
    ------
    new_states, Gamma_V, done: arrays with first dimension = len(states)-1
    """

    B = done.shape[0]
    T = done.shape[1]
    V_mask = episode_mask*n_steps_mask_b
    b, x, y = np.nonzero(V_mask)
    V_trg_index = [[] for _ in range(B)]
    for b_i in range(B):
        valid_x = (b==b_i)
        for i in range(T):
            matching_x = (x==i)
            V_trg_index[b_i].append(y[valid_x*matching_x][-1])
    V_trg_index = np.array(V_trg_index)

    cols = np.array([], dtype=np.int)
    rows = np.array([], dtype=np.int)
    for i, v in enumerate(V_trg_index):
        cols = np.concatenate([cols, v], axis=0)
        row = np.full(V_trg_index.shape[1], i)
        rows = np.concatenate([rows, row], axis=0)
    ###
    new_states = trg_states[rows, cols].reshape(trg_states.shape)
    ###
    pw = V_trg_index - np.arange(V_trg_index.shape[1]) + 1
    Gamma_V = discount**pw
    shifted_done = done[rows, cols].reshape(done.shape)
    return new_states, Gamma_V, shifted_done

In [34]:
discount = 0.9
n_steps = 3 # number of rewards to sum-up at most before adding the discounted value of the next state
B = 2 # batch_size
T = 10 # trajectory length for each sample in the batch

rewards = np.array([
    [0,1,0,0,0,0,1,0,0,0],
    [0,0,0,1,0,0,0,1,0,0]
], dtype=np.float)

done = np.array([
    [0,1,0,0,0,0,1,0,0,1],
    [0,0,0,1,0,0,0,1,0,1]
], dtype=np.bool)

# Old notation that I was using
bootstrap = np.array([
    [0,0,0,0,0,0,0,0,0,1],
    [0,0,0,0,0,0,0,0,0,1]
], dtype=np.bool)
# In our current notation we don't have a bootstrap variable and our done can be obtained by done[boostrap] = 0

# Let's say we have a binary state and the value of the state is 0.5 x state (so either 0 or 0.5)
states = np.zeros((B,T+1,1), dtype=np.float) # states are 1 step longer than te rest of the signals
states[0,-1,0] = 1. # let's just keep it simple and check the value bootstrapping on a single state

expected_n_step_rewards = np.array([
    [0.9,1,0,0,0.81,0.9,1,0,0,0],
    [0,0.81,0.9,1,0,0.81,0.9,1,0,0]
], dtype=np.float)


expected_v_trg = torch.tensor([
    [0.9,1,0,0,0.81,0.9,1,0.3645,0.405,0.45],
    [0,0.81,0.9,1,0,0.81,0.9,1,0,0]
]).float()

In [23]:
class ValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return 0.5*x
    
value_net = ValueNet()

In [39]:
n_step_rewards, episode_mask, n_steps_mask_b = compute_n_step_rewards(rewards, done, n_steps, discount)
assert np.allclose(n_step_rewards, expected_n_step_rewards), "n-step-rewards do not match the expected values"
print("Success")

Success


In [7]:
episode_mask

array([[[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]],

       [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]])

In [8]:
n_steps_mask_b

array([[[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]],

       [[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]])

In [13]:
trg_states = states[:,1:]
new_states, Gamma_V, shifted_done = compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount)

In [14]:
new_states # (at most) n_step away target state

array([[[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.]],

       [[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]]])

In [15]:
Gamma_V # always discounted at least of a factor gamma, up to gamma**n_steps

array([[0.81 , 0.9  , 0.729, 0.729, 0.729, 0.81 , 0.9  , 0.729, 0.81 ,
        0.9  ],
       [0.729, 0.729, 0.81 , 0.9  , 0.729, 0.729, 0.81 , 0.9  , 0.81 ,
        0.9  ]])

In [16]:
shifted_done # whether after (at most) n_steps a terminal state has been reached

array([[ True,  True, False, False,  True,  True,  True,  True,  True,
         True],
       [False,  True,  True,  True, False,  True,  True,  True,  True,
         True]])

In [46]:
v_trg = compute_n_step_V_trg_v0(n_steps, discount, rewards, done, bootstrap, states, value_net, device="cpu")
print("Target values (reshaped): \n", v_trg.reshape(B,T))
print("Expected: \n", expected_v_trg)
assert torch.allclose(v_trg.reshape(B,T), expected_v_trg), "Wrong vtarget values"
print("Success")

Target values (reshaped): 
 tensor([[0.9000, 1.0000, 0.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.3645, 0.4050,
         0.4500],
        [0.0000, 0.8100, 0.9000, 1.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.0000,
         0.0000]])
Expected: 
 tensor([[0.9000, 1.0000, 0.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.3645, 0.4050,
         0.4500],
        [0.0000, 0.8100, 0.9000, 1.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.0000,
         0.0000]])
Success


## Removing bootstrap signal

How it was used: we had as a default done=True at the end of each trajectory, even though it might have just been truncated while the episode was still going on. The done=True helped to signal that we should not take more steps than to that timestep, but then needs to be turned off so that we know that the final step needs bootstrapping. 

In [41]:
new_done = np.array([
    [0,1,0,0,0,0,1,0,0,0],
    [0,0,0,1,0,0,0,1,0,0]
], dtype=np.bool)

# get the old done as
old_done = new_done.copy()

In [42]:
old_done[:,-1] = True
old_done

array([[False,  True, False, False, False, False,  True, False, False,
         True],
       [False, False, False,  True, False, False, False,  True, False,
         True]])

In [44]:
new_done # does not change the original variable

array([[False,  True, False, False, False, False,  True, False, False,
        False],
       [False, False, False,  True, False, False, False,  True, False,
        False]])

In [49]:
def compute_n_step_V_trg(n_steps, discount, rewards, done, states, value_net, device="cpu"):
    """
    Compute m-steps value target, with m = min(n_steps, steps-to-episode-end).
    Formula (for precisely n-steps):
        V^{(n)}(t) = \sum_{k=}^{n-1} gamma^k r_{t+k+1} + gamma^n * V(s_{t+n})
        
    Input
    -----
    n_steps: int
        How many steps in the future to consider before bootstrapping while computing the value target
    discount: float in (0,1)
        Discount factor of the MDP
    rewards: np.array of shape (B,T), type float
    done: np.array of shape (B,T), type bool
    states: np.array of shape (B,T,...)
    value_net: instance of nn.Module
        outputs values of shape (B*T,) given states reshaped as (B*T,...)
    
    """
    done_plus_ending = done.copy()
    done_plus_ending[:,-1] = True
    n_step_rewards, episode_mask, n_steps_mask_b = compute_n_step_rewards(rewards, done_plus_ending, n_steps, discount)
    ###
    trg_states = states[:,1:]
    ###
    new_states, Gamma_V, done = compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount)

    ###
    new_states = torch.tensor(new_states).float().to(device).reshape((-1,)+states.shape[2:])
    ###
    done = torch.LongTensor(done.astype(int)).to(device).reshape(-1)
    n_step_rewards = torch.tensor(n_step_rewards).float().to(device).reshape(-1)
    Gamma_V = torch.tensor(Gamma_V).float().to(device).reshape(-1)

    with torch.no_grad():
        V_pred = value_net(new_states).squeeze()
        V_trg = (1-done)*Gamma_V*V_pred + n_step_rewards
        V_trg = V_trg.squeeze()
    return V_trg

In [48]:
v_trg = compute_n_step_V_trg(n_steps, discount, rewards, new_done, states, value_net, device="cpu")
print("Target values (reshaped): \n", v_trg.reshape(B,T))
print("Expected: \n", expected_v_trg)
assert torch.allclose(v_trg.reshape(B,T), expected_v_trg), "Wrong vtarget values"
print("Success")

Target values (reshaped): 
 tensor([[0.9000, 1.0000, 0.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.3645, 0.4050,
         0.4500],
        [0.0000, 0.8100, 0.9000, 1.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.0000,
         0.0000]])
Expected: 
 tensor([[0.9000, 1.0000, 0.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.3645, 0.4050,
         0.4500],
        [0.0000, 0.8100, 0.9000, 1.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.0000,
         0.0000]])
Success


## Dealing with a state that is a dictionary of tensors

In [61]:
def compute_n_step_V_trg(n_steps, discount, rewards, done, states, value_net, device="cpu"):
    """
    Compute m-steps value target, with m = min(n_steps, steps-to-episode-end).
    Formula (for precisely n-steps):
        V^{(n)}(t) = \sum_{k=}^{n-1} gamma^k r_{t+k+1} + gamma^n * V(s_{t+n})
        
    Input
    -----
    n_steps: int
        How many steps in the future to consider before bootstrapping while computing the value target
    discount: float in (0,1)
        Discount factor of the MDP
    rewards: np.array of shape (B,T), type float
    done: np.array of shape (B,T), type bool
    states: dictionary of tensors all of shape (B,T,...)
    value_net: instance of nn.Module
        outputs values of shape (B*T,) given states reshaped as (B*T,...)
    
    """
    done_plus_ending = done.copy()
    done_plus_ending[:,-1] = True
    n_step_rewards, episode_mask, n_steps_mask_b = compute_n_step_rewards(rewards, done_plus_ending, n_steps, discount)
    trg_states = {}
    for k in states.keys():
        trg_states[k] = states[k][:,1:]
    new_states, Gamma_V, done = compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount)

    new_states_reshaped = {}
    for k in new_states.keys():
        new_states_reshaped[k] = new_states[k].reshape((-1,)+new_states[k].shape[2:])
    done = torch.LongTensor(done.astype(int)).to(device).reshape(-1)
    n_step_rewards = torch.tensor(n_step_rewards).float().to(device).reshape(-1)
    Gamma_V = torch.tensor(Gamma_V).float().to(device).reshape(-1)

    with torch.no_grad():
        V_pred = value_net(new_states_reshaped).squeeze()
        V_trg = (1-done)*Gamma_V*V_pred + n_step_rewards
        V_trg = V_trg.squeeze()
    return V_trg

In [53]:
def compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount):
    """
    Computes n-steps target states (to be used by the critic as target values together with the
    n-steps discounted reward). For last n-1 elements the target state is the last one available.
    Adjusts also the `done` mask used for disabling the bootstrapping in the case of terminal states
    and returns Gamma_V, that are the discount factors for the target state-values, since they are 
    n-steps away (except for the last n-1 states, whose discount is adjusted accordingly).

    Return
    ------
    new_states, Gamma_V, done: arrays with first dimension = len(states)-1
    """

    B = done.shape[0]
    T = done.shape[1]
    V_mask = episode_mask*n_steps_mask_b
    b, x, y = np.nonzero(V_mask)
    V_trg_index = [[] for _ in range(B)]
    for b_i in range(B):
        valid_x = (b==b_i)
        for i in range(T):
            matching_x = (x==i)
            V_trg_index[b_i].append(y[valid_x*matching_x][-1])
    V_trg_index = np.array(V_trg_index)

    cols = np.array([], dtype=np.int)
    rows = np.array([], dtype=np.int)
    for i, v in enumerate(V_trg_index):
        cols = np.concatenate([cols, v], axis=0)
        row = np.full(V_trg_index.shape[1], i)
        rows = np.concatenate([rows, row], axis=0)
    
    new_states = {}
    for k in trg_states.keys(): 
        new_states[k] = trg_states[k][rows, cols].reshape(trg_states[k].shape)

    pw = V_trg_index - np.arange(V_trg_index.shape[1]) + 1
    Gamma_V = discount**pw
    shifted_done = done[rows, cols].reshape(done.shape)
    return new_states, Gamma_V, shifted_done

In [63]:
# assuming all entries are tensors of shape (B,T,...)
frames = {"frame":torch.tensor(states).float()} # pay attention here to the format in float
frames

{'frame': tensor([[[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [1.]],
 
         [[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]])}

In [64]:
class ValueNetFrames(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return 0.5*x["frame"]
    
value_net = ValueNetFrames()

In [65]:
v_trg = compute_n_step_V_trg(n_steps, discount, rewards, new_done, frames, value_net, device="cpu")
print("Target values (reshaped): \n", v_trg.reshape(B,T))
print("Expected: \n", expected_v_trg)
assert torch.allclose(v_trg.reshape(B,T), expected_v_trg), "Wrong vtarget values"
print("Success")

Target values (reshaped): 
 tensor([[0.9000, 1.0000, 0.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.3645, 0.4050,
         0.4500],
        [0.0000, 0.8100, 0.9000, 1.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.0000,
         0.0000]])
Expected: 
 tensor([[0.9000, 1.0000, 0.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.3645, 0.4050,
         0.4500],
        [0.0000, 0.8100, 0.9000, 1.0000, 0.0000, 0.8100, 0.9000, 1.0000, 0.0000,
         0.0000]])
Success


## How to integrate this with the replay buffer

1. Make sure that everything except frames is a numpy array
2. Use functions above instead of get_cumulative_reward and use them during get_batch instead of batch_episode

In [None]:
### Replay buffer stuff ###
def get_cumulative_rewards(rewards, discount, dones):
    cum_disc_rewards = []
    cum_r = 0
    for i,r in enumerate(reversed(rewards)):
        not_done = 1 - dones[-(i+1)]
        cum_r = not_done*discount*cum_r + r
        cum_disc_rewards.append (cum_r)
    cum_disc_rewards = torch.tensor(cum_disc_rewards[::-1])
    return cum_disc_rewards

In [4]:
class nStepsReplayBuffer:
    def __init__(self, mem_size, discount):
        self.mem_size = mem_size
        self.discount = discount
        self.frame_buffer = []
        self.reward_buffer = []
        self.done_buffer = []
        
    def store_episode(self, frame_lst, reward_lst, done_lst):
        frames, rewards, done = self.batch_episode(frame_lst, reward_lst, done_lst)
        self.frame_buffer.append(frames)
        self.reward_buffer.append(rewards)
        self.done_buffer.append(done)
        if len(self.frame_buffer) > self.mem_size:
            self.frame_buffer.pop(0)
            self.reward_buffer.pop(0)
            self.done_buffer.pop(0)
            
    def batch_episode(self, frame_lst, reward_lst, done_lst):
        """
        Unifies the time dimension fo the data and adds a batch dimension of 1 in front
        """
        episode_len = len(reward_lst)
        frames = {}
        for k in frame_lst[0].keys():
            k_value_lst = []
            for b in range(episode_len):
                k_value_lst.append(frame_lst[b][k])
            k_value_lst = torch.cat(k_value_lst, axis=0)
            frames[k] = k_value_lst.unsqueeze(0) # add batch size dimension in front
            
        rewards = np.array(reward_lst, dtype=np.float).reshape(1,-1)  # add batch size dimension in front
        done = np.array(done_lst, dtype=np.bool).reshape(1,-1)  # add batch size dimension in front
        
        return frames, rewards, done
    
    def get_batch(self, batch_size, n_steps, discount, target_net, device="cpu"):
        # Decide which indexes to sample
        id_range = len(self.frame_buffer)
        assert id_range >= batch_size, "Not enough samples stored to get this batch size"
        sampled_ids = np.random.choice(id_range, size=batch_size, replace=False)
        
        # Sample frames, rewards and done
        sampled_rewards = np.array([self.reward_buffer[i] for i in sampled_ids])
        sampled_done = np.array([self.done_buffer[i] for i in sampled_ids])
        # batch together frames 
        sampled_frames = {}
        for k in self.frame_buffer[0].keys():
            key_values = torch.cat([self.frame_buffer[i][k] for i in sampled_ids], axis=0)
            sampled_frames[k] = key_values
            
        # sampled_targets of shape (B*T,)
        sampled_targets = self.compute_n_step_V_trg(n_steps, discount, sampled_rewards, sampled_done, 
                                                    sampled_frames, target_net, device)
        # Flatten also the sampled_frames
        reshaped_frames = {}
        for k in sampled_frames.keys():
            shape = sampled_frames[k].shape
            reshaped_frames[k] = sampled_frames[k].reshape(-1,*shape[2:])

        return reshaped_frames, sampled_targets
    
    def compute_n_step_V_trg(self, n_steps, discount, rewards, done, states, value_net, device="cpu"):
        """
        Compute m-steps value target, with m = min(n_steps, steps-to-episode-end).
        Formula (for precisely n-steps):
            V^{(n)}(t) = \sum_{k=}^{n-1} gamma^k r_{t+k+1} + gamma^n * V(s_{t+n})

        Input
        -----
        n_steps: int
            How many steps in the future to consider before bootstrapping while computing the value target
        discount: float in (0,1)
            Discount factor of the MDP
        rewards: np.array of shape (B,T), type float
        done: np.array of shape (B,T), type bool
        states: dictionary of tensors all of shape (B,T,...)
        value_net: instance of nn.Module
            outputs values of shape (B*T,) given states reshaped as (B*T,...)

        """
        done_plus_ending = done.copy()
        done_plus_ending[:,-1] = True
        n_step_rewards, episode_mask, n_steps_mask_b = compute_n_step_rewards(rewards, done_plus_ending, n_steps, discount)
        trg_states = {}
        for k in states.keys():
            trg_states[k] = states[k][:,1:]
        new_states, Gamma_V, done = compute_n_step_states(trg_states, done, episode_mask, n_steps_mask_b, discount)

        new_states_reshaped = {}
        for k in new_states.keys():
            new_states_reshaped[k] = new_states[k].reshape((-1,)+new_states[k].shape[2:])
        done = torch.LongTensor(done.astype(int)).to(device).reshape(-1)
        n_step_rewards = torch.tensor(n_step_rewards).float().to(device).reshape(-1)
        Gamma_V = torch.tensor(Gamma_V).float().to(device).reshape(-1)

        with torch.no_grad():
            V_pred = value_net(new_states_reshaped).squeeze()
            V_trg = (1-done)*Gamma_V*V_pred + n_step_rewards
            V_trg = V_trg.squeeze()
        return V_trg
    
    def compute_n_step_rewards(self, rewards, done, n_steps, discount):
        """
        Computes n-steps discounted reward. 
        Note: the rewards considered are AT MOST n, but can be less for the last n-1 elements.
        """

        B = done.shape[0]
        T = done.shape[1]

        # Compute episode mask (i-th row contains 1 if col j is in the same episode of col i, 0 otherwise)
        episode_mask = [[] for _ in range(B)]
        last = [-1 for _ in range(B)]
        xs, ys = np.nonzero(done)

        # Add done at the end of every batch to avoid exceptions -> not used in real target computations
        xs = np.concatenate([xs, np.arange(B)])
        ys = np.concatenate([ys, np.full(B, T-1)])
        for x, y in zip(xs, ys):
            m = [1 if (i > last[x] and i <= y) else 0 for i in range(T)]
            for _ in range(y-last[x]):
                episode_mask[x].append(m)
            last[x] = y
        episode_mask = np.array(episode_mask)

        # Compute n-steps mask and repeat it B times
        n_steps_mask = []
        for i in range(T):
            m = [1 if (j>=i and j<i+n_steps) else 0 for j in range(T)]
            n_steps_mask.append(m)
        n_steps_mask = np.array(n_steps_mask)
        n_steps_mask_b = np.repeat(n_steps_mask[np.newaxis,...] , B, axis=0)

        # Broadcast rewards to use multiplicative masks
        rewards_repeated = np.repeat(rewards[:,np.newaxis,:], T, axis=1)

        # Exponential discount factor
        Gamma = np.array([discount**i for i in range(T)]).reshape(1,-1)
        n_steps_r = (Gamma*rewards_repeated*episode_mask*n_steps_mask_b).sum(axis=2)/Gamma
        return n_steps_r, episode_mask, n_steps_mask_b

    def compute_n_step_states(self, trg_states, done, episode_mask, n_steps_mask_b, discount):
        """
        Computes n-steps target states (to be used by the critic as target values together with the
        n-steps discounted reward). For last n-1 elements the target state is the last one available.
        Adjusts also the `done` mask used for disabling the bootstrapping in the case of terminal states
        and returns Gamma_V, that are the discount factors for the target state-values, since they are 
        n-steps away (except for the last n-1 states, whose discount is adjusted accordingly).

        Return
        ------
        new_states, Gamma_V, done: arrays with first dimension = len(states)-1
        """

        B = done.shape[0]
        T = done.shape[1]
        V_mask = episode_mask*n_steps_mask_b
        b, x, y = np.nonzero(V_mask)
        V_trg_index = [[] for _ in range(B)]
        for b_i in range(B):
            valid_x = (b==b_i)
            for i in range(T):
                matching_x = (x==i)
                V_trg_index[b_i].append(y[valid_x*matching_x][-1])
        V_trg_index = np.array(V_trg_index)

        cols = np.array([], dtype=np.int)
        rows = np.array([], dtype=np.int)
        for i, v in enumerate(V_trg_index):
            cols = np.concatenate([cols, v], axis=0)
            row = np.full(V_trg_index.shape[1], i)
            rows = np.concatenate([rows, row], axis=0)

        new_states = {}
        for k in trg_states.keys(): 
            new_states[k] = trg_states[k][rows, cols].reshape(trg_states[k].shape)

        pw = V_trg_index - np.arange(V_trg_index.shape[1]) + 1
        Gamma_V = discount**pw
        shifted_done = done[rows, cols].reshape(done.shape)
        return new_states, Gamma_V, shifted_done

In [None]:
def compute_update_v1(value_net, frames, targets, loss_fn, optimizer):
    values = value_net(reshaped_frames).squeeze(1)
    
    loss = loss_fn(values, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [3]:
r = [
    [0,1,0,0,0,0,1,0,0,0],
    [0,0,0,1,0,0,0,1,0,0]
]

rewards = np.array(r)
rewards

array([[0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 1, 0, 0]])