In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as tr

from torch.utils.data import DataLoader
from utils import TrajectoryDataset
from lightly.models.modules.heads import VICRegProjectionHead
from utils import save_model, compute_mean_and_std, get_byol_transforms, get_vicreg_loss
from utils import criterion as VICReg_criterion
from tqdm import tqdm
from models import JEPAModelv2

import numpy as np
import math
import matplotlib.pyplot as plt

In [2]:
embed_dim = 432 
epochs = 20
learning_rate = 0.001
use_expander = False
batch_size = 16 

dataset_directory = "./dataset"
states_filename = "states.npy"
actions_filename = "actions.npy"

In [3]:
def train_joint(model, dataloader, criterion_encoder, criterion_pred, 
                optimizer, transformation1, transformation2, 
                device, epochs=10, use_expander=False):
    model.to(device)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc="Processing Batch"):
            state, action = batch
            state, action = state.to(device), action.to(device)
            B, L, D = state.shape[0], action.shape[1], model.repr_dim

            loss, loss1, loss2, loss3 = 0, 0, 0, 0

            o = state[:, 0, :, :, :]
            model.set_init_embedding(o)

            for i in range(L):
                # inference of encoder(next state) and predictor(action) 
                sy_hat, sy = model(action[:, i, :], state[:, i+1, :, :, :])

                # compute loss2 (distance btw sy and sy_hat)
                loss2 += criterion_pred(sy_hat, sy)
            
            # adding all loss and doing back propagation
            loss = loss2 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        print(f"Epoch: {epoch}, total_loss: {total_loss}, the avg loss = {total_loss/len(dataloader)}")
        save_model(model, epoch, file_name="join_modelv2")

    return model   

In [4]:
dataset = TrajectoryDataset(
    data_dir = dataset_directory,
    states_filename = states_filename,
    actions_filename = actions_filename,
    s_transform = None,
    a_transform = None,
    length = 2000 
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
first_datapoint = next(iter(dataloader))
state, action = first_datapoint
print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

Number of data_points 125
Shape of state: torch.Size([16, 17, 2, 65, 65])
Shape of action: torch.Size([16, 16, 2])


  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


In [5]:
mean, std = compute_mean_and_std(dataloader, is_channelsize3=False)
transformation1, transformation2 = get_byol_transforms(mean, std)

In [6]:
model = JEPAModelv2(embed_dim, 2048, 2, 2)

# joint_optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1.5e-4)
joint_optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion_predictor = nn.MSELoss()
criterion_encoder = VICReg_criterion
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_joint(model, dataloader, criterion_encoder, criterion_predictor, 
            joint_optimizer, transformation1, transformation2, device,
            epochs=epochs, use_expander=use_expander)

Processing Batch: 100%|██████████| 125/125 [00:11<00:00, 11.32it/s]


Epoch: 0, total_loss: 39.8578193967478, the avg loss = 0.3188625551739824
Model saved to checkpoints/join_modelv2_0.pth


Processing Batch: 100%|██████████| 125/125 [00:10<00:00, 12.11it/s]


Epoch: 1, total_loss: 0.006402654329576762, the avg loss = 5.1221234636614096e-05
Model saved to checkpoints/join_modelv2_1.pth


Processing Batch: 100%|██████████| 125/125 [00:14<00:00,  8.62it/s]


Epoch: 2, total_loss: 0.002673478740689461, the avg loss = 2.1387829925515688e-05
Model saved to checkpoints/join_modelv2_2.pth


Processing Batch: 100%|██████████| 125/125 [00:12<00:00, 10.08it/s]


Epoch: 3, total_loss: 0.001401989046826202, the avg loss = 1.1215912374609616e-05
Model saved to checkpoints/join_modelv2_3.pth


Processing Batch: 100%|██████████| 125/125 [00:13<00:00,  9.23it/s]


Epoch: 4, total_loss: 0.0008633216498310503, the avg loss = 6.9065731986484025e-06
Model saved to checkpoints/join_modelv2_4.pth


Processing Batch: 100%|██████████| 125/125 [00:10<00:00, 11.57it/s]


Epoch: 5, total_loss: 0.0006030483618815197, the avg loss = 4.824386895052157e-06
Model saved to checkpoints/join_modelv2_5.pth


Processing Batch: 100%|██████████| 125/125 [00:11<00:00, 10.72it/s]


Epoch: 6, total_loss: 0.0004600449135523377, the avg loss = 3.6803593084187012e-06
Model saved to checkpoints/join_modelv2_6.pth


Processing Batch: 100%|██████████| 125/125 [00:08<00:00, 15.14it/s]


Epoch: 7, total_loss: 0.00037129764655219333, the avg loss = 2.9703811724175464e-06
Model saved to checkpoints/join_modelv2_7.pth


Processing Batch: 100%|██████████| 125/125 [00:08<00:00, 14.82it/s]


Epoch: 8, total_loss: 0.0003095641968684504, the avg loss = 2.476513574947603e-06
Model saved to checkpoints/join_modelv2_8.pth


Processing Batch: 100%|██████████| 125/125 [00:08<00:00, 14.83it/s]


Epoch: 9, total_loss: 0.0002638307944380358, the avg loss = 2.1106463555042864e-06
Model saved to checkpoints/join_modelv2_9.pth


Processing Batch: 100%|██████████| 125/125 [00:08<00:00, 14.51it/s]


Epoch: 10, total_loss: 0.00022779921721394203, the avg loss = 1.8223937377115362e-06
Model saved to checkpoints/join_modelv2_10.pth


Processing Batch: 100%|██████████| 125/125 [00:07<00:00, 16.21it/s]


Epoch: 11, total_loss: 0.00019860146608152718, the avg loss = 1.5888117286522174e-06
Model saved to checkpoints/join_modelv2_11.pth


Processing Batch: 100%|██████████| 125/125 [00:07<00:00, 16.64it/s]


Epoch: 12, total_loss: 0.00017412097747637745, the avg loss = 1.3929678198110195e-06
Model saved to checkpoints/join_modelv2_12.pth


Processing Batch: 100%|██████████| 125/125 [00:06<00:00, 18.04it/s]


Epoch: 13, total_loss: 0.00015350115984347212, the avg loss = 1.228009278747777e-06
Model saved to checkpoints/join_modelv2_13.pth


Processing Batch: 100%|██████████| 125/125 [00:07<00:00, 17.67it/s]


Epoch: 14, total_loss: 0.00013581261953277135, the avg loss = 1.0865009562621709e-06
Model saved to checkpoints/join_modelv2_14.pth


Processing Batch: 100%|██████████| 125/125 [00:06<00:00, 19.08it/s]


Epoch: 15, total_loss: 0.0001204506979775033, the avg loss = 9.636055838200264e-07
Model saved to checkpoints/join_modelv2_15.pth


Processing Batch: 100%|██████████| 125/125 [00:06<00:00, 19.12it/s]


Epoch: 16, total_loss: 0.00010720207785652747, the avg loss = 8.576166228522198e-07
Model saved to checkpoints/join_modelv2_16.pth


Processing Batch: 100%|██████████| 125/125 [00:07<00:00, 17.13it/s]


Epoch: 17, total_loss: 9.543350824969821e-05, the avg loss = 7.634680659975856e-07
Model saved to checkpoints/join_modelv2_17.pth


Processing Batch: 100%|██████████| 125/125 [00:07<00:00, 17.03it/s]


Epoch: 18, total_loss: 8.519333783851835e-05, the avg loss = 6.815467027081468e-07
Model saved to checkpoints/join_modelv2_18.pth


Processing Batch: 100%|██████████| 125/125 [00:07<00:00, 17.03it/s]


Epoch: 19, total_loss: 7.620421030196667e-05, the avg loss = 6.096336824157334e-07
Model saved to checkpoints/join_modelv2_19.pth


JEPAModelv2(
  (encoder): SimpleEncoderv2(
    (conv1): Conv2d(2, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (pool1): MaxPool2d(kernel_size=(5, 5), stride=2, padding=0, dilation=1, ceil_mode=False)
    (pool2): MaxPool2d(kernel_size=(5, 5), stride=5, padding=0, dilation=1, ceil_mode=False)
    (fc1): Linear(in_features=432, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=432, bias=True)
  )
  (predictor): Predictorv2(
    (linear1): Linear(in_features=434, out_features=2048, bias=True)
    (relu): 