<a href="https://colab.research.google.com/github/nasselm4i/Deep-Theoretical/blob/main/SSL_Methods_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Requirements



In [1]:
%%capture
!git clone https://github.com/htdt/self-supervised.git && pip install lightly
%cd self-supervised
!pip install wandb --upgrade

In [2]:
import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset
from lightly.data.multi_view_collate import MultiViewCollate
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
from sklearn.neighbors import KernelDensity

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
import wandb

####################################################################################

import copy

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.utils.scheduler import cosine_schedule

####################################################################################

from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import SimSiamTransform


from lightly.loss import VICRegLoss

from lightly.loss.vicreg_loss import VICRegLoss
from lightly.transforms.vicreg_transform import VICRegTransform

from lightly.transforms.simclr_transform import SimCLRTransform

# Metrics for tracking Collapse

## 1.1 Singular Value 

In [8]:
def compute_singular_values(z0, z1):
    """
    Compute the singular values of the cross-correlation matrix between the embeddings z0 and z1.
    
    Args:
    - z0 (torch.Tensor): The embedding tensor for the first sequence, with shape (seq_len, embedding_dim).
    - z1 (torch.Tensor): The embedding tensor for the second sequence, with shape (seq_len, embedding_dim).
    
    Returns:
    - singular_values (np.ndarray): The singular values of the cross-correlation matrix, sorted in descending order.
    """
    with torch.no_grad():
      # Compute cross-correlation matrix
      c = cross_correlation_matrix(z0, z1)

      # Compute singular values
      svd = torch.svd(c)
      singular_values = svd.S.cpu().detach().numpy()
      return singular_values

## 1.2 Mean off diagonal Cross-Correlation

In [11]:
def cross_correlation_matrix(z0, z1):
  """
    Compute the cross-correlation matrix between the embeddings z0 and z1.
    
    Args:
    - z0 (torch.Tensor): The embedding tensor for the first sequence, with shape (seq_len, embedding_dim).
    - z1 (torch.Tensor): The embedding tensor for the second sequence, with shape (seq_len, embedding_dim).
    
    Returns:
    - c (np.ndarray): the cross-correlation matrix.
    """
  with torch.no_grad():
      z0_centered = z0 - z0.mean(dim=0) / z0.std(0)
      z1_centered = z1 - z1.mean(dim=0) / z1.std(0)
      c = torch.mm(z0_centered.T, z1_centered) / (z0_centered.shape[0])
      # std0 = z0_centered.std(dim=0, unbiased=False)
      # std1 = z1_centered.std(dim=0, unbiased=False)
      # c = c / torch.outer(std0, std1)
      return c

def cross_covariance_matrix(z0, z1):
  pass

def compute_average_off_correlation_matrix(z0,z1):
  """
    Compute the average of the off diagonal of the cross-correlation matrix between the embeddings z0 and z1.
    
    Args:
    - z0 (torch.Tensor): The embedding tensor for the first sequence, with shape (seq_len, embedding_dim).
    - z1 (torch.Tensor): The embedding tensor for the second sequence, with shape (seq_len, embedding_dim).
    
    Returns:
    - corr (float): The average of the off diagonal of the cross-correlation matrix.
    """
  with torch.no_grad():
    c = cross_correlation_matrix(z0, z1)
    corr = (c.flatten()[c.shape[0]::c.shape[0]+1].mean().item())
    return corr

# Methods

In [None]:
import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset
from lightly.data.multi_view_collate import MultiViewCollate
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform


class BarlowTwins(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BarlowTwins(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimCLRTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = BarlowTwinsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

## 2.1 BarlowTwin

In [None]:
class BarlowTwins(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

In [None]:
wandb.init(
    project="SSL-Methods",
    name="BarlowTwins-Vanilla-VICTransform",
    config={
        "max_epochs": 10,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BarlowTwins(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
# transform = SimCLRTransform(input_size=32)
transform = VICRegTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = BarlowTwinsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

print("Starting Training")
start = time.time()
log_dict_BarlowTwins = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    corr = 0
    corr_count = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_BarlowTwins["avg_loss"].append(avg_loss)
    log_dict_BarlowTwins["avg_corr"].append(avg_corr)
    log_dict_BarlowTwins["entropy_z0"].append(entropy_z0)
    log_dict_BarlowTwins["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_BarlowTwins = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_BarlowTwins.size),singular_values_BarlowTwins, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'BarlowTwin-Vanilla.pth')

## 2.2 VICReg

In [13]:
class VICReg(nn.Module): # Same as Barlow Twin
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

In [None]:
wandb.init(
    project="SSL-Methods",
    name="VICReg-Vanilla",
    config={
        "max_epochs": 40,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = VICReg(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = VICRegTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = VICRegLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

print("Starting Training")
start = time.time()
log_dict_VICReg = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    
    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_VICReg["avg_loss"].append(avg_loss)
    log_dict_VICReg["avg_corr"].append(avg_corr)
    log_dict_VICReg["entropy_z0"].append(entropy_z0)
    log_dict_VICReg["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_VICReg = compute_singular_values(z0, z1)

# plot singular values
# fig, ax = plt.subplots()
# ax.plot(range(singular_values_VICReg.size),singular_values_VICReg, label=f'Singular Values')
# ax.set_yscale('log')
# ax.set_xlabel('Singular Value Index')
# ax.set_ylabel('Singular Value')
# wandb.log({"Log Singular Values ": fig})
data = [[x, y] for (x, y) in zip(range(singular_values_VICReg.size), singular_values_VICReg)]
table = wandb.Table(data=data, columns = ["Singular Value Index", "Log of Singular Values"])
wandb.log(
    {"my_custom_plot_id" : wandb.plot.line(table, "x", "y",
           title="Log of Singular Values for Tracking Dimensional Collapse")})

# Save the model
torch.save(model.state_dict(), 'VICReg-Vanilla.pth')


## 2.4 SimSiam

In [None]:
class SimSiam(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimSiamProjectionHead(512, 512, 128)
        self.prediction_head = SimSiamPredictionHead(128, 64, 128)

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

Training

In [None]:
wandb.init(
    project="SSL-Methods",
    name="SimSiam-Vanilla",
    config={
        "max_epochs": 10,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimSiam(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimSiamTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

print("Starting Training")
start = time.time()
log_dict_SimSiam = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0, p0 = model(x0)
        z1, p1 = model(x1)
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_SimSiam["avg_loss"].append(avg_loss)
    log_dict_SimSiam["avg_corr"].append(avg_corr)
    log_dict_SimSiam["entropy_z0"].append(entropy_z0)
    log_dict_SimSiam["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_SimSiam = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_SimSiam.size),singular_values_SimSiam, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'SimSiam-Vanilla.pth')

## 2.5 BYOL

In [None]:
class BYOL(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projection_head = BYOLProjectionHead(512, 1024, 256)
        self.prediction_head = BYOLPredictionHead(256, 1024, 256)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

In [None]:
wandb.init(
    project="SSL-Methods",
    name="BYOL-Vanilla",
    config={
        "max_epochs": 10,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimCLRTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)


print("Starting Training")
start = time.time()
log_dict_BYOL = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, config.max_epochs, 0.996, 1)
    for (x0, x1), _, _ in dataloader:
        update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
        update_momentum(
            model.projection_head, model.projection_head_momentum, m=momentum_val
        )
        x0 = x0.to(device)
        x1 = x1.to(device)
        p0 = model(x0)
        z0 = model.forward_momentum(x0)
        p1 = model(x1)
        z1 = model.forward_momentum(x1)
        loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_BYOL["avg_loss"].append(avg_loss)
    log_dict_BYOL["avg_corr"].append(avg_corr)
    log_dict_BYOL["entropy_z0"].append(entropy_z0)
    log_dict_BYOL["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_BYOL = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_BYOL.size),singular_values_BYOL, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'BYOL-Vanilla.pth')

# Test

## 3.1 BarlowTwin

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models import resnet18

wandb.init(
    project="BarlowTwin-Vanilla-Test",
    config={
        "max_epochs": 800,
        "batch_size": 256,
        "lr": 0.2
    })

config = wandb.config

# Load the trained BarlowTwin model
state_dict = torch.load('BarlowTwin-Vanilla.pth')

# Create a new model and load the state_dict
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
backbone = BarlowTwins(backbone)
backbone.load_state_dict(state_dict, strict=False)

head = nn.Linear(2048, 10)
head.weight.data.normal_(mean=0.0, std=0.01)
head.bias.data.zero_()
model = nn.Sequential(backbone, head)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

backbone.requires_grad_(False)
head.requires_grad_(True)

model.train()

# Load the CIFAR10 train set
cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)
# Train the linear classifier
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(head.parameters(), lr=config.lr)

print("Start Training Linear Classifier")
start = time.time()
log_dict_BarlowTwins_LC = {"Loss": []}
for epoch in range(config.max_epochs):
    running_loss = 0.0
    for i, (inputs, labels, _) in enumerate(train_loader, 0):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
    log_dict_BarlowTwins_LC["Loss"].append(running_loss/(i+1))
    wandb.log({"Loss": running_loss / (i+1)})

In [None]:
# Load the CIFAR10 test set
cifar10_test = torchvision.datasets.CIFAR10("datasets/cifar10", train=False, download=True)
dataset_test = LightlyDataset.from_torch_dataset(cifar10_test, transform=transforms.ToTensor())
dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=config.batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=8,
)
# Set the model to evaluation mode
model.eval()
model.to(device)

# Test the model
with torch.no_grad():
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, targets, _ in dataloader_test:
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the model's predictions
        outputs = model(inputs)
        print(outputs)
        # print(targets)
        # Compute the loss
        loss = criterion(outputs, targets)
        total_loss += loss.item()

        # Compute the number of correctly classified samples
        _, predictions = torch.max(outputs, 1)
        total_correct += (predictions == targets).sum().item()
        total_samples += targets.size(0)

    # Compute the accuracy and average loss
    accuracy = total_correct / total_samples
    avg_loss = total_loss / len(dataloader_test)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Average Test Loss: {avg_loss:.4f}')

## 3.2 VICReg

In [None]:
model_path = 'VICReg-Vanilla.pth'

# Define the linear classifier
linear_classifier = torch.nn.Linear(2048, 10)

# Load the trained BarlowTwin model
barlowtwins_model = torch.load(model_path)

# Freeze the BarlowTwin layers
for param in barlowtwins_model.parameters():
    param.requires_grad = False

# Replace the last layer with the linear classifier
barlowtwins_model.fc = linear_classifier

# Load the CIFAR10 test set
test_dataset = dataset.CIFAR10(cifar10, train=False, download=True, transform=transforms.ToTensor())
test_loader = dataloader(test_dataset, batch_size=256, shuffle=False)

# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_classifier.parameters(), lr=0.001)

# Set the model to evaluation mode
barlowtwins_model.eval()

# Test the model
with torch.no_grad():
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, targets in test_loader:
        # Move the inputs and targets to the device (GPU or CPU)
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the model's predictions
        outputs = barlowtwins_model(inputs)

        # Compute the loss
        loss = criterion(outputs, targets)
        total_loss += loss.item()

        # Compute the number of correctly classified samples
        _, predictions = torch.max(outputs, 1)
        total_correct += (predictions == targets).sum().item()
        total_samples += targets.size(0)

    # Compute the accuracy and average loss
    accuracy = total_correct / total_samples
    avg_loss = total_loss / len(test_loader)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Average Test Loss: {avg_loss:.4f}')


## 3.4 SimSiam

In [None]:
model_path = 'BarlowTwinV0.pth'

# Define the linear classifier
linear_classifier = torch.nn.Linear(2048, 10)

# Load the trained BarlowTwin model
barlowtwins_model = torch.load(model_path)

# Freeze the BarlowTwin layers
for param in barlowtwins_model.parameters():
    param.requires_grad = False

# Replace the last layer with the linear classifier
barlowtwins_model.fc = linear_classifier

# Load the CIFAR10 test set
test_dataset = dataset.CIFAR10(cifar10, train=False, download=True, transform=transforms.ToTensor())
test_loader = dataloader(test_dataset, batch_size=256, shuffle=False)

# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_classifier.parameters(), lr=0.001)

# Set the model to evaluation mode
barlowtwins_model.eval()

# Test the model
with torch.no_grad():
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, targets in test_loader:
        # Move the inputs and targets to the device (GPU or CPU)
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the model's predictions
        outputs = barlowtwins_model(inputs)

        # Compute the loss
        loss = criterion(outputs, targets)
        total_loss += loss.item()

        # Compute the number of correctly classified samples
        _, predictions = torch.max(outputs, 1)
        total_correct += (predictions == targets).sum().item()
        total_samples += targets.size(0)

    # Compute the accuracy and average loss
    accuracy = total_correct / total_samples
    avg_loss = total_loss / len(test_loader)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Average Test Loss: {avg_loss:.4f}')


## 3.5 BYOL

In [None]:
model_path = 'BarlowTwinV0.pth'

# Define the linear classifier
linear_classifier = torch.nn.Linear(2048, 10)

# Load the trained BarlowTwin model
barlowtwins_model = torch.load(model_path)

# Freeze the BarlowTwin layers
for param in barlowtwins_model.parameters():
    param.requires_grad = False

# Replace the last layer with the linear classifier
barlowtwins_model.fc = linear_classifier

# Load the CIFAR10 test set
test_dataset = dataset.CIFAR10(cifar10, train=False, download=True, transform=transforms.ToTensor())
test_loader = dataloader(test_dataset, batch_size=256, shuffle=False)

# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_classifier.parameters(), lr=0.001)

# Set the model to evaluation mode
barlowtwins_model.eval()

# Test the model
with torch.no_grad():
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, targets in test_loader:
        # Move the inputs and targets to the device (GPU or CPU)
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the model's predictions
        outputs = barlowtwins_model(inputs)

        # Compute the loss
        loss = criterion(outputs, targets)
        total_loss += loss.item()

        # Compute the number of correctly classified samples
        _, predictions = torch.max(outputs, 1)
        total_correct += (predictions == targets).sum().item()
        total_samples += targets.size(0)

    # Compute the accuracy and average loss
    accuracy = total_correct / total_samples
    avg_loss = total_loss / len(test_loader)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Average Test Loss: {avg_loss:.4f}')


# Methods - Change 1

## 2.1 BarlowTwin

In [30]:
import torch
import torch.distributed as dist


class BarlowTwinsLossChanged(torch.nn.Module):
    def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False, normalized:bool = True, redundancy:bool = True ):
        """Lambda param configuration with default value like in [0]

        Args:
            lambda_param:
                Parameter for importance of redundancy reduction term.
                Defaults to 5e-3 [0].
            gather_distributed:
                If True then the cross-correlation matrices from all gpus are
                gathered and summed before the loss calculation.
        """
        super(BarlowTwinsLossChanged, self).__init__()
        self.lambda_param = lambda_param
        self.gather_distributed = gather_distributed
        self.normalized = normalized
        self.redundancy = redundancy
    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
        device = z_a.device
        if self.normalized :
          # normalize repr. along the batch dimension
          z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0)  # NxD
          z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0)  # NxD

          N = z_a.size(0)
          D = z_a.size(1)

          # cross-correlation matrix
          c = torch.mm(z_a_norm.T, z_b_norm) / N  # DxD
        else :
          c = torch.mm(z_a.T, z_b) / N  # DxD
        # sum cross-correlation matrix between multiple gpus
        if self.gather_distributed and dist.is_initialized():
            world_size = dist.get_world_size()
            if world_size > 1:
                c = c / world_size
                dist.all_reduce(c)

        # loss
        c_diff = (c - torch.eye(D, device=device)).pow(2)  # DxD
        if self.redundancy:
          # multiply off-diagonal elems of c_diff by lambda
          c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param
        loss = c_diff.sum()

        return loss

class BarlowTwins(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


In [None]:
wandb.init(
    project="SSL-Methods",
    name="BarlowTwins-WO-BN",
    notes="...",
    config={
        "max_epochs": 20,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BarlowTwins(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimCLRTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = BarlowTwinsLossChanged(normalized=False)
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

print("Starting Training")
start = time.time()
log_dict_BarlowTwins = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    corr = 0
    corr_count = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_BarlowTwins["avg_loss"].append(avg_loss)
    log_dict_BarlowTwins["avg_corr"].append(avg_corr)
    log_dict_BarlowTwins["entropy_z0"].append(entropy_z0)
    log_dict_BarlowTwins["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_BarlowTwins = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_BarlowTwins.size),singular_values_BarlowTwins, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'BarlowTwin-Vanilla.pth')

## 2.2 VICReg

In [18]:
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor

from lightly.utils.dist import gather


class VICRegLossChanged(torch.nn.Module):
    def __init__(
        self,
        lambda_param: float = 25.0,
        mu_param: float = 25.0,
        nu_param: float = 1.0,
        gather_distributed: bool = False,
        eps=0.0001,
        covariance: bool = False,
        variance: bool = False,
    ):
        super(VICRegLossChanged, self).__init__()
        self.lambda_param = lambda_param
        self.mu_param = mu_param
        self.nu_param = nu_param
        self.gather_distributed = gather_distributed
        self.eps = eps
        self.covariance = covariance
        self.variance = variance

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
        """Returns VICReg loss.

        Args:
            z_a:
                Tensor with shape (batch_size, ..., dim).
            z_b:
                Tensor with shape (batch_size, ..., dim).
        """
        assert (
            z_a.shape[0] > 1 and z_b.shape[0] > 1
        ), f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}"
        assert (
            z_a.shape == z_b.shape
        ), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}."

        # invariance term of the loss
        inv_loss = invariance_loss(x=z_a, y=z_b)

        # gather all batches
        if self.gather_distributed and dist.is_initialized():
            world_size = dist.get_world_size()
            if world_size > 1:
                z_a = torch.cat(gather(z_a), dim=0)
                z_b = torch.cat(gather(z_b), dim=0)

        var_loss = 0.5 * (variance_loss(x=z_a, eps=self.eps) + variance_loss(x=z_b, eps=self.eps))
        cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b)
        
        if self.variance:
          loss = (
            self.lambda_param * inv_loss
            + self.nu_param * cov_loss
        )
          return loss
        if self.covariance:
          loss = (
            self.lambda_param * inv_loss
            + self.mu_param * var_loss
        )
          return loss
        loss = (
            self.lambda_param * inv_loss
            + self.mu_param * var_loss
            + self.nu_param * cov_loss
        )
        return loss


def invariance_loss(x: Tensor, y: Tensor) -> Tensor:
    """Returns VICReg invariance loss.

    Args:
        x:
            Tensor with shape (batch_size, ..., dim).
        y:
            Tensor with shape (batch_size, ..., dim).
    """
    return F.mse_loss(x, y)


def variance_loss(x: Tensor, eps: float = 0.0001) -> Tensor:
    """Returns VICReg variance loss.

    Args:
        x:
            Tensor with shape (batch_size, ..., dim).
        eps:
            Epsilon for numerical stability.
    """
    x = x - x.mean(dim=0)
    std = torch.sqrt(x.var(dim=0) + eps)
    loss = torch.mean(F.relu(1.0 - std))
    return loss


def covariance_loss(x: Tensor) -> Tensor:
    """Returns VICReg covariance loss.

    Generalized version of the covariance loss with support for tensors with more than
    two dimensions. Adapted from VICRegL:
    https://github.com/facebookresearch/VICRegL/blob/803ae4c8cd1649a820f03afb4793763e95317620/main_vicregl.py#L299

    Args:
        x:
            Tensor with shape (batch_size, ..., dim).
    """
    x = x - x.mean(dim=0)
    batch_size = x.size(0)
    dim = x.size(-1)
    # nondiag_mask has shape (dim, dim) with 1s on all non-diagonal entries.
    nondiag_mask = ~torch.eye(dim, device=x.device, dtype=torch.bool)
    # cov has shape (..., dim, dim)
    cov = torch.einsum("b...c,b...d->...cd", x, x) / (batch_size - 1)
    loss = cov[..., nondiag_mask].pow(2).sum(-1) / dim
    return loss.mean()

class VICReg(nn.Module): # Same as Barlow Twin
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


In [None]:
wandb.init(
    project="SSL-Methods",
    name="VICReg-WO-Variance-20",
    notes="Without Variance",
    config={
        "max_epochs": 20,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = VICReg(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = VICRegTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = VICRegLossChanged(variance=True) # True = removed
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

print("Starting Training")
start = time.time()
log_dict_VICReg = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    
    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_VICReg["avg_loss"].append(avg_loss)
    log_dict_VICReg["avg_corr"].append(avg_corr)
    log_dict_VICReg["entropy_z0"].append(entropy_z0)
    log_dict_VICReg["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_VICReg = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_VICReg.size),singular_values_VICReg, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'VICReg-Changed.pth')


## 2.4 SimSiam

In [None]:
class SimSiam(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimSiamProjectionHead(512, 512, 128)
        self.prediction_head = SimSiamPredictionHead(128, 64, 128)

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

Training

In [None]:
wandb.init(
    project="SSL-Methods",
    name="SimSiam-Change-1",
    notes="...",
    config={
        "max_epochs": 10,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimSiam(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimSiamTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)

print("Starting Training")
start = time.time()
log_dict_SimSiam = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0, p0 = model(x0)
        z1, p1 = model(x1)
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_SimSiam["avg_loss"].append(avg_loss)
    log_dict_SimSiam["avg_corr"].append(avg_corr)
    log_dict_SimSiam["entropy_z0"].append(entropy_z0)
    log_dict_SimSiam["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_SimSiam = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_SimSiam.size),singular_values_SimSiam, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'SimSiam-Vanilla.pth')

## 2.5 BYOL

In [None]:
class BYOL(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projection_head = BYOLProjectionHead(512, 1024, 256)
        self.prediction_head = BYOLPredictionHead(256, 1024, 256)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

In [None]:
wandb.init(
    project="SSL-Methods",
    name="BYOL-Change-1",
    notes="...",
    config={
        "max_epochs": 10,
        "batch_size": 256,
        "lr": 0.06
    })

config = wandb.config

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
transform = SimCLRTransform(input_size=32)
dataset = LightlyDataset.from_torch_dataset(cifar10, transform=transform)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config.batch_size,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)


print("Starting Training")
start = time.time()
log_dict_BYOL = {"avg_loss": [], "avg_corr": [], "entropy_z0": [], "entropy_z1": []}
for epoch in tqdm(range(config.max_epochs)):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, config.max_epochs, 0.996, 1)
    for (x0, x1), _, _ in dataloader:
        update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
        update_momentum(
            model.projection_head, model.projection_head_momentum, m=momentum_val
        )
        x0 = x0.to(device)
        x1 = x1.to(device)
        p0 = model(x0)
        z0 = model.forward_momentum(x0)
        p1 = model(x1)
        z1 = model.forward_momentum(x1)
        loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    avg_corr = compute_average_off_correlation_matrix(z0,z1)
    entropy_z0 = entropy(z0)
    entropy_z1 = entropy(z1)
    log_dict_BYOL["avg_loss"].append(avg_loss)
    log_dict_BYOL["avg_corr"].append(avg_corr)
    log_dict_BYOL["entropy_z0"].append(entropy_z0)
    log_dict_BYOL["entropy_z1"].append(entropy_z1)
    # print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}, avg cross-correlation: {avg_corr}, mutual information: {mut_info} ")
    wandb.log({"avg_loss": avg_loss, "avg_corr": avg_corr, "entropy_z0": entropy_z0, "entropy_z1": entropy_z1})

# time to train 
end = time.time()
train_time = time.strftime("%H:%M:%S", time.gmtime(end - start))
print("Time for the training :", train_time)
# compute singular values
singular_values_BYOL = compute_singular_values(z0, z1)

# plot singular values
fig, ax = plt.subplots()
ax.plot(range(singular_values_BYOL.size),singular_values_BYOL, label=f'Singular Values')
ax.set_yscale('log')
ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value')
wandb.log({"Log Singular Values ": fig})

# Save the model
torch.save(model.state_dict(), 'BYOL-Vanilla.pth')