In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from dataset import TrajectoryDataset

import numpy as np
import matplotlib.pyplot as plt

### 1. Getting the Dataset and Dataloader

In [2]:
dataset = TrajectoryDataset(
    data_dir = "../dataset/",
    states_filename = "states.npy",
    actions_filename = "actions.npy",
    s_transform = None,
    a_transform = None,
)

# TODO: create two dataset for train and test

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

first_datapoint = next(iter(dataloader))
state, action = first_datapoint

print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

Number of data_points 15000
Shape of state: torch.Size([1, 17, 2, 65, 65])
Shape of action: torch.Size([1, 16, 2])


  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


### 2. Defining the Model

1. `Encoder`: which will be a simple CNN network.
2. `Predictor`: which will be a simple LSTM Cell.

In [3]:
class SimpleEncoder(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 12, padding=1, kernel_size=3)
        self.conv2 = nn.Conv2d(12, 12, padding=1, kernel_size=3)
        self.conv3 = nn.Conv2d(12, 12, padding=1, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(12)
        self.bn2 = nn.BatchNorm2d(12)
        self.bn3 = nn.BatchNorm2d(12)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d((6, 5), stride=2)
        self.pool2 = nn.MaxPool2d((5, 5), stride=5)
        # h -> (5, 5, stride=1) -> (3, 3)
        # h = 65 -> 8748
        self.fc1 = nn.Linear(432, 4096)
        self.fc2 = nn.Linear(4096, embed_size)

    def forward(self, x):
        # h,w = 65
        x = self.conv1(x)        
        x = self.bn1(x)
        x = self.relu(x)
        x1 = x

        x2 = self.conv2(x1)
        x2 = self.bn2(x2)
        x2 = self.relu(x2)
        x2 = x2 + x1
        x2 = self.pool1(x2)
        # h,w = 31 

        x3 = self.conv3(x2)
        x3 = self.bn3(x3)
        x3 = self.relu(x3)
        x3 = x3 + x2
        x3 = self.pool2(x3)
        # h,w = 6

        x3 = x3.view(x3.size(0), -1)
        # (b,12*6*6)
        x3 = self.fc1(x3)
        x3 = self.relu(x3)
        x3 = self.fc2(x3)
        return x3

In [4]:
class Predictor(nn.Module):
    def __init__(self, input_size, hidden_size) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)

        self.h = None
        self.c = None

    def set_hc(self, h, c):
        self.h = h
        self.c = c 
    
    def reset_hc(self):
        self.h = self.h.zero_() 
        self.c = self.c.zero_()
    """
    Forward Method
    Input: 
        action tensor (a) of shape: (B, input_size)
    Output:   
        predicted repr. tensor (s_yhat) of shape: (B, hidden_size)
    """
    def forward(self, action):
        self.h, self.c = self.lstm_cell(action, (self.h, self.c))
        return self.h

### 3. Training 
We define `train_separate()` function, which does the training step when encoder is trained separately.

In [5]:
def train_separate(pred, enc, dataloader, criterion, optimizer, device):
    # keeping encoder in eval mode
    pred, enc = pred.to(device), enc.to(device)
    enc.eval()

    for batch in dataloader:
        ## shape of [ s = (b, L+1, c, h, w)]  [a = (b, L, 2)]
        s, a = batch
        s, a = s.to(device), a.to(device)

        ## initial observation
        o = s[:, 0, :, :, :]
        so = enc(o)
        co = torch.zeros(so.shape).to(device)
        pred.set_hc(so, co)
        
        loss ,L = 0, a.shape[1]
        for i in range(L):
            sy_hat = pred(a[:, i, :])
            sy = enc(s[:, i+1, :, :, :])
            loss += criterion(sy_hat, sy)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss)
        ## clearing the hidden state and cell state
        pred.reset_hc()
        break     

In [6]:
input_size = 2
hidden_size = 1024 

encoder = SimpleEncoder(hidden_size) 
predictor = Predictor(input_size=input_size, hidden_size=hidden_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(predictor.parameters(), lr=0.001)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
train_separate(predictor, encoder, dataloader, criterion, optimizer, device)

tensor(0.0308, device='cuda:0', grad_fn=<AddBackward0>)


### Training and Inference if the Encoder and Decoder is part of the model.
If the encoder is trained together with JEPA, we define the forward inference and training step.  
The pending step is the defining the loss function and how to do backward.

In [8]:
class JEPAModel(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.encoder = SimpleEncoder(embed_size)
        self.predictor = Predictor(2, 1024)
        
    def set_predictor(self, o, co):
        so = self.encoder.forward(o)
        self.predictor.set_hc(so, co)
        return so
    
    def reset_predictor(self):
        self.predictor.reset_hc()

    def forward(self, action, state):
        sy_hat = self.predictor(action)
        sy = self.encoder(state)
        return sy_hat, sy

In [9]:
def forward_inference(model, actions, states):
    # shape of states = (b, L+1, c, h, w)
    # shape of action = (b, L, 2)
    B, L, D = state.shape[0], actions.shape[1], model.predictor.hidden_size

    o = states[:, 0, :, :, :]
    co = torch.zeros((B, D)).to(o.device)
    model.set_predictor(o, co)

    result = torch.empty((B, L, D))
    for i in range(L):
        sy_hat, _ = model(actions[:, i, :], states[:, i+1, :, :, :])
        result[:, i, :] = sy_hat

    return result

In [10]:
model = JEPAModel(1024)

# first_datapoint = next(iter(dataloader))
states, actions = first_datapoint
model = model.to(device)
states = states.to(device)
actions = actions.to(device)

print(f"states shape: {states.shape}")
print(f"actions shape: {actions.shape}")

result = forward_inference(model, actions, states)
print(result.shape)

states shape: torch.Size([1, 17, 2, 65, 65])
actions shape: torch.Size([1, 16, 2])
torch.Size([1, 16, 1024])


If we are doing training of encoder and predictor together, then we need to handle all the different losses, defined in the figure

![loss diagram](../../assets/loss_diagram.png)

In [11]:
def forward_train_step(model, actions, states, optimizer):
    B, L, D = state.shape[0], actions.shape[1], model.predictor.hidden_size

    loss, loss1, loss2, loss3 = 0, 0, 0, 0

    o = states[:, 0, :, :, :]
    co = torch.zeros((B, D)).to(o.device)
    so = model.set_predictor(o, co)
    ## TODO: compute loss1 using `so`

    for i in range(L):
        sy_hat, sy = model(actions[:, i, :], states[:, i+1, :, :, :])
        ## TODO: compute loss2 per iteration using `sy_hat`, `sy`
        ## TODO: compute loss3 per iteration using `sy`
        ## note: if we are using latent variables: we have to compute lossz
        
    loss = loss1 + loss2 + loss3
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return