In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as tr

from torch.utils.data import DataLoader
from utils import TrajectoryDataset
from lightly.models.modules.heads import VICRegProjectionHead
from utils import save_model, compute_mean_and_std, get_byol_transforms, get_vicreg_loss
from utils import criterion as VICReg_criterion
from tqdm import tqdm
from models import JEPAModel

import numpy as np
import math
import matplotlib.pyplot as plt

In [2]:
embed_dim = 1024
epochs = 10
learning_rate = 0.001
use_expander = False
batch_size = 16 

dataset_directory = "./dataset"
states_filename = "states.npy"
actions_filename = "actions.npy"

In [3]:
def train_joint(model, dataloader, criterion_encoder, criterion_pred, 
                optimizer, transformation1, transformation2, 
                device, epochs=10, use_expander=False):
    model.to(device)
    model.train()
    
    # clipping the gradient to handle gradient explosions in LSTM
    max_val = 5.0
    for param in model.predictor.parameters():
        if param.grad is not None:
            param.grad.data = torch.clamp(param.grad.data, -max_val, max_val)

    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc="Processing Batch"):
            state, action = batch
            state, action = state.to(device), action.to(device)
            B, L, D = state.shape[0], action.shape[1], model.predictor.hidden_size

            loss, loss1, loss2, loss3 = 0, 0, 0, 0

            o = state[:, 0, :, :, :]
            c0 = torch.zeros((B, D)).to(device)
            model.set_predictor(o, c0, use_expander)

            # compute loss1
            loss1 = get_vicreg_loss(model, o, transformation1, transformation2, 
                                     criterion_encoder)
            for i in range(L):
                # inference of encoder(next state) and predictor(action) 
                sy_hat, (sy_enc, sy_exp) = model(action[:, i, :], state[:, i+1, :, :, :])
                sy = sy_exp if use_expander else sy_enc

                # compute loss2 (distance btw sy and sy_hat)
                loss2 += criterion_pred(sy_hat, sy)
                # vic_reg loss for encoder (for encoding next state)
                loss3 += get_vicreg_loss(model, state[:, i, :, :, :], 
                                          transformation1, transformation2, 
                                          criterion_encoder) 
            
            # adding all loss and doing back propagation
            loss = loss1 + loss2 + loss3
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        print(f"Epoch: {epoch}, total_loss: {total_loss}, the avg loss = {total_loss/len(dataloader)}")
        save_model(model, epoch, file_name="join_model")

    return model   

In [4]:
dataset = TrajectoryDataset(
    data_dir = dataset_directory,
    states_filename = states_filename,
    actions_filename = actions_filename,
    s_transform = None,
    a_transform = None,
    length = 1000 
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
first_datapoint = next(iter(dataloader))
state, action = first_datapoint
print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

Number of data_points 63
Shape of state: torch.Size([16, 17, 2, 65, 65])
Shape of action: torch.Size([16, 16, 2])


  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


In [5]:
mean, std = compute_mean_and_std(dataloader, is_channelsize3=False)
transformation1, transformation2 = get_byol_transforms(mean, std)

In [6]:
model = JEPAModel(embed_dim, 2)

joint_optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1.5e-4)
# joint_optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion_predictor = nn.MSELoss()
criterion_encoder = VICReg_criterion
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_joint(model, dataloader, criterion_encoder, criterion_predictor, 
            joint_optimizer, transformation1, transformation2, device,
            epochs=epochs, use_expander=use_expander)

Processing Batch: 100%|██████████| 63/63 [00:30<00:00,  2.09it/s]


Epoch: 0, total_loss: 41872.14825439453, the avg loss = 664.6372738792783
Model saved to checkpoints/join_model_0.pth


Processing Batch: 100%|██████████| 63/63 [00:29<00:00,  2.13it/s]


Epoch: 1, total_loss: 40472.96569824219, the avg loss = 642.4280269562252
Model saved to checkpoints/join_model_1.pth


Processing Batch: 100%|██████████| 63/63 [00:29<00:00,  2.11it/s]


Epoch: 2, total_loss: 39556.43377685547, the avg loss = 627.8799012199281
Model saved to checkpoints/join_model_2.pth


Processing Batch: 100%|██████████| 63/63 [00:29<00:00,  2.13it/s]


Epoch: 3, total_loss: 39138.952575683594, the avg loss = 621.2532154870412
Model saved to checkpoints/join_model_3.pth


Processing Batch: 100%|██████████| 63/63 [00:30<00:00,  2.08it/s]


Epoch: 4, total_loss: 38889.23504638672, the avg loss = 617.2894451807416
Model saved to checkpoints/join_model_4.pth


Processing Batch: 100%|██████████| 63/63 [00:30<00:00,  2.08it/s]


Epoch: 5, total_loss: 38840.490173339844, the avg loss = 616.5157170371403
Model saved to checkpoints/join_model_5.pth


Processing Batch: 100%|██████████| 63/63 [00:30<00:00,  2.09it/s]


Epoch: 6, total_loss: 38448.87774658203, the avg loss = 610.2996467711433
Model saved to checkpoints/join_model_6.pth


Processing Batch: 100%|██████████| 63/63 [00:31<00:00,  2.01it/s]


Epoch: 7, total_loss: 38130.42510986328, the avg loss = 605.2448430137028
Model saved to checkpoints/join_model_7.pth


Processing Batch: 100%|██████████| 63/63 [00:31<00:00,  1.97it/s]


Epoch: 8, total_loss: 37982.05352783203, the avg loss = 602.8897385370163
Model saved to checkpoints/join_model_8.pth


Processing Batch: 100%|██████████| 63/63 [00:32<00:00,  1.93it/s]

Epoch: 9, total_loss: 37877.866455078125, the avg loss = 601.2359754774305
Model saved to checkpoints/join_model_9.pth





JEPAModel(
  (encoder): VICRegModel(
    (backbone): SimpleEncoder(
      (conv1): Conv2d(2, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (pool1): MaxPool2d(kernel_size=(5, 5), stride=2, padding=0, dilation=1, ceil_mode=False)
      (pool2): MaxPool2d(kernel_size=(5, 5), stride=5, padding=0, dilation=1, ceil_mode=False)
      (fc1): Linear(in_features=432, out_features=4096, bias=True)
      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    )
    (projection_head): VICRegProjectionHead(
      (layers)