# SimCLR Example with MNIST

This notebook demonstrates how to use the SimCLR estimator from the nidl library to train a small encoder on the MNIST dataset.

## Introduction

SimCLR (Simple Framework for Contrastive Learning of Visual Representations) is a self-supervised learning framework that learns useful features without labels. It works by making different augmented views of the same image close in a representation space, while pushing apart representations of different images.

## Setup and Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import MNIST
from torchvision.models import googlenet
from tqdm import tqdm

import sys
sys.path.append('../')
from nidl.estimators.ssl import SimCLR
from nidl.transforms import MultiViewsTransform


## Data Preparation

### Load MNIST Dataset

In [3]:
# Define transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
contrast_transforms = transforms.Compose(
    [   transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(kernel_size=5),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.Grayscale(num_output_channels=3)
    ]
)

# Load MNIST dataset
train_dataset = MNIST(root='./data', train=True, download=True,
                      transform=MultiViewsTransform(contrast_transforms,
                                                        n_views=2))
test_dataset = MNIST(root='./data', train=False, download=True,
                     transform=transform)

indices = np.random.choice(np.arange(len(train_dataset)),
                           size=10000)
train_indices = indices[:9000]
val_indices = indices[9000:]
train_dataset = Subset(train_dataset, indices=train_indices)
val_dataset = Subset(train_dataset, indices=val_indices)


## Model Definition

In [4]:
# Define the encoder (GoogLeNet)
class GoogLeNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = googlenet(weights=None)
        self.model.fc = nn.Identity()  # Remove the final fully connected layer
        #self.latent_size = 1024  # Set the latent size

    def forward(self, x):
        return self.model(x)

# Initialize the encoder
encoder = GoogLeNetEncoder()



In [5]:
# Define the encoder (GoogLeNet)
class GoogLeNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = googlenet(weights=None)
        self.model.fc = nn.Identity()  # Remove the final fully connected layer
        self.latent_size = 1024

    def forward(self, x):
        return self.model(x)

# Initialize the encoder
encoder = GoogLeNetEncoder()

# Define the SimCLR model
simclr = SimCLR(
    encoder=encoder,
    hidden_dims=[2*1024],
    lr=1e-4,
    temperature=0.1,
    weight_decay=0.001,
    max_epochs=10,
    random_state=42
)

## Training Loop

In [8]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False)

# Fit
simclr.fit(train_loader, val_loader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.


ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/home_local/pa284280/nidl/.pixi/envs/default/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Seed set to 42
You are using a CUDA device ('NVIDIA RTX 4500 Ada Generation') that has Tensor Cores. To properly utilize them, you should set `torch.set_f

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home_local/pa284280/nidl/.pixi/envs/default/lib/python3.12/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.


IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home_local/pa284280/nidl/.pixi/envs/default/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 358, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home_local/pa284280/nidl/.pixi/envs/default/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home_local/pa284280/nidl/.pixi/envs/default/lib/python3.12/site-packages/torch/utils/data/dataset.py", line 411, in __getitems__
    return self.dataset.__getitems__([self.indices[idx] for idx in indices])  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home_local/pa284280/nidl/.pixi/envs/default/lib/python3.12/site-packages/torch/utils/data/dataset.py", line 413, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
                         ~~~~~~~~~~~~^^^^^
IndexError: index 37263 is out of bounds for axis 0 with size 9000


## Evaluation and Visualization

### Learning Curves

In [None]:
# Plot learning curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Learning Curves')
plt.legend()
plt.show()

### Test Set Evaluation

In [None]:
# Evaluate on the test set
simclr.eval()
test_loss = 0.0
with torch.no_grad():
    for batch in val_loader:
        V1, V2 = batch
        outputs = simclr.validation_step((V1, V2), 0)
        test_loss += outputs['loss'].item()
test_loss /= len(val_loader)
print(f"Test Loss: {test_loss:.4f}")