# Trajectory replay

## Initial setup

In [3]:
import sys
sys.path.insert(0, "../python/")
from memory.ReplayMemory import ReplayMemory
import numpy as np
import random
import copy

## Numpy indexing

In [19]:
idx = np.random.randint(0, 100, 7)
print("idx:", idx)

x, y = np.meshgrid(idx, np.arange(5))
slices = np.transpose(x + y).flatten()
print("Slices:\n", slices)
slices %= 100
print("Corrected slices:\n", slices)

s1_0 = np.random.rand(100, 12, 84, 84) # mimics memory with capacity 100
s1_1 = np.random.rand(2)
s1 = [s1_0, s1_1]
print("Original:")
print("s1[0].shape:", s1[0].shape)
print("s1[0][slices][0, ...]:")
print(s1[0][slices][0, 0, :3, :3])
print("s1[0][slices][5, ...]:")
print(s1[0][slices][5, 0, :3, :3])

rs = np.reshape(s1[0][slices], [7, 5, 12, 84, 84]) # [num of traj, traj len, state_shape...]
print("Reshape:")
print("s1[0][slices].shape:", rs.shape)
print("s1[0][slices][0][...]:")
print(rs[0, 0, 0, :3, :3])
print("s1[0][slices][1][...]:")
print(rs[1, 0, 0, :3, :3])

idx: [58  8 24 44 72 70 19]
Slices:
 [58 59 60 61 62  8  9 10 11 12 24 25 26 27 28 44 45 46 47 48 72 73 74 75 76
 70 71 72 73 74 19 20 21 22 23]
Corrected slices:
 [58 59 60 61 62  8  9 10 11 12 24 25 26 27 28 44 45 46 47 48 72 73 74 75 76
 70 71 72 73 74 19 20 21 22 23]
Original:
s1[0].shape: (100, 12, 84, 84)
s1[0][slices][0, ...]:
[[ 0.53546728  0.34858874  0.99254931]
 [ 0.98171926  0.22050952  0.32387461]
 [ 0.50426754  0.38082718  0.53164911]]
s1[0][slices][5, ...]:
[[ 0.90396477  0.21205504  0.47445172]
 [ 0.63509333  0.78279289  0.96300232]
 [ 0.17780579  0.22307001  0.1404238 ]]
Reshape:
s1[0][slices].shape: (7, 5, 12, 84, 84)
s1[0][slices][0][...]:
[[ 0.53546728  0.34858874  0.99254931]
 [ 0.98171926  0.22050952  0.32387461]
 [ 0.50426754  0.38082718  0.53164911]]
s1[0][slices][1][...]:
[[ 0.90396477  0.21205504  0.47445172]
 [ 0.63509333  0.78279289  0.96300232]
 [ 0.17780579  0.22307001  0.1404238 ]]


## Class definition

In [42]:
class TrajectoryReplayMemory(ReplayMemory):
    def __init__(self, capacity, state_shape, num_game_var, input_overlap=0, 
                 trajectory_length=5):
        # Initialize base replay memory
        ReplayMemory.__init__(self, capacity, state_shape, num_game_var, input_overlap)

        # Initialize trajectory parameters
        self.tr_len = trajectory_length
        
    def get_sample(self, sample_size):
        # Get random minibatch of indices
        idx = np.random.randint(0, self.size, sample_size)
        x, y = np.meshgrid(idx, np.arange(self.tr_len))
        idx = np.transpose(x + y).flatten() # [i, i+1, ..., i+n, j, j+1, ..., j+n, k...]
        idx %= self.capacity # wrap end cases
        # TODO: find isterminal in sequences and cut short
        
        # s = [screen[trajectory id, trajectory step, state_shape...], gv[traj id, traj step, num_gv]]
        
        def create_sample(idx, t):    
            # Make list of states
            s1_sample, s2_sample = [], []
            
            # Get screen component
            s1_slice = self.s1[0][idx]
            if self.overlap > 0:
                # Stack overlapping frames from s1 to stored frames of s2 to
                # recreate full s2 state
                s2_slice = np.concatenate((self.s1[0][[idx] + [slice(None)] * self.chdim 
                                                 + [slice(None, self.overlap)]], 
                                                 self.s2[0][idx]), 
                                                axis=self.chdim+1)
            else:
                s2_slice = self.s2[0][idx]
            s1_sample.append(np.reshape(s1_slice, [sample_size, t] + self.state_shape))
            s2_sample.append(np.reshape(s2_slice, [sample_size, t] + self.state_shape))
            
            # Get game variable component
            s1_sample.append(np.reshape(self.s1[1][idx], [sample_size, t] + [self.num_game_var]))
            s2_sample.append(np.reshape(self.s2[1][idx], [sample_size, t] + [self.num_game_var]))
            
            # Get other transition parameters
            a_sample = np.reshape(self.a[idx], [sample_size, t])
            isterminal_sample = np.reshape(self.isterminal[idx], [sample_size, t])
            r_sample = np.reshape(self.r[idx], [sample_size, t])
            
            # Return importance sampling weights of one (stochastic distribution)
            w = np.ones([sample_size, t])
            
            return s1_sample, a_sample, s2_sample, isterminal_sample, r_sample, w, idx
        
        return create_sample(idx, self.tr_len)

In [43]:
capacity = 100
state_shape = [12, 84, 84]
num_game_var = 2
input_overlap = 3
trajectory_length = 10
memory = TrajectoryReplayMemory(capacity=capacity,
                                state_shape=state_shape,
                                num_game_var=num_game_var,
                                input_overlap=input_overlap,
                                trajectory_length=trajectory_length)
terminal_states = random.sample(range(capacity), 7)
s1, s2 = [], []
s2.append(np.random.rand(state_shape[0], state_shape[1], state_shape[2]))
s2.append([random.random()] * num_game_var)
for i in range(capacity):
    s1 = [s2[0], s2[1]]
    s2[0] = np.delete(s2[0], np.s_[0:input_overlap], axis=0)
    s2[0] = np.append(s2[0], np.random.rand(input_overlap, state_shape[1], state_shape[2]), axis=0)
    s2[1] = [random.random()] * num_game_var
    a = random.sample(range(4), 1)[0]
    r = random.random()
    isterminal = i in terminal_states
    memory.add_transition(s1, a, s2, isterminal, r)  

In [48]:
s1, a, s2, isterminal, r, w, idx = memory.get_sample(7)
print("idx:", idx)
print("s1[0].shape:", s1[0].shape)
print("s1[1].shape:", s1[1].shape)
print("a.shape:", a.shape)
print("s2[0].shape:", s2[0].shape)
print("s2[1].shape:", s2[1].shape)
print("isterminal.shape:", isterminal.shape)
print("r.shape:", r.shape)
print("w.shape:", w.shape)

idx: [86 87 88 89 90 91 92 93 94 95 82 83 84 85 86 87 88 89 90 91  2  3  4  5  6
  7  8  9 10 11 17 18 19 20 21 22 23 24 25 26  8  9 10 11 12 13 14 15 16 17
 32 33 34 35 36 37 38 39 40 41 90 91 92 93 94 95 96 97 98 99]
s1[0].shape: (7, 10, 12, 84, 84)
s1[1].shape: (7, 10, 2)
a.shape: (7, 10)
s2[0].shape: (7, 10, 12, 84, 84)
s2[1].shape: (7, 10, 2)
isterminal.shape: (7, 10)
r.shape: (7, 10)
w.shape: (7, 10)
