In [1]:
!pip install lightly



In [2]:
from google.colab import drive

drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import sys

GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = "DL"

GOOGLE_DRIVE_PATH = os.path.join("drive", "My Drive", GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
print(os.listdir(GOOGLE_DRIVE_PATH))

sys.path.append(GOOGLE_DRIVE_PATH)

['dataset', '__pycache__', 'dataset.py', 'dataset2']


In [4]:
import numpy as np
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as tr

from dataset import TrajectoryDataset

from lightly.models.modules.heads import VICRegProjectionHead

In [5]:
dataset = TrajectoryDataset(
    data_dir = "/content/drive/My Drive/DL/dataset2",
    states_filename = "states_5000.npy",
    actions_filename = "actions_5000.npy",
    s_transform = None,
    a_transform = None,
)

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

first_datapoint = next(iter(dataloader))
state, action = first_datapoint

print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


Number of data_points 79
Shape of state: torch.Size([64, 17, 2, 65, 65])
Shape of action: torch.Size([64, 16, 2])


In [6]:
class VICReg(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = VICRegProjectionHead(
            input_dim=512,
            hidden_dim=1024,
            output_dim=1024,
            num_layers=3,
        )

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

In [7]:
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = VICReg(backbone)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum= 0.9, weight_decay=1.5e-4)

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(device)

cuda


In [9]:
def get_byol_transforms(mean, std):
    # Define the first augmentation pipeline
    transformT = tr.Compose([
        tr.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
        tr.RandomRotation(degrees=90),  # Random rotation
        tr.GaussianBlur(kernel_size=(23, 23), sigma=(0.1, 2.0)),  # Gaussian blur
        tr.Normalize(mean, std),  # Normalize for 2 channels
    ])

    # Define a slightly different second augmentation pipeline
    transformT1 = tr.Compose([
        tr.RandomVerticalFlip(p=0.5),  # Random vertical flip
        tr.RandomRotation(degrees=45),  # Different random rotation
        tr.GaussianBlur(kernel_size=(15, 15), sigma=(0.1, 1.5)),  # Gaussian blur with smaller kernel
        tr.Normalize(mean, std),  # Normalize for 2 channels
    ])

    return transformT, transformT1


def off_diagonal(matrix):
    """
    Extracts the off-diagonal elements of a square matrix.

    Args:
        matrix (torch.Tensor): A square matrix of shape (D, D).

    Returns:
        torch.Tensor: A tensor containing all off-diagonal elements.
    """
    # Create a mask for off-diagonal elements
    n = matrix.shape[0]
    off_diag_mask = ~torch.eye(n, dtype=bool, device=matrix.device)

    # Use the mask to extract off-diagonal elements
    off_diag_elements = matrix[off_diag_mask]
    # print(off_diag_elements)
    return off_diag_elements


def criterion(x, y, lmbd = 5e-3, invar = 25, mu = 25, nu = 1, epsilon = 1e-4):
    bs = x.size(0)
    emb = x.size(1)

    std_x = torch.sqrt(x.var(dim=0) + epsilon)
    std_y = torch.sqrt(y.var(dim=0) + epsilon)
    var_loss = torch.mean(F.relu(1 - std_x)) + torch.mean(F.relu(1 - std_y))

    invar_loss = F.mse_loss(x, y)

    x = x - x.mean(dim=0)
    y = y - y.mean(dim=0)
    cov_z_a = (x.T @ x) / (bs - 1)
    cov_z_b = (y.T @ y) / (bs - 1)
    cov_loss = off_diagonal(cov_z_a).pow_(2).sum() / emb + off_diagonal(cov_z_b).pow_(2).sum() / emb

    # print(f"invar_loss: {invar_loss.item()}")
    # print(f"var_loss: {var_loss.item()}")
    # print(f"cov_loss: {cov_loss.item()}")

    loss = invar*invar_loss + mu*var_loss + nu*cov_loss

    return loss

In [10]:
def compute_mean_and_std():
  num_channels = 2  # Assuming you have 2 channels
  pixel_sum = [0] * num_channels
  pixel_squared_sum = [0] * num_channels
  total_pixels = 0

  # Iterate through the dataset
  for state, _ in dataloader:
      # Iterate through each channel
      for channel in range(num_channels):
          channel_data = state[:, :, channel, :, :].reshape(-1)  # Flatten the current channel
          pixel_sum[channel] += channel_data.sum().item()
          pixel_squared_sum[channel] += (channel_data ** 2).sum().item()

      # Total number of pixels per channel (all images combined)
      total_pixels += state.size(0) * state.size(1) * state.size(3) * state.size(4)

  # Calculate mean and std for each channel
  mean = [pixel_sum[c] / total_pixels for c in range(num_channels)]
  std = [
      np.sqrt((pixel_squared_sum[c] / total_pixels) - (mean[c] ** 2))
      for c in range(num_channels)
  ]

  print(f"Mean per channel: {mean}")
  print(f"Std per channel: {std}")
  mean.append(mean[1])
  std.append(std[1])
  return mean, std

In [11]:
mean, std = compute_mean_and_std()
transformation1, transformation2  = get_byol_transforms(mean, std)

Mean per channel: [0.00023668267603448322, 0.009316061526496694]
Std per channel: [0.003329996277483632, 0.028114841590122994]


In [12]:
print("Starting Training")
def train(dataloader, epochs):
  for epoch in range(epochs):
      total_loss = 0
      ind = 0
      for batch in dataloader:
          state, action = batch
          img = state[:, 0, :, :, :]
          img = torch.cat([img, img[:, 1:2, :, :]], dim=1)

          x0 = transformation1(img)
          x1 = transformation2(img)

          x0 = x0.to(device)
          x1 = x1.to(device)

          z0 = model(x0)
          z1 = model(x1)

          loss = criterion(z0, z1)
          total_loss += loss.detach()
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          # print(f"batch: {ind}")
          avg_loss = total_loss / len(dataloader)
          ind = ind + 1
      print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training


In [13]:
train(dataloader, 30)

epoch: 00, loss: 37.53720
epoch: 01, loss: 34.67337
epoch: 02, loss: 33.28022
epoch: 03, loss: 32.07617
epoch: 04, loss: 31.24450
epoch: 05, loss: 30.76028
epoch: 06, loss: 30.49474
epoch: 07, loss: 30.22291
epoch: 08, loss: 29.99336
epoch: 09, loss: 29.96262
epoch: 10, loss: 29.49220
epoch: 11, loss: 29.47420
epoch: 12, loss: 29.21940
epoch: 13, loss: 28.93957
epoch: 14, loss: 28.87402
epoch: 15, loss: 28.73059
epoch: 16, loss: 28.52197
epoch: 17, loss: 28.65028
epoch: 18, loss: 28.29223
epoch: 19, loss: 28.20359
epoch: 20, loss: 28.36916
epoch: 21, loss: 28.24691
epoch: 22, loss: 28.01368
epoch: 23, loss: 27.89042
epoch: 24, loss: 27.83872
epoch: 25, loss: 27.75111
epoch: 26, loss: 27.69475
epoch: 27, loss: 27.79646
epoch: 28, loss: 27.54712
epoch: 29, loss: 27.38757
