In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as tr

from torch.utils.data import DataLoader
from dataset import TrajectoryDataset
from lightly.models.modules.heads import VICRegProjectionHead
from encoder_train import save_model, compute_mean_and_std, get_byol_transforms
from encoder_train import criterion as VICReg_criterion
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

### 1. Getting the Dataset and Dataloader

In [2]:
dataset = TrajectoryDataset(
    data_dir = "../dataset/",
    states_filename = "states.npy",
    actions_filename = "actions.npy",
    s_transform = None,
    a_transform = None,
)

# TODO: create two dataset for train and test

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

first_datapoint = next(iter(dataloader))
state, action = first_datapoint

print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

Number of data_points 469
Shape of state: torch.Size([32, 17, 2, 65, 65])
Shape of action: torch.Size([32, 16, 2])


  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


### 2. Defining the Model

1. `Encoder`: which will be a simple CNN network.
2. `Predictor`: which will be a simple LSTM Cell.

In [3]:
class SimpleEncoder(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 12, padding=1, kernel_size=3)
        self.conv2 = nn.Conv2d(12, 12, padding=1, kernel_size=3)
        self.conv3 = nn.Conv2d(12, 12, padding=1, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(12)
        self.bn2 = nn.BatchNorm2d(12)
        self.bn3 = nn.BatchNorm2d(12)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d((5, 5), stride=2)
        self.pool2 = nn.MaxPool2d((5, 5), stride=5)
        # h -> (5, 5, stride=1) -> (3, 3)
        # h = 65 -> 8748
        self.fc1 = nn.Linear(432, 4096)
        self.fc2 = nn.Linear(4096, embed_size)

    def forward(self, x):
        # h,w = 65
        x = self.conv1(x)        
        x = self.bn1(x)
        x = self.relu(x)
        x1 = x

        x2 = self.conv2(x1)
        x2 = self.bn2(x2)
        x2 = self.relu(x2)
        x2 = x2 + x1
        x2 = self.pool1(x2)
        # h,w = 31 

        x3 = self.conv3(x2)
        x3 = self.bn3(x3)
        x3 = self.relu(x3)
        x3 = x3 + x2
        x3 = self.pool2(x3)
        # h,w = 6

        x3 = x3.view(x3.size(0), -1)
        # (b,12*6*6)
        x3 = self.fc1(x3)
        x3 = self.relu(x3)
        x3 = self.fc2(x3)
        return x3

In [4]:
class VICRegModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = VICRegProjectionHead(
            input_dim=1024,
            hidden_dim=1024,
            output_dim=1024,
            num_layers=3,
        )
    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

In [5]:
class Predictor(nn.Module):
    def __init__(self, input_size, hidden_size) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)

        self.h = None
        self.c = None

    def set_hc(self, h, c):
        self.h = h
        self.c = c 
    
    def reset_hc(self):
        self.h = self.h.zero_() 
        self.c = self.c.zero_()

    def forward(self, action):
        self.h, self.c = self.lstm_cell(action, (self.h, self.c))
        return self.h

### 3. Training 
We define `train_separate()` function, which does the training step when encoder is trained separately.

In [6]:
def train_encoder(dataloader, model, optimizer, criterion, epochs, device, transformation1, 
                  transformation2, step = 1):
    model.to(device)
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc="Processing batches"):
            state, _ = batch
            state = state.to(device)
            for i in range(state.size(1)):
                img = state[:, i, :, :, :]
                img = torch.cat([img, img[:, 1:2, :, :]], dim=1)

                x0 = transformation1(img)
                x1 = transformation2(img)

                z0 = model(x0)
                z1 = model(x1)

                loss = criterion(z0, z1)
                total_loss += loss.detach()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                avg_loss = total_loss / (len(dataloader)*state.size(1))

        # Save model checkpoint
        if epoch % step == 0:
            save_model(model, epoch)
        print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
    print("Training completed.")
    return model

In [14]:
def train_predictor(pred, enc, dataloader, criterion, optimizer, device, epochs=10):
    # keeping encoder in eval mode
    pred, enc = pred.to(device), enc.to(device)
    enc.eval()

    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc="Processing batch"):
            ## shape of [ s = (b, L+1, c, h, w)]  [a = (b, L, 2)]
            s, a = batch
            s, a = s.to(device), a.to(device)

            ## initial observation
            o = s[:, 0, :, :, :]
            o = torch.cat([o, o[:, 1:2, :, :]], dim=1)
            so = enc(o)
            co = torch.zeros(so.shape).to(device)
            pred.set_hc(so, co)
            
            loss ,L = 0, a.shape[1]
            for i in range(L):
                sy_hat = pred(a[:, i, :])
                temp = s[:, i+1, :, :, :]
                temp = torch.cat([temp, temp[:, 1:2, :, :]], dim=1)
                sy = enc(temp)
                loss += criterion(sy_hat, sy)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ## clearing the hidden state and cell state
            pred.reset_hc()
            total_loss += loss

        avg_loss = total_loss / len(dataloader)
        save_model(pred, epoch, file_name="pred")
        print(f"epoch: {epoch:>02}, loss: {avg_loss:.9f}")
    print("Training completed..")
    return pred

In [15]:
hidden_size = 1024

encoder = SimpleEncoder(hidden_size) 
encoder = VICRegModel(encoder)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9, weight_decay=1.5e-4)

# defining transformations
mean, std = compute_mean_and_std(dataloader)
transformation1, transformation2 = get_byol_transforms(mean, std)

In [9]:
# encoder = train_encoder(dataloader, encoder, optimizer, VICReg_criterion, 10, 
#               device, transformation1, transformation2)

# save_model(encoder, "encoder")

In [17]:
input_size = 2
hidden_size = 1024
encoder.load_state_dict(torch.load("./checkpoints/enoder_9.pth"))

predictor = Predictor(input_size=input_size, hidden_size=hidden_size)
predictor_optimizer = optim.SGD(predictor.parameters(), lr=0.00001, momentum=0.9, weight_decay=1.5e-4)
predictor_criterion = nn.MSELoss()

  encoder.load_state_dict(torch.load("./checkpoints/enoder_9.pth"))


In [None]:
train_predictor(predictor, encoder, dataloader, predictor_criterion, optimizer, device)

Processing batch: 100%|██████████| 469/469 [01:49<00:00,  4.29it/s]


Model saved to checkpoints/pred_0.pth
epoch: 00, loss: 0.098045424


Processing batch: 100%|██████████| 469/469 [01:25<00:00,  5.50it/s]


Model saved to checkpoints/pred_1.pth
epoch: 01, loss: 0.001317678


Processing batch: 100%|██████████| 469/469 [01:51<00:00,  4.22it/s]


Model saved to checkpoints/pred_2.pth
epoch: 02, loss: 0.001225349


Processing batch: 100%|██████████| 469/469 [03:00<00:00,  2.59it/s]


Model saved to checkpoints/pred_3.pth
epoch: 03, loss: 0.001220523


Processing batch: 100%|██████████| 469/469 [05:12<00:00,  1.50it/s]


Model saved to checkpoints/pred_4.pth
epoch: 04, loss: 0.001220162


Processing batch: 100%|██████████| 469/469 [05:15<00:00,  1.49it/s]


Model saved to checkpoints/pred_5.pth
epoch: 05, loss: 0.001220171


Processing batch:  99%|█████████▉| 466/469 [04:54<00:01,  2.68it/s]

### Training and Inference if the Encoder and Decoder is part of the model.
If the encoder is trained together with JEPA, we define the forward inference and training step.  
The pending step is the defining the loss function and how to do backward.

In [None]:
class JEPAModel(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.encoder = SimpleEncoder(embed_size)
        self.predictor = Predictor(2, 1024)
        
    def set_predictor(self, o, co):
        so = self.encoder.forward(o)
        self.predictor.set_hc(so, co)
        return so
    
    def reset_predictor(self):
        self.predictor.reset_hc()

    def forward(self, action, state):
        sy_hat = self.predictor(action)
        sy = self.encoder(state)
        return sy_hat, sy

In [None]:
def forward_inference(model, actions, states):
    # shape of states = (b, L+1, c, h, w)
    # shape of action = (b, L, 2)
    B, L, D = state.shape[0], actions.shape[1], model.predictor.hidden_size

    o = states[:, 0, :, :, :]
    co = torch.zeros((B, D)).to(o.device)
    model.set_predictor(o, co)

    result = torch.empty((B, L, D))
    for i in range(L):
        sy_hat, _ = model(actions[:, i, :], states[:, i+1, :, :, :])
        result[:, i, :] = sy_hat

    return result

In [None]:
model = JEPAModel(1024)

# first_datapoint = next(iter(dataloader))
states, actions = first_datapoint
model = model.to(device)
states = states.to(device)
actions = actions.to(device)

print(f"states shape: {states.shape}")
print(f"actions shape: {actions.shape}")

result = forward_inference(model, actions, states)
print(result.shape)

If we are doing training of encoder and predictor together, then we need to handle all the different losses, defined in the figure

![loss diagram](../../assets/loss_diagram.png)

In [None]:
def forward_train_step(model, actions, states, optimizer):
    B, L, D = state.shape[0], actions.shape[1], model.predictor.hidden_size

    loss, loss1, loss2, loss3 = 0, 0, 0, 0

    o = states[:, 0, :, :, :]
    co = torch.zeros((B, D)).to(o.device)
    so = model.set_predictor(o, co)
    ## TODO: compute loss1 using `so`

    for i in range(L):
        sy_hat, sy = model(actions[:, i, :], states[:, i+1, :, :, :])
        ## TODO: compute loss2 per iteration using `sy_hat`, `sy`
        ## TODO: compute loss3 per iteration using `sy`
        ## note: if we are using latent variables: we have to compute lossz
        
    loss = loss1 + loss2 + loss3
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return