In [1]:
import numpy as np

import torch 
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
grid_sz = 28

world = torch.zeros(grid_sz,grid_sz)
pos_gen = lambda grid_sz : torch.tensor((torch.randint(0,grid_sz,(1,)),torch.randint(0,grid_sz,(1,))))
start_pos = pos_gen(grid_sz)
end_pos = pos_gen(grid_sz)
start_pos, end_pos

(tensor([18,  6]), tensor([4, 1]))

In [3]:
# calculate actions from start to end, action_space: 0:up, 1:down, 2:left, 3:right
def get_actions(start_pos, end_pos):
    actions = []
    grid_positions = []
    i_dif = end_pos[0] - start_pos[0]
    j_dif = end_pos[1] - start_pos[1]

    for i in range(abs(i_dif)+1):
        if i_dif < 0:
            actions.append(0)
            grid_positions.append(torch.tensor((start_pos[0]-i,start_pos[1])))
        elif i_dif > 0:
            actions.append(1)
            grid_positions.append(torch.tensor((start_pos[0]+i,start_pos[1])))
    
    for j in range(abs(j_dif)+1):
        grid_i = grid_positions[-1][0] if len(grid_positions)>0 else start_pos[0]
        if j_dif < 0:
            actions.append(2)
            grid_positions.append(torch.tensor((grid_i,start_pos[1]-j)))
        elif j_dif > 0:
            actions.append(3)
            grid_positions.append(torch.tensor((grid_i,start_pos[1]+j)))
    
    return torch.tensor((actions)), torch.stack(grid_positions)

In [4]:
actions, grid_positions = get_actions(start_pos.clone(), end_pos)
start_pos, end_pos, grid_positions

(tensor([18,  6]),
 tensor([4, 1]),
 tensor([[18,  6],
         [17,  6],
         [16,  6],
         [15,  6],
         [14,  6],
         [13,  6],
         [12,  6],
         [11,  6],
         [10,  6],
         [ 9,  6],
         [ 8,  6],
         [ 7,  6],
         [ 6,  6],
         [ 5,  6],
         [ 4,  6],
         [ 4,  6],
         [ 4,  5],
         [ 4,  4],
         [ 4,  3],
         [ 4,  2],
         [ 4,  1]]))

In [5]:
# get batch of data 
def get_batch(grid_sz, batch_size):
    acts, grid_pos = [], []
    start_positions, end_positions = [], []
    for i in range(batch_size):
        start_pos = pos_gen(grid_sz)
        end_pos = pos_gen(grid_sz)
        # ensure start != end_pos
        while start_pos[0] == end_pos[0] and start_pos[1] == end_pos[1]:
            end_pos = pos_gen(grid_sz)
        actions, grid_positions = get_actions(start_pos.clone(), end_pos)
        acts.append(actions)
        grid_pos.append(grid_positions)
        start_positions.append(start_pos)
        end_positions.append(end_pos)
    
    return torch.stack(start_positions), torch.stack(end_positions), acts, grid_pos

In [6]:
def make_grids(start_positions, end_positions, grid_sz, b_sz):
    grid = torch.zeros(b_sz, grid_sz, grid_sz)
    b_idx = torch.arange(b_sz)
    grid[b_idx, start_positions[:,0], start_positions[:,1]] = 0.5
    grid[b_idx, end_positions[:,0], end_positions[:,1]] = 1.0
    return grid

def make_grid(positions, grid_sz, b_sz):
    grid = torch.zeros(b_sz, grid_sz, grid_sz)
    b_idx = torch.arange(b_sz)
    grid[b_idx, positions[:,0], positions[:,1]] = 1.0
    return grid

In [7]:
def sample_s_sn_act(grid_positions, actions, b_sz):
    # sample state and next state and act along path toward end_pos 
    sampled_states, next_states, sampled_acts, end_positions = [], [], [], []
    for i in range(b_sz):
        if len(actions[i]) == 1:
            continue 
        idx = torch.randint(0, len(actions[i])-1, (1,))
        sampled_states.append(grid_positions[i][idx])
        next_states.append(grid_positions[i][idx+1])
        sampled_acts.append(actions[i][idx])
        end_positions.append(grid_positions[i][-1])
    sampled_states = torch.stack(sampled_states).squeeze(1)
    next_states = torch.stack(next_states).squeeze(1)
    sampled_acts = torch.stack(sampled_acts)
    end_positions = torch.stack(end_positions)
    return sampled_states, next_states, sampled_acts, end_positions

In [22]:
b_sz = 32
start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
start_positions.shape, end_positions.shape, len(actions), len(grid_positions)
sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
sampled_states.shape, next_states.shape, sampled_acts.shape, end_positions.shape
grids_input = make_grids(sampled_states, end_positions, grid_sz, b_sz)

In [23]:
# map grid pos to 1d index
def grid_pos_to_idx(pos, grid_sz):
    return pos[:,0]*grid_sz + pos[:,1]

# map 1d index to grid pos
def idx_to_grid_pos(idx, grid_sz):
    return torch.stack((idx//grid_sz, idx%grid_sz)).T

In [None]:
# simulate next state, when given a state and action 
def simulate_next_state(states, action, grid_sz):
    # states: (b_sz, 2)
    # action: (b_sz, 1)
    # return: (b_sz, 2)
        
    # act_idx to 2d tensor 0: up, 1: down, 2: left, 3: right
    action_2d = torch.zeros((action.shape[0], 2))
    action_2d[action == 0, 0] = -1
    action_2d[action == 1, 0] = 1
    action_2d[action == 2, 1] = -1
    action_2d[action == 3, 1] = 1
    
    print(action_2d.shape, states.shape)
    next_states = states + action_2d
    next_states = torch.clamp(next_states, 0, grid_sz-1)
    return next_states

In [24]:
idxs = grid_pos_to_idx(next_states, grid_sz)
next_states_ = idx_to_grid_pos(idxs, grid_sz)
(next_states == next_states_).all(), idxs[1]

(tensor(True), tensor(6))

In [103]:
act_dim = 4
grid_sz = 10
# Train step 
sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
grids_input = make_grids(sampled_states, end_positions_, grid_sz, b_sz).unsqueeze(1)
targets = grid_pos_to_idx(next_states, grid_sz)
grids_input.shape, targets.shape

(torch.Size([32, 1, 10, 10]), torch.Size([32]))

In [104]:
dim = 64
state_embedding = nn.Embedding(grid_sz*grid_sz, dim)

state_encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(grid_sz*grid_sz, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, dim),
)

state_predictor = nn.Sequential(
    nn.Linear(dim+dim, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, dim),
)

action_predictor = nn.Sequential(
    nn.Linear(dim+dim, 16),
    nn.ReLU(),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.Linear(16, act_dim),
)

In [105]:
grids_input = make_grid(sampled_states, grid_sz, b_sz)
sg = state_encoder(grids_input)
# s_next_start_idx_pred = state_predictor(sg)
# # s_next_start preds to next state predictions 
# s_next_start_idx = torch.argmax(s_next_start_idx_pred, dim=1)
# s_next_start = idx_to_grid_pos(s_next_start_idx, grid_sz)
# s_next_grid = make_grids(s_next_start, end_positions_, grid_sz, b_sz).unsqueeze(1)
# sg_next = state_goal_encoder(s_next_grid)

# a = action_predictor(torch.cat((sg, sg_next), dim=1))
sg.shape#, s_next_start_idx.shape, sg_next.shape, a.shape

torch.Size([32, 64])

In [106]:
# Train encoder decoder
def train_encoder_decoder_step(encoder, state_embedding, optim, grid_sz, b_sz, epochs):
    for e in epochs:
        start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
        sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
        grids_input = make_grid(sampled_states, grid_sz, b_sz).unsqueeze(1)
        sampled_states_idx = grid_pos_to_idx(sampled_states, grid_sz)

        optim.zero_grad()
        enc = encoder(grids_input)
        logits = enc @ state_embedding.weight.T
        
        loss = F.cross_entropy(logits, sampled_states_idx)
        loss.backward()
        optim.step()

        acc = (torch.argmax(logits, dim=1) == sampled_states_idx).float().mean()

        print(f'Epoch: {e}, Loss: {loss.item()}, Acc: {acc.item()}')

optim = torch.optim.Adam(list(state_encoder.parameters()) + list(state_embedding.parameters()), lr=1e-3)
train_encoder_decoder_step(state_encoder, state_embedding, optim, grid_sz, b_sz, range(300))

Epoch: 0, Loss: 4.8872880935668945, Acc: 0.0
Epoch: 1, Loss: 4.923231601715088, Acc: 0.0
Epoch: 2, Loss: 4.7670135498046875, Acc: 0.03125
Epoch: 3, Loss: 4.667690753936768, Acc: 0.03125
Epoch: 4, Loss: 4.724980354309082, Acc: 0.0
Epoch: 5, Loss: 4.738255023956299, Acc: 0.0
Epoch: 6, Loss: 4.664534568786621, Acc: 0.0
Epoch: 7, Loss: 4.786640167236328, Acc: 0.03125
Epoch: 8, Loss: 4.672857284545898, Acc: 0.0
Epoch: 9, Loss: 4.635302543640137, Acc: 0.03125
Epoch: 10, Loss: 4.679807186126709, Acc: 0.0
Epoch: 11, Loss: 4.508739948272705, Acc: 0.03125
Epoch: 12, Loss: 4.64932918548584, Acc: 0.0
Epoch: 13, Loss: 4.555769443511963, Acc: 0.0
Epoch: 14, Loss: 4.496523857116699, Acc: 0.03125
Epoch: 15, Loss: 4.536107063293457, Acc: 0.03125
Epoch: 16, Loss: 4.781723499298096, Acc: 0.0
Epoch: 17, Loss: 4.590867042541504, Acc: 0.0
Epoch: 18, Loss: 4.653238296508789, Acc: 0.0
Epoch: 19, Loss: 4.415218353271484, Acc: 0.03125
Epoch: 20, Loss: 4.534533977508545, Acc: 0.0
Epoch: 21, Loss: 4.4090070724487

In [108]:
# save state_encoder 
torch.save(state_encoder.state_dict(), 'state_encoder.pt')
torch.save(state_embedding.state_dict(), 'state_embedding.pt')

In [94]:
s = state_encoder(grids_input)
logits = s @ state_embedding.weight.T
pred_s = torch.argmax(logits, dim=1)
sampled_states_idx = grid_pos_to_idx(sampled_states, grid_sz)
(pred_s == sampled_states_idx).float().mean()

tensor(1.)

In [107]:
# Train state predictor
def train_state_predictor_step(encoder, state_predictor, state_embedding, optim, grid_sz, b_sz, epochs):
    for e in epochs: 
        start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
        sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
        s_grids_input = make_grid(sampled_states, grid_sz, b_sz).unsqueeze(1)
        g_grids_input = make_grid(end_positions_, grid_sz, b_sz).unsqueeze(1)

        targets = grid_pos_to_idx(next_states, grid_sz)

        optim.zero_grad()
        with torch.no_grad():
            s = encoder(s_grids_input)
            s_idx = torch.argmax(s @ state_embedding.weight.T, dim=1)
            s_tokens = state_embedding(s_idx)
            g = encoder(g_grids_input)
            g_idx = torch.argmax(g @ state_embedding.weight.T, dim=1)
            g_tokens = state_embedding(g_idx)
        s_next_pred = state_predictor(torch.cat((s_tokens, g_tokens), dim=1))
        logits = s_next_pred @ state_embedding.weight.T
        loss = F.cross_entropy(logits, targets)
        loss.backward()
        optim.step()

        # accuracy 
        acc = (torch.argmax(logits, dim=1) == targets).float().mean()

        print(f'Epoch: {e}, Loss: {loss.item()}, Acc: {acc.item()}')
optim = torch.optim.Adam(list(state_predictor.parameters()), lr=1e-3)
train_state_predictor_step(state_encoder, state_predictor, state_embedding, optim, grid_sz, b_sz, range(3000))

Epoch: 0, Loss: 5.076756477355957, Acc: 0.0
Epoch: 1, Loss: 4.919412612915039, Acc: 0.0
Epoch: 2, Loss: 4.929455757141113, Acc: 0.03125
Epoch: 3, Loss: 4.898471832275391, Acc: 0.0
Epoch: 4, Loss: 4.580976963043213, Acc: 0.125
Epoch: 5, Loss: 4.871441841125488, Acc: 0.03125
Epoch: 6, Loss: 4.670561790466309, Acc: 0.03125
Epoch: 7, Loss: 5.0955305099487305, Acc: 0.03125
Epoch: 8, Loss: 4.646744728088379, Acc: 0.0
Epoch: 9, Loss: 4.495335102081299, Acc: 0.0
Epoch: 10, Loss: 4.7847771644592285, Acc: 0.03125
Epoch: 11, Loss: 4.506476402282715, Acc: 0.03125
Epoch: 12, Loss: 4.66901969909668, Acc: 0.03125
Epoch: 13, Loss: 4.503342628479004, Acc: 0.0625
Epoch: 14, Loss: 4.592710494995117, Acc: 0.03125
Epoch: 15, Loss: 4.5716872215271, Acc: 0.0625
Epoch: 16, Loss: 4.649235248565674, Acc: 0.09375
Epoch: 17, Loss: 4.446005821228027, Acc: 0.0625
Epoch: 18, Loss: 4.638891696929932, Acc: 0.03125
Epoch: 19, Loss: 4.4446868896484375, Acc: 0.0625
Epoch: 20, Loss: 4.3152289390563965, Acc: 0.0625
Epoch: 

In [109]:
# save state_predictor
torch.save(state_predictor.state_dict(), 'state_predictor.pt')

In [None]:
# Train action predictor
def train_action_predictor_step(encoder, state_predictor, action_predictor, state_embedding, optim, grid_sz, b_sz, epochs):
    for e in epochs:
        start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
        sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
        s_grids_input = make_grid(sampled_states, grid_sz, b_sz).unsqueeze(1)
        g_grids_input = make_grid(end_positions_, grid_sz, b_sz).unsqueeze(1)

        optim.zero_grad()
        with torch.no_grad():
            s = encoder(s_grids_input)
            g = encoder(g_grids_input)
            s_next_pred = state_predictor(torch.cat((s, g), dim=1))
            s_next_pred_logits = s_next_pred @ state_embedding.weight.T
            targets = torch.argmax(s_next_pred_logits, dim=1)

        act_pred = action_predictor(torch.cat((s, s_next_pred), dim=1))
        
        with torch.no_grad():
            next_sim_states = simulate_next_state(sampled_states, torch.argmax(act_pred,dim=1), grid_sz).long()
            next_sim_states = encoder(make_grid(next_sim_states, grid_sz, b_sz).unsqueeze(1))
            next_sim_logits = next_sim_states @ state_embedding.weight.T
        
        loss = F.cross_entropy(next_sim_logits, targets)
        loss.backward()
        optim.step()

        # accuracy
        acc = (torch.argmax(next_sim_logits, dim=1) == targets).float().mean()

        print(f'Epoch: {e}, Loss: {loss.item()}, Acc: {acc.item()}')

optim = torch.optim.Adam(list(action_predictor.parameters()), lr=1e-3)
train_action_predictor_step(state_encoder, state_predictor, action_predictor, state_embedding, optim, grid_sz, b_sz, range(3000))

torch.Size([32, 2]) torch.Size([32, 2])


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [79]:
act_logits = action_predictor(torch.cat((s, s), dim=1))
acts = torch.argmax(act_logits, dim=1)

In [72]:
next_sim_states = simulate_next_state(sampled_states, acts, grid_sz)

torch.Size([32, 2]) torch.Size([32, 2])


In [74]:
sampled_states[:5], next_sim_states[:5]

(tensor([[0, 4],
         [3, 6],
         [5, 1],
         [2, 7],
         [2, 3]]),
 tensor([[1., 4.],
         [4., 6.],
         [6., 1.],
         [3., 7.],
         [3., 3.]]))

torch.Size([32])

In [162]:
# Load data
obs_data = np.load('obs_data_cart.npy')
next_obs_data = np.load('next_obs_data_cart.npy')


In [163]:
obs_data[:50]

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [-3.79725099e-02,  2.27987677e-01,  1.60504263e-02,
        -3.10386598e-01],
       [-3.34127583e-02,  3.26407701e-02,  9.84269381e-03,
        -1.26853632e-02],
       [-3.27599421e-02,  2.27620184e-01,  9.58898570e-03,
        -3.02246630e-01],
       [-2.82075368e-02,  3.23628969e-02,  3.54405376e-03,
        -6.55502686e-03],
       [-2.75602788e-02,  2.27433845e-01,  3.41295311e-03,
        -2.98117667e-01],
       [-2.30116025e-02,  3.22634056e-02, -2.54940009e-03,
        -4.36030375e-03],
       [-2.23663338e-02,  2.27421820e-01, -2.63660611e-03,
        -2.97846496e-01],
       [-1.78178977e-02,  3.23375575e-02, -8.59353598e-03,
        -5.99628314e-03],
       [-1.71711463e-02, -1.62660107e-01, -8.71346146e-03,
         2.83962935e-01],
       [-2.04243492e-02,  3.25850360e-02, -3.03420308e-03,
        -1.14553766e-02],
       [-1.97726488e-02,  2.27750376e-01, -3.26331053e-03,
      

In [164]:
next_obs_data

array([[-0.03797251,  0.22798768,  0.01605043, -0.3103866 ],
       [-0.03341276,  0.03264077,  0.00984269, -0.01268536],
       [-0.03275994,  0.22762018,  0.00958899, -0.30224663],
       ...,
       [-0.9658    , -0.32799113,  0.02331792,  0.29227886],
       [-0.97235984, -0.13320926,  0.0291635 ,  0.00704035],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],
      dtype=float32)

In [143]:
next_obs_data.shape

(250000, 4)