# SimCLR Using Pytorch

In [21]:
import os
import time
import random
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from tqdm import tqdm

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
from torchvision.datasets import CIFAR10

## GPU Setup (Optional)

In [22]:
torch.__version__
# '1.13.0+cu117'

'1.13.0+cu117'

In [23]:
# Get info of all GPU devices
!nvidia-smi

Fri Jul  7 13:54:39 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 536.23                 Driver Version: 536.23       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060      WDDM  | 00000000:22:00.0 Off |                  N/A |
|  0%   49C    P3              36W / 170W |    269MiB / 12288MiB |     37%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [24]:
# Set environment variable with possible device ids
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
print(os.environ["CUDA_VISIBLE_DEVICES"])
# Set device: 0 or 1
# NOTE: indices are not necessarily the ones shown by nvidia-smi
# We need to try them with the cell below
torch.cuda.set_device("cuda:0")

0,1


In [6]:
# Check that the selected device is the desired one
print("Torch version?", torch.__version__)
print("Torchvision version?", torchvision.__version__)
print("Is cuda available?", torch.cuda.is_available())
print("Is cuDNN version:", torch.backends.cudnn.version())
print("cuDNN enabled? ", torch.backends.cudnn.enabled)
print("Device count?", torch.cuda.device_count())
print("Current device?", torch.cuda.current_device())
print("Device name? ", torch.cuda.get_device_name(torch.cuda.current_device()))
x = torch.rand(5, 3)
print(x)

Torch version? 1.13.0+cu117
Torchvision version? 0.14.0+cu117
Is cuda available? True
Is cuDNN version: 8500
cuDNN enabled?  True
Device count? 1
Current device? 0
Device name?  NVIDIA GeForce RTX 3060
tensor([[0.8968, 0.3288, 0.3419],
        [0.7293, 0.7448, 0.0672],
        [0.2199, 0.6071, 0.0307],
        [0.6126, 0.1665, 0.2358],
        [0.4057, 0.3819, 0.6767]])


## Config and Dataset

In [7]:
class Config:
    def __init__(self):
        self.learning_rate = 0.001
        self.num_epochs = 100
        self.batch_size = 64
        self.patience = 10
        self.dropout_p = 0.3
        self.embedding_size = 128
        self.scheduler_step_size = 70
        self.scheduler_gamma = 0.1
        # Other application variables
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.base_path = "./output"
        os.makedirs(self.base_path, exist_ok=True)  # Create the base_path directory if it doesn't exist
        self.best_model_path = os.path.join(self.base_path, "best_model.pth")
        self.last_model_path = os.path.join(self.base_path, "last_model.pth")
        self.learning_plot_path = os.path.join(self.base_path, "learning_curves.png")

In [19]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, _ = self.data[idx]  # We don't need labels for SimCLR

        if self.transform:
            augmented_image_1 = self.transform(image)
            augmented_image_2 = self.transform(image)

        return augmented_image_1, augmented_image_2

## Model

In [20]:
class SiameseNetwork(nn.Module):
    def __init__(self, dropout_p=0.5, embedding_size=128):
        super(SiameseNetwork, self).__init__()
        self.backbone = models.resnet18(weights='ResNet18_Weights.DEFAULT')
        #self.backbone = models.resnet50(weights='ResNet50_Weights.DEFAULT')
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        self.head = nn.Sequential(
            nn.Linear(self.backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(512, embedding_size)
        )
        
    def forward(self, x1, x2):
        x1 = self.backbone(x1)
        x1 = self.head(x1)
        x2 = self.backbone(x2)
        x2 = self.head(x2)
        
        return x1, x2

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        # Normalize the representations along the batch dimension
        z1_norm = (z1 / z1.norm(dim=1)[:, None])
        z2_norm = (z2 / z2.norm(dim=1)[:, None])

        # Compute the cosine similarity matrix 
        # We add the temperature as a scaling factor (usually set to 0.5 or 0.1)
        representations = torch.cat([z1_norm, z2_norm], dim=0)
        similarity_matrix = torch.mm(representations, representations.t()) / self.temperature

        # Compute the loss
        batch_size = z1_norm.shape[0]
        contrastive_loss = torch.nn.functional.cross_entropy(
            similarity_matrix, torch.arange(2*batch_size).to(device)
        )
        
        return contrastive_loss

## Utils

In [15]:
# Prediction function
def predict(model, img1, img2, device="cpu"):
    model.eval()
    model.to(device)
    img1, img2 = img1.to(device), img2.to(device)
    
    with torch.no_grad():
        output1, output2 = model(img1, img2)
        
        # Normalize the output vectors
        output1 = F.normalize(output1, p=2, dim=1)
        output2 = F.normalize(output2, p=2, dim=1)
        
        # Compute the cosine similarities
        similarities = (output1 * output2).sum(dim=1).cpu().numpy()
        
    return similarities

In [14]:
# Plot training function
def plot_training(train_loss_history, val_loss_history, config):
    plt.figure(figsize=(10, 5))
    plt.plot(train_loss_history, label='Train Loss')
    plt.plot(val_loss_history, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(config.learning_plot_path)
    plt.show()

# Plot prediction function
def plot_prediction(img1, img2, similarities, limit=None):
    if limit is not None:
        img1, img2, similarities = img1[:limit], img2[:limit], distances[:limit]

    fig, axs = plt.subplots(len(img1), 2, figsize=(5, 3*len(img1)))
    for i in range(len(img1)):
        img1_i = img1[i].permute(1, 2, 0) if img1[i].shape[0] == 3 else img1[i].squeeze()
        img2_i = img2[i].permute(1, 2, 0) if img2[i].shape[0] == 3 else img2[i].squeeze()

        cmap1 = None if img1[i].shape[0] == 3 else 'gray'
        cmap2 = None if img2[i].shape[0] == 3 else 'gray'

        axs[i, 0].imshow(img1_i.cpu(), cmap=cmap1)
        axs[i, 1].imshow(img2_i.cpu(), cmap=cmap2)
        axs[i, 0].axis('off')
        axs[i, 1].axis('off')
        axs[i, 1].set_title(f"Similarity: {similarities[i]:.2f}")

    plt.tight_layout()
    plt.show()


In [16]:
# Save model function
def save_model(model, save_path):
    torch.save(model.state_dict(), save_path)

# Load model function
def load_model(model, load_path, device):
    model.load_state_dict(torch.load(load_path, map_location=device))
    return model

## Training

In [13]:
# Training function
def train(model, train_loader, val_loader, criterion, optimizer, scheduler, config, output_freq=2):
    model.train()
    train_loss_history = []
    val_loss_history = []
    best_val_loss = float('inf')
    no_improve_epochs = 0
    total_batches = len(train_loader)
    print_every = total_batches // output_freq  # Print every 1/output_freq of total batches

    for epoch in range(config.num_epochs):
        start_time = time.time()
        train_loss = 0
        model.train()
        for i, (img1, img2) in enumerate(train_loader):
            img1, img2 = img1.to(config.device), img2.to(config.device)
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

            # Print training loss
            if i % print_every == 0:
                print(f"Epoch: {epoch+1}, Batch: {i+1}, Loss: {loss.item()}")

        scheduler.step()
        train_loss_history.append(train_loss / len(train_loader))

        #val_loss = validate(model, val_loader, criterion, config)
        val_loss = train_loss / len(train_loader)
        val_loss_history.append(val_loss)
        end_time = time.time()
        epoch_time = end_time - start_time

        print(f"Epoch: {epoch+1}, Loss: {train_loss_history[-1]}, Val Loss: {val_loss}, Time: {epoch_time}s, Learning Rate: {scheduler.get_last_lr()[0]}")

        # Save last model
        save_model(model, config.last_model_path)    

        # Save best model & early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, config.best_model_path)
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= config.patience:
                print("Early stopping")
                break

    return train_loss_history, val_loss_history


# Validation function
def validate(model, val_loader, criterion, config):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for i, (img1, img2) in enumerate(val_loader):
            img1, img2 = img1.to(config.device), img2.to(config.device), labels.to(config.device)
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2)
            total_loss += loss.item()
    return total_loss / len(val_loader)

## Evaluate

In [17]:
# Evaluation
def evaluate(model, test_loader, config, limit=None):
    model = model.to('cpu')
    model.eval()
    positive_distances = []
    negative_distances = []
    labels_list = []

    with torch.no_grad():
        count = 0
        for img1, img2, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
            count += 1
            output1, output2 = model(img1, img2)
            distances = F.pairwise_distance(output1, output2).detach().numpy()
            labels = labels.numpy()

            positive_distances.extend(distances[labels == 1])
            negative_distances.extend(distances[labels == 0])
            labels_list.extend(labels)

            if limit is not None:
                if count > limit:
                    break

    # Compute best threshold
    distances = positive_distances + negative_distances
    labels = np.array([1]*len(positive_distances) + [0]*len(negative_distances))
    fpr, tpr, thresholds = roc_curve(labels, distances, pos_label=0)
    best_threshold = thresholds[np.argmax(tpr - fpr)]

    # Compute histograms
    plt.hist(positive_distances, bins=30, alpha=0.5, color='r', label='Positive pairs')
    plt.hist(negative_distances, bins=30, alpha=0.5, color='b', label='Negative pairs')
    
    # Plot best threshold
    plt.axvline(x=best_threshold, color='g', linestyle='--', label=f'Best threshold: {best_threshold:.2f}')
    plt.legend()
    plt.savefig(config.threshold_plot_path)
    plt.show()

    return best_threshold

## Main Application

In [25]:
# Main function
def main(do_train=False):
    config = Config()

    # CIFAR10 Mean and Std Dev for normalization
    cifar10_mean = [0.4914, 0.4822, 0.4465]
    cifar10_std = [0.2023, 0.1994, 0.2010]
    
    # Define transformations
    data_transforms = Compose([
        RandomResizedCrop(32),
        RandomHorizontalFlip(),
        RandomApply([ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
        RandomGrayscale(p=0.2),
        RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        ToTensor(),
        Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    
    # Initialize datasets
    trainset = ImageDataset(CIFAR10(root='./data', train=True, download=True), transform=data_transforms)
    testset = ImageDataset(CIFAR10(root='./data', train=False, download=True), transform=data_transforms)
    
    # Prepare DataLoader
    train_loader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(testset, batch_size=config.batch_size, shuffle=False, drop_last=True)

    # Instantiate the model, criterion, optimizer, and scheduler
    # We need to have a model also for the case in whic we don't train
    model = SiameseNetwork(embedding_size=config.embedding_size,
                           dropout_p=config.dropout_p).to(config.device)
    
    criterion = ContrastiveLoss()
    optimizer = Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = StepLR(optimizer, step_size=config.scheduler_step_size, gamma=config.scheduler_gamma)

    # TRAIN!
    if do_train:
        train_loss_history, val_loss_history = train(model, train_loader, val_loader, criterion, optimizer, scheduler, config, debug=True)

    # Load the best model
    model = SiameseNetwork(embedding_size=config.embedding_size,
                           dropout_p=config.dropout_p).to(config.device)
    model = load_model(model, config.best_model_path, config.device)

    # Evaluate
    print("Evaluating model...")
    print("Evaluation completed!")
    
    # Test the model
    print("Predicting random batch...")
    test_img1, test_img2, _ = next(iter(test_loader))
    similarities = predict(model, test_img1, test_img2)

    # Plot predictions
    plot_prediction(test_img1.to("cpu"), test_img2.to("cpu"), distances, limit=10)