# Contrastive Learning with MNIST Using Pytorch

In [1]:
import os
import time

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from tqdm import tqdm

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR

## GPU Setup (Optional)

In [2]:
torch.__version__
# '1.13.0+cu117'

'1.13.0+cu117'

In [3]:
# Get info of all GPU devices
!nvidia-smi

Thu Jul  6 16:13:35 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 536.23                 Driver Version: 536.23       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060      WDDM  | 00000000:22:00.0 Off |                  N/A |
|  0%   40C    P8              14W / 170W |      0MiB / 12288MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
# Set environment variable with possible device ids
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
print(os.environ["CUDA_VISIBLE_DEVICES"])
# Set device: 0 or 1
# NOTE: indices are not necessarily the ones shown by nvidia-smi
# We need to try them with the cell below
torch.cuda.set_device("cuda:0")

0,1


In [5]:
# Check that the selected device is the desired one
print("Torch version?", torch.__version__)
print("Torchvision version?", torchvision.__version__)
print("Is cuda available?", torch.cuda.is_available())
print("Is cuDNN version:", torch.backends.cudnn.version())
print("cuDNN enabled? ", torch.backends.cudnn.enabled)
print("Device count?", torch.cuda.device_count())
print("Current device?", torch.cuda.current_device())
print("Device name? ", torch.cuda.get_device_name(torch.cuda.current_device()))
x = torch.rand(5, 3)
print(x)

Torch version? 1.13.0+cu117
Torchvision version? 0.14.0+cu117
Is cuda available? True
Is cuDNN version: 8500
cuDNN enabled?  True
Device count? 1
Current device? 0
Device name?  NVIDIA GeForce RTX 3060
tensor([[0.9684, 0.1714, 0.1893],
        [0.9239, 0.8133, 0.0158],
        [0.5785, 0.6900, 0.7949],
        [0.8359, 0.5665, 0.9882],
        [0.5788, 0.4752, 0.3651]])


## Config and Dataset

In [6]:
# Configuration class
class Config:
    def __init__(self):
        # Hyperparameters
        self.learning_rate = 0.001
        self.num_epochs = 100
        self.batch_size = 64
        self.patience = 5 # For early stopping
        self.dropout_p = 0.3
        self.embedding_size = 48 # Size of the embedding/feature vectors
        self.scheduler_step_size = 30  # Step size for the learning rate scheduler
        self.scheduler_gamma = 0.1  # Gamma for the learning rate scheduler: every step_size lr is multiplied by gamma
        self.img_shape = (28, 28)  # Not used in this application
        # Other application variables
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.base_path = "./output"
        os.makedirs(self.base_path, exist_ok=True)  # Create the base_path directory if it doesn't exist
        self.best_model_path = os.path.join(self.base_path, "best_model.pth")
        self.last_model_path = os.path.join(self.base_path, "last_model.pth")
        self.learning_plot_path = os.path.join(self.base_path, "learning_curves.png")
        self.threshold_plot_path = os.path.join(self.base_path, "threshold_histogram.png")

In [7]:
# Dataset generator class
class PairDataset(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.transform = transforms.ToTensor()

    def __getitem__(self, index):
        img1, label1 = self.mnist_dataset[index]
        index2 = torch.randint(len(self.mnist_dataset), size=(1,)).item()
        img2, label2 = self.mnist_dataset[index2]
        return img1, img2, torch.tensor(int(label1 == label2), dtype=torch.float32)

    def __len__(self):
        return len(self.mnist_dataset)

## Definitions: Siamese Network, Train, Predict

In [14]:
# Siamese Network
class SiameseNetworkResnet(nn.Module):
    def __init__(self, embedding_size=128, dropout_p=0.3, freeze_backbone=False):
        super(SiameseNetwork, self).__init__()
        #self.backbone = models.resnet18(pretrained=True)
        self.backbone = models.resnet18(weights='ResNet18_Weights.DEFAULT')
        self.embedding_size = embedding_size
        self.dropout_p = dropout_p

        # Remove the fully connected layer
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        self.head = nn.Sequential(
            #nn.Linear(self.backbone[-1].out_features, 512),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(self.dropout_p),
            nn.Linear(256, self.embedding_size)
        )

    def forward_one(self, x):
        x = self.backbone(x)
        x = x.view(x.size()[0], -1)
        x = self.head(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

class SiameseNetwork(nn.Module):
    def __init__(self, embedding_size=48, dropout_p=0.3):
        super(SiameseNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)  # Changed input channels to 1 for grayscale images
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.3)

        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout(0.3)

        #self.fc = nn.Linear(64, embedding_size)  # if adaptive_avg_pool2d used below
        self.fc = nn.Linear(64*7*7, embedding_size) # 28/2/2 = 7

    def forward_one(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.dropout1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.dropout2(x)

        #x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2
        
# Contrastive loss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # pairwise_distance(): equivalent to euclidean_distance:
        # torch.sqrt(((output1 - output2) ** 2).sum(dim=1))
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

# Save model function
def save_model(model, save_path):
    torch.save(model.state_dict(), save_path)

# Load model function
def load_model(model, load_path, device):
    model.load_state_dict(torch.load(load_path, map_location=device))
    return model
        
# Training function
def train(model, train_loader, val_loader, criterion, optimizer, scheduler, config, output_freq=2):
    model.train()
    train_loss_history = []
    val_loss_history = []
    best_val_loss = float('inf')
    no_improve_epochs = 0
    total_batches = len(train_loader)
    print_every = total_batches // output_freq  # Print every 1/output_freq of total batches

    for epoch in range(config.num_epochs):
        start_time = time.time()
        train_loss = 0
        model.train()
        for i, (img1, img2, labels) in enumerate(train_loader):
            img1, img2, labels = img1.to(config.device), img2.to(config.device), labels.to(config.device)
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # Print training loss
            if i % print_every == 0:
                print(f"Epoch: {epoch+1}, Batch: {i+1}, Loss: {loss.item()}")

        scheduler.step()
        train_loss_history.append(train_loss / len(train_loader))

        val_loss = validate(model, val_loader, criterion, config)
        val_loss_history.append(val_loss)
        end_time = time.time()
        epoch_time = end_time - start_time

        print(f"Epoch: {epoch+1}, Loss: {train_loss_history[-1]}, Val Loss: {val_loss}, Time: {epoch_time}s, Learning Rate: {scheduler.get_last_lr()[0]}")

        # Save last model
        save_model(model, config.last_model_path)    

        # Save best model & early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, config.best_model_path)
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= config.patience:
                print("Early stopping")
                break

    return train_loss_history, val_loss_history


# Validation function
def validate(model, val_loader, criterion, config):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for i, (img1, img2, labels) in enumerate(val_loader):
            img1, img2, labels = img1.to(config.device), img2.to(config.device), labels.to(config.device)
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2, labels)
            total_loss += loss.item()
    return total_loss / len(val_loader)

In [15]:
# Plot training function
def plot_training(train_loss_history, val_loss_history, config):
    plt.figure(figsize=(10, 5))
    plt.plot(train_loss_history, label='Train Loss')
    plt.plot(val_loss_history, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(config.learning_plot_path)
    plt.show()
    
# Prediction function
def predict(model, img1, img2, device="cpu"):
    model.eval()
    model.to(device)
    img1, img2 = img1.to(device), img2.to(device)
    with torch.no_grad():
        output1, output2 = model(img1, img2)
        distances = F.pairwise_distance(output1, output2).to("cpu").numpy()
        return distances

# Plot prediction function
def plot_prediction(img1, img2, distances, limit=None):
    if limit is not None:
        img1, img2, distances = img1[:limit], img2[:limit], distances[:limit]

    fig, axs = plt.subplots(len(img1), 2, figsize=(5, 3*len(img1)))
    for i in range(len(img1)):
        img1_i = img1[i].squeeze().permute(1, 2, 0) if img1[i].shape[0] == 3 else img1[i].squeeze()
        img2_i = img2[i].squeeze().permute(1, 2, 0) if img2[i].shape[0] == 3 else img2[i].squeeze()
        cmap1 = 'gray' if img1[i].shape[0] == 1 else None
        cmap2 = 'gray' if img2[i].shape[0] == 1 else None
        axs[i, 0].imshow(img1_i, cmap=cmap1)
        axs[i, 1].imshow(img2_i, cmap=cmap2)
        axs[i, 1].set_title(f"Distance: {distances[i].item():.2f}")
        #axs[i, 0].imshow(img1[i].squeeze()[0], cmap='gray')  # Select the first channel
        #axs[i, 1].imshow(img2[i].squeeze()[0], cmap='gray')  # Select the first channel
        axs[i, 1].set_title(f"Distance: {distances[i].item():.2f}")
    plt.show()

In [16]:
# Evaluation
def evaluate(model, test_loader, config, limit=None):
    model = model.to('cpu')
    model.eval()
    positive_distances = []
    negative_distances = []
    labels_list = []

    with torch.no_grad():
        count = 0
        for img1, img2, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
            count += 1
            output1, output2 = model(img1, img2)
            distances = F.pairwise_distance(output1, output2).detach().numpy()
            labels = labels.numpy()

            positive_distances.extend(distances[labels == 1])
            negative_distances.extend(distances[labels == 0])
            labels_list.extend(labels)

            if limit is not None:
                if count > limit:
                    break

    # Compute best threshold
    distances = positive_distances + negative_distances
    labels = np.array([1]*len(positive_distances) + [0]*len(negative_distances))
    fpr, tpr, thresholds = roc_curve(labels, distances, pos_label=0)
    best_threshold = thresholds[np.argmax(tpr - fpr)]

    # Compute histograms
    plt.hist(positive_distances, bins=30, alpha=0.5, color='r', label='Positive pairs')
    plt.hist(negative_distances, bins=30, alpha=0.5, color='b', label='Negative pairs')
    
    # Plot best threshold
    plt.axvline(x=best_threshold, color='g', linestyle='--', label=f'Best threshold: {best_threshold:.2f}')
    plt.legend()
    plt.savefig(config.threshold_plot_path)
    plt.show()

    return best_threshold

## Main Application

In [17]:
# Main function
def main(do_train=True):
    config = Config()

    # Choose the model and the data transformations
    resnet = False

    # Define the transformations for the training set
    train_transform = transforms.Compose([
        transforms.Resize((30, 30)),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(28),
    ])
    
    # Define the transformations for the validation and test sets
    val_test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if resnet:
        # Define the transformations for the training set
        train_transform = transforms.Compose([
            transforms.Resize((230, 230)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Convert grayscale to RGB
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Define the transformations for the validation and test sets
        val_test_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Convert grayscale to RGB
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    # Load MNIST dataset
    mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
    print("Dataset obtained!")

    # Split the dataset into train, validation, and test sets
    train_size = int(0.7 * len(mnist_dataset))
    val_size = int(0.15 * len(mnist_dataset))
    test_size = len(mnist_dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(mnist_dataset, [train_size, val_size, test_size])
    print("Dataset splits created!")

    # Apply the appropriate transformations to the validation and test sets
    val_dataset.dataset.transform = val_test_transform
    test_dataset.dataset.transform = val_test_transform

    # Create PairDataset for each split
    train_dataset = PairDataset(train_dataset)
    val_dataset = PairDataset(val_dataset)
    test_dataset = PairDataset(test_dataset)

    # Create DataLoader for each split
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
    print("Dataset loaders created!")

    # Instantiate the model, criterion, optimizer, and scheduler
    model = SiameseNetwork(embedding_size=config.embedding_size,
                           dropout_p=config.dropout_p).to(config.device)
    if resnet:
        model = SiameseNetworkResnet(embedding_size=config.embedding_size,
                                     dropout_p=config.dropout_p).to(config.device)    
    criterion = ContrastiveLoss()
    optimizer = Adam(model.parameters(), lr=config.learning_rate)
    scheduler = StepLR(optimizer, step_size=config.scheduler_step_size, gamma=config.scheduler_gamma)
    print("Model instantiated!")
    
    # Train the model
    if do_train:
        print("Starting training...")
        train_loss_history, val_loss_history = train(model, train_loader, val_loader, criterion, optimizer, scheduler, config)
        print("Training completed!")
    
        # Plot training history
        plot_training(train_loss_history, val_loss_history, config)

    # Load the best model
    model = SiameseNetwork(embedding_size=config.embedding_size,
                           dropout_p=config.dropout_p).to(config.device)
    if resnet:
        model = SiameseNetworkResnet(embedding_size=config.embedding_size,
                                     dropout_p=config.dropout_p).to(config.device)    
    model = load_model(model, config.best_model_path, config.device)

    # Evaluate
    print("Evaluating model...")
    best_threshold = evaluate(model, test_loader, config, limit=None)
    print(f"Best threshold: {best_threshold}")
    print("Evaluation completed!")
    
    # Test the model
    print("Predicting random batch...")
    test_img1, test_img2, _ = next(iter(test_loader))
    distances = predict(model, test_img1, test_img2)

    # Plot predictions
    plot_prediction(test_img1.to("cpu"), test_img2.to("cpu"), distances, limit=10)

In [18]:
#if __name__ == "__main__":
#    main()

In [None]:
main(do_train=True)
#main(do_train=False)

Dataset obtained!
Dataset splits created!
Dataset loaders created!
Model instantiated!
Starting training...
Epoch: 1, Batch: 1, Loss: 0.4373529553413391
Epoch: 1, Batch: 329, Loss: 0.07190398871898651
Epoch: 1, Batch: 657, Loss: 0.05930083245038986
Epoch: 1, Loss: 0.09279891693060212, Val Loss: 0.09919741970199203, Time: 35.2789409160614s, Learning Rate: 0.001
Epoch: 2, Batch: 1, Loss: 0.08687381446361542
Epoch: 2, Batch: 329, Loss: 0.12331025302410126
Epoch: 2, Batch: 657, Loss: 0.1026899442076683
Epoch: 2, Loss: 0.09080783464581546, Val Loss: 0.09652836104286901, Time: 31.365543127059937s, Learning Rate: 0.001
Epoch: 3, Batch: 1, Loss: 0.12313893437385559
Epoch: 3, Batch: 329, Loss: 0.12194577604532242
