In [87]:
import numpy as np

import torch 
import torch.nn as nn
import torch.nn.functional as F
from vicreg import variance_loss, covariance_loss

In [2]:
grid_sz = 28

world = torch.zeros(grid_sz,grid_sz)
pos_gen = lambda grid_sz : torch.tensor((torch.randint(0,grid_sz,(1,)),torch.randint(0,grid_sz,(1,))))
start_pos = pos_gen(grid_sz)
end_pos = pos_gen(grid_sz)
start_pos, end_pos

(tensor([10, 13]), tensor([13,  1]))

In [3]:
# calculate actions from start to end, action_space: 0:up, 1:down, 2:left, 3:right
def get_actions(start_pos, end_pos):
    actions = []
    grid_positions = []
    i_dif = end_pos[0] - start_pos[0]
    j_dif = end_pos[1] - start_pos[1]

    for i in range(abs(i_dif)+1):
        if i_dif < 0:
            actions.append(0)
            grid_positions.append(torch.tensor((start_pos[0]-i,start_pos[1])))
        elif i_dif > 0:
            actions.append(1)
            grid_positions.append(torch.tensor((start_pos[0]+i,start_pos[1])))
    
    for j in range(abs(j_dif)+1):
        grid_i = grid_positions[-1][0] if len(grid_positions)>0 else start_pos[0]
        if j_dif < 0:
            actions.append(2)
            grid_positions.append(torch.tensor((grid_i,start_pos[1]-j)))
        elif j_dif > 0:
            actions.append(3)
            grid_positions.append(torch.tensor((grid_i,start_pos[1]+j)))
    
    return torch.tensor((actions)), torch.stack(grid_positions)

In [4]:
actions, grid_positions = get_actions(start_pos.clone(), end_pos)
start_pos, end_pos, grid_positions

(tensor([10, 13]),
 tensor([13,  1]),
 tensor([[10, 13],
         [11, 13],
         [12, 13],
         [13, 13],
         [13, 13],
         [13, 12],
         [13, 11],
         [13, 10],
         [13,  9],
         [13,  8],
         [13,  7],
         [13,  6],
         [13,  5],
         [13,  4],
         [13,  3],
         [13,  2],
         [13,  1]]))

In [5]:
# get batch of data 
def get_batch(grid_sz, batch_size):
    acts, grid_pos = [], []
    start_positions, end_positions = [], []
    for i in range(batch_size):
        start_pos = pos_gen(grid_sz)
        end_pos = pos_gen(grid_sz)
        # ensure start != end_pos
        while start_pos[0] == end_pos[0] and start_pos[1] == end_pos[1]:
            end_pos = pos_gen(grid_sz)
        actions, grid_positions = get_actions(start_pos.clone(), end_pos)
        acts.append(actions)
        grid_pos.append(grid_positions)
        start_positions.append(start_pos)
        end_positions.append(end_pos)
    
    return torch.stack(start_positions), torch.stack(end_positions), acts, grid_pos

In [6]:
def make_grids(start_positions, end_positions, grid_sz, b_sz):
    grid = torch.zeros(b_sz, grid_sz, grid_sz)
    b_idx = torch.arange(b_sz)
    grid[b_idx, start_positions[:,0], start_positions[:,1]] = 0.5
    grid[b_idx, end_positions[:,0], end_positions[:,1]] = 1.0
    return grid

def make_grid(positions, grid_sz, b_sz):
    grid = torch.zeros(b_sz, grid_sz, grid_sz)
    b_idx = torch.arange(b_sz)
    grid[b_idx, positions[:,0], positions[:,1]] = 1.0
    return grid

In [7]:
def sample_s_sn_act(grid_positions, actions, b_sz):
    # sample state and next state and act along path toward end_pos 
    sampled_states, next_states, sampled_acts, end_positions = [], [], [], []
    for i in range(b_sz):
        if len(actions[i]) == 1:
            continue 
        idx = torch.randint(0, len(actions[i])-1, (1,))
        sampled_states.append(grid_positions[i][idx])
        next_states.append(grid_positions[i][idx+1])
        sampled_acts.append(actions[i][idx])
        end_positions.append(grid_positions[i][-1])
    sampled_states = torch.stack(sampled_states).squeeze(1)
    next_states = torch.stack(next_states).squeeze(1)
    sampled_acts = torch.stack(sampled_acts)
    end_positions = torch.stack(end_positions)
    return sampled_states, next_states, sampled_acts, end_positions

In [13]:
b_sz = 32
start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
start_positions.shape, end_positions.shape, len(actions), len(grid_positions)
sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
sampled_states.shape, next_states.shape, sampled_acts.shape, end_positions.shape
grids_input = make_grids(sampled_states, end_positions, grid_sz, b_sz)

In [9]:
# map grid pos to 1d index
def grid_pos_to_idx(pos, grid_sz):
    return pos[:,0]*grid_sz + pos[:,1]

# map 1d index to grid pos
def idx_to_grid_pos(idx, grid_sz):
    return torch.stack((idx//grid_sz, idx%grid_sz)).T

In [138]:
# simulate next state, when given a state and action 
def simulate_next_state(states, action, grid_sz):
    # states: (b_sz, 2)
    # action: (b_sz, 1)
    # return: (b_sz, 2)
        
    # act_idx to 2d tensor 0: up, 1: down, 2: left, 3: right
    action_2d = torch.zeros((action.shape[0], 2))
    action_2d[action == 0, 0] = -1
    action_2d[action == 1, 0] = 1
    action_2d[action == 2, 1] = -1
    action_2d[action == 3, 1] = 1
    # penalzie no op 4 action by doing random action in 
    action_2d[action == 4, 0] = torch.randint(-1, 2, (1,))
    action_2d[action == 4, 1] = torch.randint(-1, 2, (1,))
    
    next_states = states + action_2d
    next_states = torch.clamp(next_states, 0, grid_sz-1)
    return next_states

In [11]:
idxs = grid_pos_to_idx(next_states, grid_sz)
next_states_ = idx_to_grid_pos(idxs, grid_sz)
(next_states == next_states_).all(), idxs[1]

(tensor(True), tensor(682))

In [262]:
act_dim = 4
grid_sz = 10
# Train step 
sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
grids_input = make_grids(sampled_states, end_positions_, grid_sz, b_sz).unsqueeze(1)
targets = grid_pos_to_idx(next_states, grid_sz)
grids_input.shape, targets.shape

(torch.Size([32, 1, 10, 10]), torch.Size([32]))

In [263]:
dim = 64
state_embedding = nn.Embedding(grid_sz*grid_sz, dim)
act_embedding = nn.Embedding(act_dim+1, dim)

state_encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(grid_sz*grid_sz, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, dim),
)

state_predictor = nn.Sequential(
    nn.Linear(dim+dim+dim, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, dim),
)

action_predictor = nn.Sequential(
    nn.Linear(dim+dim, 16),
    nn.ReLU(),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.Linear(16, dim),
)

In [264]:
grids_input = make_grid(sampled_states, grid_sz, b_sz)
s = state_encoder(grids_input)
g = state_encoder(grids_input)
a = action_predictor(torch.cat((s, g), dim=1))
a_logits = a @ act_embedding.weight.T

no_op_idx = torch.full((b_sz,), act_dim, dtype=torch.long)
no_op_tokens = act_embedding(no_op_idx)
next_states = state_predictor(torch.cat((s, no_op_tokens, g), dim=1))



# a = action_predictor(torch.cat((sg, sg_next), dim=1))
s.shape, a.shape, a_logits.shape, next_states.shape

(torch.Size([32, 64]),
 torch.Size([32, 64]),
 torch.Size([32, 5]),
 torch.Size([32, 64]))

In [265]:
# Train encoder decoder
def train_encoder_decoder_step(encoder, state_embedding, optim, grid_sz, b_sz, epochs):
    for e in epochs:
        start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
        sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
        grids_input = make_grid(sampled_states, grid_sz, b_sz).unsqueeze(1)
        sampled_states_idx = grid_pos_to_idx(sampled_states, grid_sz)

        optim.zero_grad()
        enc = encoder(grids_input)
        logits = enc @ state_embedding.weight.T
        
        loss = F.cross_entropy(logits, sampled_states_idx)
        loss.backward()
        optim.step()

        acc = (torch.argmax(logits, dim=1) == sampled_states_idx).float().mean()

        print(f'Epoch: {e}, Loss: {loss.item()}, Acc: {acc.item()}')

optim = torch.optim.Adam(list(state_encoder.parameters()) + list(state_embedding.parameters()), lr=1e-3)
train_encoder_decoder_step(state_encoder, state_embedding, optim, grid_sz, b_sz, range(300))

Epoch: 0, Loss: 4.634663105010986, Acc: 0.0
Epoch: 1, Loss: 4.8573808670043945, Acc: 0.0
Epoch: 2, Loss: 4.580707550048828, Acc: 0.0625
Epoch: 3, Loss: 4.804654121398926, Acc: 0.0
Epoch: 4, Loss: 4.697710990905762, Acc: 0.0
Epoch: 5, Loss: 4.675345420837402, Acc: 0.0
Epoch: 6, Loss: 4.697463035583496, Acc: 0.0
Epoch: 7, Loss: 4.574417591094971, Acc: 0.03125
Epoch: 8, Loss: 4.658871650695801, Acc: 0.0
Epoch: 9, Loss: 4.633896350860596, Acc: 0.03125
Epoch: 10, Loss: 4.568658351898193, Acc: 0.03125
Epoch: 11, Loss: 4.454195499420166, Acc: 0.0625
Epoch: 12, Loss: 4.545584201812744, Acc: 0.0
Epoch: 13, Loss: 4.666504383087158, Acc: 0.0
Epoch: 14, Loss: 4.288885116577148, Acc: 0.09375
Epoch: 15, Loss: 4.6952433586120605, Acc: 0.0
Epoch: 16, Loss: 4.407009601593018, Acc: 0.0625
Epoch: 17, Loss: 4.519698619842529, Acc: 0.03125
Epoch: 18, Loss: 4.452103137969971, Acc: 0.03125
Epoch: 19, Loss: 4.430048942565918, Acc: 0.09375
Epoch: 20, Loss: 4.457878112792969, Acc: 0.0625
Epoch: 21, Loss: 4.4635

In [266]:
# save state_encoder 
torch.save(state_encoder.state_dict(), 'state_encoder.pt')
# load state_encoder
state_encoder.load_state_dict(torch.load('state_encoder.pt'))

In [50]:
s = state_encoder(grids_input)
logits = s @ state_embedding.weight.T
pred_s = torch.argmax(logits, dim=1)
sampled_states_idx = grid_pos_to_idx(sampled_states, grid_sz)
(pred_s == sampled_states_idx).float().mean()

tensor(1.)

In [238]:
# Train state predictor
def train_state_predictor_step(encoder, state_predictor, state_embedding, optim, grid_sz, b_sz, epochs):
    for e in epochs: 
        start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
        sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
        s_grids_input = make_grid(sampled_states, grid_sz, b_sz).unsqueeze(1)
        g_grids_input = make_grid(end_positions_, grid_sz, b_sz).unsqueeze(1)

        targets = grid_pos_to_idx(next_states, grid_sz)

        optim.zero_grad()
        with torch.no_grad():
            s = encoder(s_grids_input)
            s_idx = torch.argmax(s @ state_embedding.weight.T, dim=1)
            s_tokens = state_embedding(s_idx)
            g = encoder(g_grids_input)
            g_idx = torch.argmax(g @ state_embedding.weight.T, dim=1)
            g_tokens = state_embedding(g_idx)


        no_op_idx = torch.full((b_sz,), act_dim, dtype=torch.long)
        no_op_tokens = act_embedding(no_op_idx)
        s_next_pred = state_predictor(torch.cat((s_tokens, no_op_tokens, g_tokens), dim=1))
        logits = s_next_pred @ state_embedding.weight.T
        
        loss = F.cross_entropy(logits, targets)
        loss.backward()
        optim.step()

        # accuracy 
        acc = (torch.argmax(logits, dim=1) == targets).float().mean()

        print(f'Epoch: {e}, Loss: {loss.item()}, Acc: {acc.item()}')
optim = torch.optim.Adam(list(state_predictor.parameters()), lr=1e-3)
train_state_predictor_step(state_encoder, state_predictor, state_embedding, optim, grid_sz, b_sz, range(500))

Epoch: 0, Loss: 4.58441686630249, Acc: 0.03125
Epoch: 1, Loss: 4.739172458648682, Acc: 0.0
Epoch: 2, Loss: 4.39930534362793, Acc: 0.0
Epoch: 3, Loss: 4.785521507263184, Acc: 0.0
Epoch: 4, Loss: 4.290663719177246, Acc: 0.0
Epoch: 5, Loss: 4.4196672439575195, Acc: 0.03125
Epoch: 6, Loss: 4.254199028015137, Acc: 0.03125
Epoch: 7, Loss: 4.286580562591553, Acc: 0.0625
Epoch: 8, Loss: 4.2257304191589355, Acc: 0.03125
Epoch: 9, Loss: 4.189031600952148, Acc: 0.09375
Epoch: 10, Loss: 4.129340648651123, Acc: 0.125
Epoch: 11, Loss: 3.9673187732696533, Acc: 0.125
Epoch: 12, Loss: 4.097968578338623, Acc: 0.09375
Epoch: 13, Loss: 4.275415897369385, Acc: 0.03125
Epoch: 14, Loss: 4.097780227661133, Acc: 0.0
Epoch: 15, Loss: 4.127450942993164, Acc: 0.0625
Epoch: 16, Loss: 3.797480344772339, Acc: 0.1875
Epoch: 17, Loss: 4.036776065826416, Acc: 0.09375
Epoch: 18, Loss: 3.988205671310425, Acc: 0.15625
Epoch: 19, Loss: 4.1202826499938965, Acc: 0.03125
Epoch: 20, Loss: 3.9263389110565186, Acc: 0.0
Epoch: 21

In [261]:
# Train action predictor
def train_action_predictor_step(encoder, state_predictor, action_predictor, state_embedding, optim, grid_sz, b_sz, epochs):
    for e in epochs:
        start_positions, end_positions, actions, grid_positions = get_batch(grid_sz, b_sz)
        sampled_states, next_states, sampled_acts, end_positions_ = sample_s_sn_act(grid_positions, actions, b_sz)
        s_grids_input = make_grid(sampled_states, grid_sz, b_sz).unsqueeze(1)
        g_grids_input = make_grid(end_positions_, grid_sz, b_sz).unsqueeze(1)

        optim.zero_grad()
        with torch.no_grad():
            s = encoder(s_grids_input)
            s_idx = torch.argmax(s @ state_embedding.weight.T, dim=1)
            s_tokens = state_embedding(s_idx)
            
            g = encoder(g_grids_input)
            g_idx = torch.argmax(g @ state_embedding.weight.T, dim=1)
            g_tokens = state_embedding(g_idx)

            no_op_idx = torch.full((b_sz,), act_dim, dtype=torch.long)
            no_op_tokens = act_embedding(no_op_idx)
            
            s_next_pred = state_predictor(torch.cat((s_tokens, no_op_tokens, g_tokens), dim=1)) 
            s_next_idx = torch.argmax(s_next_pred @ state_embedding.weight.T, dim=1)
            s_next_tokens = state_embedding(s_next_idx)
            imagined_targets = s_next_idx


        a = action_predictor(torch.cat((s, s_next_tokens), dim=1))
        a_logits = a @ act_embedding.weight.T
        a_tokens = a#act_embedding(torch.argmax(a_logits, dim=1))
    
        s_next_pred = state_predictor(torch.cat((s, a_tokens, torch.zeros_like(s_next_tokens)), dim=1))
        s_next_pred_logits = s_next_pred @ state_embedding.weight.T
        
        with torch.no_grad():
            next_sim_states = simulate_next_state(sampled_states, torch.argmax(a_logits, dim=1), grid_sz).long()
            # next_sim_states = encoder(make_grid(next_sim_states, grid_sz, b_sz).unsqueeze(1))
            actual_targets = grid_pos_to_idx(next_sim_states, grid_sz)

        std_loss = variance_loss(a) #+ variance_loss(s_next_pred)
        cov_loss = covariance_loss(a)# + covariance_loss(s_next_pred)
        loss = F.cross_entropy(s_next_pred_logits, actual_targets)
        loss += F.cross_entropy(s_next_pred_logits, imagined_targets)
        # loss = 10*loss + 10*std_loss + 1*cov_loss
        loss.backward()
        optim.step()

        # accuracy
        acc = (torch.argmax(s_next_pred_logits, dim=1) == actual_targets).float().mean()

        print(f'Epoch: {e}, Loss: {loss.item()}, Acc: {acc.item()}')

# optimize both action_predictor + state_predictor
optim = torch.optim.Adam(list(action_predictor.parameters()) + list(state_predictor.parameters()), lr=1e-3)
train_action_predictor_step(state_encoder, state_predictor, action_predictor, state_embedding, optim, grid_sz, b_sz, range(3000))

Epoch: 0, Loss: 2.1371874809265137, Acc: 0.3125
Epoch: 1, Loss: 2.4398863315582275, Acc: 0.375
Epoch: 2, Loss: 1.9336374998092651, Acc: 0.375
Epoch: 3, Loss: 2.2648234367370605, Acc: 0.375
Epoch: 4, Loss: 2.045513153076172, Acc: 0.4375
Epoch: 5, Loss: 2.98299503326416, Acc: 0.3125
Epoch: 6, Loss: 2.087742805480957, Acc: 0.4375
Epoch: 7, Loss: 2.5734808444976807, Acc: 0.5
Epoch: 8, Loss: 2.101989269256592, Acc: 0.4375
Epoch: 9, Loss: 2.1736867427825928, Acc: 0.46875
Epoch: 10, Loss: 2.7691309452056885, Acc: 0.3125
Epoch: 11, Loss: 2.283095121383667, Acc: 0.34375
Epoch: 12, Loss: 2.1122679710388184, Acc: 0.34375
Epoch: 13, Loss: 1.8672713041305542, Acc: 0.53125
Epoch: 14, Loss: 1.8730909824371338, Acc: 0.3125
Epoch: 15, Loss: 2.1342391967773438, Acc: 0.34375
Epoch: 16, Loss: 1.7459166049957275, Acc: 0.53125
Epoch: 17, Loss: 1.7389471530914307, Acc: 0.53125
Epoch: 18, Loss: 2.1037724018096924, Acc: 0.375
Epoch: 19, Loss: 2.2526509761810303, Acc: 0.5
Epoch: 20, Loss: 2.877552032470703, Acc

In [259]:
# act space: 0: up, 1: down, 2: left, 3: right, 4: no-op
s_test = torch.tensor([[0,6]])
g_test = torch.tensor([[7,7]])
s_grid = make_grid(s_test, grid_sz, 1).unsqueeze(1)
g_grid = make_grid(g_test, grid_sz, 1).unsqueeze(1)

s = state_encoder(s_grid)
g = state_encoder(g_grid)

s_idx = torch.argmax(s @ state_embedding.weight.T, dim=1)
s_tokens = state_embedding(s_idx)
g_idx = torch.argmax(g @ state_embedding.weight.T, dim=1)
g_tokens = state_embedding(g_idx)

no_op_idx = torch.full((1,), act_dim, dtype=torch.long)
no_op_tokens = act_embedding(no_op_idx)

s_next_pred = state_predictor(torch.cat((s_tokens, no_op_tokens, g_tokens), dim=1))

a = action_predictor(torch.cat((s, s_next_pred), dim=1))
a_logits = a @ act_embedding.weight.T

a_logits, torch.argmax(a_logits, dim=1)

(tensor([[ 3.3962, -0.4909,  1.5634, -1.2217, -0.1489]], grad_fn=<MmBackward0>),
 tensor([0]))