# SimCLR Example with MNIST

This notebook demonstrates how to use the SimCLR estimator from the nidl library to train a small encoder on the MNIST dataset.

## Introduction

SimCLR (Simple Framework for Contrastive Learning of Visual Representations) is a self-supervised learning framework that learns useful features without labels. It works by making different augmented views of the same image close in a representation space, while pushing apart representations of different images.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from nidl.estimators.ssl import SimCLR
from nidl.transforms import MultiViewsTransform
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import MNIST
from torchvision.models import googlenet
from torchvision.transforms import Compose, RandomHorizontalFlip, GaussianBlur, ToTensor, Normalize
from tqdm import tqdm

## Data Preparation

### Load MNIST Dataset

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
contrast_transforms = transforms.Compose(
    [   transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(kernel_size=5),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

# Load MNIST dataset
train_dataset = MNIST(root='./data', train=True, download=True,
                      transform=MultiViewsTransform(contrast_transforms,
                                                        n_views=2))
indices = np.random.choice(np.arange(len(train_dataset)),
                           size=10000)
train_indices = indices[:9000]
val_indices = indices[9000:]
val_dataset = Subset(train_dataset, indices=val_indices)
train_dataset = Subset(train_dataset, indices=train_indices)
test_dataset = MNIST(root='./data', train=False, download=True,
                     transform=transforms)

## Model Definition

In [None]:
# Define the encoder (GoogLeNet)
class GoogLeNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = googlenet(weights=None)
        self.model.fc = nn.Identity()  # Remove the final fully connected layer
        #self.latent_size = 1024  # Set the latent size

    def forward(self, x):
        return self.model(x)

# Initialize the encoder
encoder = GoogLeNetEncoder()

# Define the SimCLR model
simclr = SimCLR(
    encoder=encoder,
    projection_head_kwargs={
        "input_dim": latent_size,
        "hidden_dim": 2 * latent_size,
        "output_dim": latent_size,
    },
    lr=1e-4,
    temperature=0.1,
    weight_decay=0.001,
    max_epochs=10,
    random_state=42
)

## Training Loop

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False)

# Fit
simclr.fit(train_loader, val_loader)

## Evaluation and Visualization

### Learning Curves

In [None]:
# Plot learning curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Learning Curves')
plt.legend()
plt.show()

### Test Set Evaluation

In [None]:
# Evaluate on the test set
simclr.eval()
test_loss = 0.0
with torch.no_grad():
    for batch in val_loader:
        V1, V2 = batch
        outputs = simclr.validation_step((V1, V2), 0)
        test_loss += outputs['loss'].item()
test_loss /= len(val_loader)
print(f"Test Loss: {test_loss:.4f}")