In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as tr

from torch.utils.data import DataLoader
from dataset import TrajectoryDataset
from lightly.models.modules.heads import VICRegProjectionHead
from encoder_train import save_model, compute_mean_and_std, get_byol_transforms, get_encoder_loss
from encoder_train import criterion as VICReg_criterion
from tqdm import tqdm

import numpy as np
import math
import matplotlib.pyplot as plt

### 1. Getting the Dataset and Dataloader

In [2]:
dataset = TrajectoryDataset(
    data_dir = "../dataset/",
    states_filename = "states.npy",
    actions_filename = "actions.npy",
    s_transform = None,
    a_transform = None,
    length = 992    
)

# TODO: create two dataset for train and test

dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

first_datapoint = next(iter(dataloader))
state, action = first_datapoint

print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

Number of data_points 62
Shape of state: torch.Size([16, 17, 2, 65, 65])
Shape of action: torch.Size([16, 16, 2])


  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


### 2. Defining the Model

1. `Encoder`: which will be a simple CNN network.
2. `Predictor`: which will be a simple LSTM Cell.

In [3]:
class SimpleEncoder(nn.Module):
    def __init__(self, embed_size, input_channel=3):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, 12, padding=1, kernel_size=3)
        self.conv2 = nn.Conv2d(12, 12, padding=1, kernel_size=3)
        self.conv3 = nn.Conv2d(12, 12, padding=1, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(12)
        self.bn2 = nn.BatchNorm2d(12)
        self.bn3 = nn.BatchNorm2d(12)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d((5, 5), stride=2)
        self.pool2 = nn.MaxPool2d((5, 5), stride=5)
        # h -> (5, 5, stride=1) -> (3, 3)
        # h = 65 -> 8748
        self.fc1 = nn.Linear(432, 4096)
        self.fc2 = nn.Linear(4096, embed_size)

    def forward(self, x):
        # h,w = 65
        x = self.conv1(x)        
        x = self.bn1(x)
        x = self.relu(x)
        x1 = x

        x2 = self.conv2(x1)
        x2 = self.bn2(x2)
        x2 = self.relu(x2)
        x2 = x2 + x1
        x2 = self.pool1(x2)
        # h,w = 31 

        x3 = self.conv3(x2)
        x3 = self.bn3(x3)
        x3 = self.relu(x3)
        x3 = x3 + x2
        x3 = self.pool2(x3)
        # h,w = 6

        x3 = x3.view(x3.size(0), -1)
        # (b,12*6*6)
        x3 = self.fc1(x3)
        x3 = self.relu(x3)
        x3 = self.fc2(x3)
        return x3

In [4]:
class VICRegModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = VICRegProjectionHead(
            input_dim=1024,
            hidden_dim=2048,
            output_dim=1024,
            num_layers=3,
        )
    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return x, z

In [5]:
class Predictor(nn.Module):
    def __init__(self, input_size, hidden_size) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)

        self.h = None
        self.c = None

    def set_hc(self, h, c):
        self.h = h
        self.c = c 
    
    def reset_hc(self):
        self.h = self.h.zero_() 
        self.c = self.c.zero_()

    def forward(self, action):
        self.h, self.c = self.lstm_cell(action, (self.h, self.c))
        return self.h

### 3. Training 
We define `train_separate()` function, which does the training step when encoder is trained separately.

In [6]:
def train_encoder(dataloader, model, optimizer, criterion, epochs, device, transformation1, 
                  transformation2, step = 1):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc="Processing batches"):
            state, _ = batch
            state = state.to(device)
            for i in range(state.size(1)):
                img = state[:, i, :, :, :]

                x0 = transformation1(img)
                x1 = transformation2(img)

                _, z0 = model(x0)
                _, z1 = model(x1)

                loss = criterion(z0, z1)
                total_loss += loss.detach()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                avg_loss = total_loss / (len(dataloader)*state.size(1))

        # Save model checkpoint
        if epoch % step == 0:
            save_model(model, epoch)
        print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
    print("Training completed.")
    return model

In [7]:
# todo: have to remove the use of `use expander`, it for testing
def train_predictor(pred, enc, dataloader, criterion, optimizer, device, 
                    use_expander=False, epochs=10):
    # keeping encoder in eval mode
    pred, enc = pred.to(device), enc.to(device)

    # freezing the encoder and setting it to evaluation mode
    enc.eval()
    for param in enc.parameters():
        param.requires_grad = False

    for epoch in range(epochs):
        total_loss = 0.0
        for batch in tqdm(dataloader, desc="Processing batch"):
            ## shape of [ s = (B, L+1, C, H, W)]  [a = (B, L, 2)]
            s, a = batch
            s, a = s.to(device), a.to(device)

            ## initial observation
            o = s[:, 0, :, :, :]
            with torch.no_grad():
                x, z = enc(o)
                so = z if use_expander else x
            
            ## initializing predictor h,c
            ## check randn instead of zeros
            co = torch.zeros(so.shape).to(device)
            pred.set_hc(so, co)
            
            ## forward inference for training.
            loss ,L = 0, a.shape[1]
            for i in range(L):
                sy_hat = pred(a[:, i, :])
                si = s[:, i+1, :, :, :]
                with torch.no_grad():
                    x, z = enc(si)
                    sy = z if use_expander else x
                loss += criterion(sy_hat, sy)
            
            ## back-propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ## clearing h,c in lstm 
            pred.reset_hc()

            total_loss += loss

        avg_loss = total_loss / len(dataloader)
        save_model(pred, epoch, file_name="pred")
        print(f"epoch: {epoch:>02}, loss: {avg_loss:.9f}")
    print("Training completed..")
    return pred

In [8]:
hidden_size = 1024

# defining model
encoder = SimpleEncoder(hidden_size, 2) 
encoder = VICRegModel(encoder)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9, weight_decay=1.5e-4)

# defining transformations
mean, std = compute_mean_and_std(dataloader, is_channelsize3=False)
transformation1, transformation2 = get_byol_transforms(mean, std)

In [9]:
encoder = train_encoder(dataloader, encoder, optimizer, VICReg_criterion, 10, 
              device, transformation1, transformation2)

save_model(encoder, "encoder")

Processing batches: 100%|██████████| 62/62 [00:28<00:00,  2.18it/s]


Model saved to checkpoints/encoder__0.pth
epoch: 00, loss: 37.94731


Processing batches: 100%|██████████| 62/62 [00:29<00:00,  2.08it/s]


Model saved to checkpoints/encoder__1.pth
epoch: 01, loss: 37.04359


Processing batches: 100%|██████████| 62/62 [00:32<00:00,  1.90it/s]


Model saved to checkpoints/encoder__2.pth
epoch: 02, loss: 36.08323


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.79it/s]


Model saved to checkpoints/encoder__3.pth
epoch: 03, loss: 35.82449


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.80it/s]


Model saved to checkpoints/encoder__4.pth
epoch: 04, loss: 35.56228


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.78it/s]


Model saved to checkpoints/encoder__5.pth
epoch: 05, loss: 35.44170


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.80it/s]


Model saved to checkpoints/encoder__6.pth
epoch: 06, loss: 35.37407


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.80it/s]


Model saved to checkpoints/encoder__7.pth
epoch: 07, loss: 35.31825


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.80it/s]


Model saved to checkpoints/encoder__8.pth
epoch: 08, loss: 35.15565


Processing batches: 100%|██████████| 62/62 [00:34<00:00,  1.81it/s]

Model saved to checkpoints/encoder__9.pth
epoch: 09, loss: 35.18352
Training completed.
Model saved to checkpoints/encoder__encoder.pth





In [10]:
input_size = 2
hidden_size = 1024

# loading encoder
encoder.load_state_dict(torch.load("./checkpoints/encoder__9.pth"))

# loading the model
predictor = Predictor(input_size=input_size, hidden_size=hidden_size)
predictor_optimizer = optim.SGD(predictor.parameters(), lr=0.001, momentum=0.9, weight_decay=1.5e-4)
predictor_criterion = nn.MSELoss()

  encoder.load_state_dict(torch.load("./checkpoints/encoder__9.pth"))


In [11]:
train_predictor(predictor, encoder, dataloader, predictor_criterion,
                predictor_optimizer, device, use_expander=False)

Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.49it/s]


Model saved to checkpoints/pred_0.pth
epoch: 00, loss: 6091.095214844


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.39it/s]


Model saved to checkpoints/pred_1.pth
epoch: 01, loss: 5925.533691406


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.32it/s]


Model saved to checkpoints/pred_2.pth
epoch: 02, loss: 5924.496582031


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.37it/s]


Model saved to checkpoints/pred_3.pth
epoch: 03, loss: 5924.109863281


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.85it/s]


Model saved to checkpoints/pred_4.pth
epoch: 04, loss: 5923.988281250


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.88it/s]


Model saved to checkpoints/pred_5.pth
epoch: 05, loss: 5923.827636719


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.87it/s]


Model saved to checkpoints/pred_6.pth
epoch: 06, loss: 5923.610351562


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.97it/s]


Model saved to checkpoints/pred_7.pth
epoch: 07, loss: 5923.518554688


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.92it/s]


Model saved to checkpoints/pred_8.pth
epoch: 08, loss: 5923.479003906


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 24.98it/s]


Model saved to checkpoints/pred_9.pth
epoch: 09, loss: 5923.376953125
Training completed..


Predictor(
  (lstm_cell): LSTMCell(2, 1024)
)

In [19]:
train_predictor(predictor, encoder, dataloader, predictor_criterion,
                predictor_optimizer, device, use_expander=True)

Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.41it/s]


Model saved to checkpoints/pred_0.pth
epoch: 00, loss: 8.385624886


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.71it/s]


Model saved to checkpoints/pred_1.pth
epoch: 01, loss: 2.296003103


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.60it/s]


Model saved to checkpoints/pred_2.pth
epoch: 02, loss: 2.250522852


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.25it/s]


Model saved to checkpoints/pred_3.pth
epoch: 03, loss: 2.204568386


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.27it/s]


Model saved to checkpoints/pred_4.pth
epoch: 04, loss: 2.158132315


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.22it/s]


Model saved to checkpoints/pred_5.pth
epoch: 05, loss: 2.110890388


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.13it/s]


Model saved to checkpoints/pred_6.pth
epoch: 06, loss: 2.062494516


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.35it/s]


Model saved to checkpoints/pred_7.pth
epoch: 07, loss: 2.012619734


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 28.44it/s]


Model saved to checkpoints/pred_8.pth
epoch: 08, loss: 1.960875630


Processing batch: 100%|██████████| 62/62 [00:02<00:00, 27.54it/s]

Model saved to checkpoints/pred_9.pth
epoch: 09, loss: 1.906860113
Training completed..





Predictor(
  (lstm_cell): LSTMCell(2, 1024)
)

### Training and Inference if the Encoder and Decoder is part of the model.
If the encoder is trained together with JEPA, we define the forward inference and training step.  
The pending step is the defining the loss function and how to do backward.

In [13]:
class JEPAModel(nn.Module):
    def __init__(self, embed_size, input_channel_size):
        super().__init__()
        self.encoder = VICRegModel(SimpleEncoder(embed_size, input_channel_size))
        self.predictor = Predictor(2, 1024)
        
    def set_predictor(self, o, co, use_expander=False):
        x, z = self.encoder.forward(o)
        so = z if use_expander else x
        self.predictor.set_hc(so, co)
        return so
    
    def reset_predictor(self):
        self.predictor.reset_hc()

    # sy_hat is state repr from predictor using actions
    # sy = (sy_enc, sy_exp), is state repr from encoder using states
    def forward(self, action=None, state=None):
        sy_hat, sy = None, None
        if action is not None:
            sy_hat = self.predictor(action)
        if state is not None:
            sy = self.encoder(state)

        return sy_hat, sy
    

In [14]:
def forward_inference(model, actions, states):
    # shape of states = (b, L+1, c, h, w)
    # shape of action = (b, L, 2)
    B, L, D = state.shape[0], actions.shape[1], model.predictor.hidden_size

    o = states[:, 0, :, :, :]
    co = torch.zeros((B, D)).to(o.device)
    model.set_predictor(o, co, use_expander=False)

    result = torch.empty((B, L, D))
    for i in range(L):
        sy_hat, _ = model(actions[:, i, :], states[:, i+1, :, :, :])
        result[:, i, :] = sy_hat

    return result

In [15]:
model = JEPAModel(1024, 2)

# first_datapoint = next(iter(dataloader))
states, actions = first_datapoint
model = model.to(device)
states = states.to(device)
actions = actions.to(device)

print(f"states shape: {states.shape}")
print(f"actions shape: {actions.shape}")

result = forward_inference(model, actions, states)
print(result.shape)

states shape: torch.Size([16, 17, 2, 65, 65])
actions shape: torch.Size([16, 16, 2])
torch.Size([16, 16, 1024])


If we are doing training of encoder and predictor together, then we need to handle all the different losses, defined in the figure

![loss diagram](../../assets/loss_diagram.png)

In [16]:
def train_joint(model, dataloader, criterion_encoder, criterion_pred, optimizer, 
                device, epochs=10, use_expander=False):
    model.to(device)
    model.train()
    
    # clipping the gradient to handle gradient explosions in LSTM
    max_val = 5.0
    for param in model.parameters():
        if param.grad is not None:
            param.grad.data = torch.clamp(param.grad.data, -max_val, max_val)

    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc="Processing Batch"):
            state, action = batch
            state, action = state.to(device), action.to(device)
            B, L, D = state.shape[0], action.shape[1], model.predictor.hidden_size

            loss, loss1, loss2, loss3 = 0, 0, 0, 0

            ## initializing the h,c of predictor 
            o = states[:, 0, :, :, :]
            c0 = torch.zeros((B, D)).to(device)
            model.set_predictor(o, c0, use_expander)

            # compute loss1
            loss1 = get_encoder_loss(model, o, transformation1, transformation2, 
                                     criterion_encoder)
            for i in range(L):
                # inference of encoder(next state) and predictor(action) 
                sy_hat, (sy_enc, sy_exp) = model(action[:, i, :], state[:, i+1, :, :, :])
                sy = sy_exp if use_expander else sy_enc

                # compute loss2 (distance btw sy and sy_hat)
                loss2 += criterion_pred(sy_hat, sy)
                # vic_reg loss for encoder (for encoding next state)
                loss3 += get_encoder_loss(model, state[:, i, :, :, :], 
                                          transformation1, transformation2, 
                                          criterion_encoder) 
            
            # adding all loss and doing back propagation
            loss = loss1 + loss2 + loss3
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        print(f"Epoch: {epoch}, total_loss: {total_loss}, the avg loss = {total_loss/len(dataloader)}")
        save_model(model, epoch, file_name="join_model")

    return model        

In [17]:
joint_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1.5e-4)
criterion_predictor = nn.MSELoss()
criterion_encoder = VICReg_criterion

In [18]:
# Training the joint model
train_joint(model, dataloader, criterion_encoder, criterion_predictor, 
            joint_optimizer, device, 3, use_expander=False)

Processing Batch: 100%|██████████| 62/62 [00:36<00:00,  1.69it/s]


Epoch: 0, total_loss: 41382.16864013672, the avg loss = 667.454332905431
Model saved to checkpoints/join_model_0.pth


Processing Batch: 100%|██████████| 62/62 [00:35<00:00,  1.73it/s]


Epoch: 1, total_loss: 39694.914123535156, the avg loss = 640.2405503795993
Model saved to checkpoints/join_model_1.pth


Processing Batch: 100%|██████████| 62/62 [00:36<00:00,  1.71it/s]

Epoch: 2, total_loss: 38815.533447265625, the avg loss = 626.0569910849295
Model saved to checkpoints/join_model_2.pth





JEPAModel(
  (encoder): VICRegModel(
    (backbone): SimpleEncoder(
      (conv1): Conv2d(2, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (pool1): MaxPool2d(kernel_size=(5, 5), stride=2, padding=0, dilation=1, ceil_mode=False)
      (pool2): MaxPool2d(kernel_size=(5, 5), stride=5, padding=0, dilation=1, ceil_mode=False)
      (fc1): Linear(in_features=432, out_features=4096, bias=True)
      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    )
    (projection_head): VICRegProjectionHead(
      (layers)