## Import Libraries

In [1]:
import torch
import torchvision.models
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from sklearn.model_selection import train_test_split
import torch.nn as nn

import random
import os
import numpy as np

## Configuration

In [2]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

## Load and Preprocess the MNIST Dataset

In [3]:
# Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data, eval_data = train_test_split(mnist_data, train_size=0.3, random_state=42, stratify=mnist_data.targets)

# Instance-level dataset
class InstanceDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label

# Create instance datasets
train_dataset = InstanceDataset(train_data)
eval_dataset = InstanceDataset(eval_data)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4)
eval_loader = DataLoader(eval_dataset, batch_size=8, shuffle=False, pin_memory=True, num_workers=4)


## NT-Xent Loss

In [4]:
from pytorch_metric_learning import losses
import torch.nn.functional as F

class NTXentLoss(losses.NTXentLoss):
    def __init__(self, temperature, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
        self.temperature = temperature

    def forward(self, embeddings1, embeddings2, labels=None):
        # Concatenate the embeddings
        embeddings = torch.cat([embeddings1, embeddings2], dim=0)
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(embeddings, p=2, dim=1)

        if labels == None:
            # Self-supervised labels
            labels = torch.arange(feature_vectors_normalized.size(0))
        else:
            # Supervised labels
            labels = torch.cat([labels, labels], dim=0)

        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )

        if labels == None:
            return losses.NTXentLoss(temperature=self.temperature)(logits, torch.squeeze(labels))
        else:
            return losses.SupConLoss(temperature=self.temperature)(logits, torch.squeeze(labels))

# NT-Xent loss
criterion = NTXentLoss(0.5)

## Model Architecture

In [5]:
class Encoder(nn.Module):
    def __init__(self, outputs_dim):
        super().__init__()
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.fc = nn.Identity()

        self.projection = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )

        self.fc = nn.Linear(512, outputs_dim)

    def forward(self, x):
        x = self.model(x)
        projections = self.projection(x)

        # Flatten the output
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x, projections

## Training

In [6]:
# Version 2: Avg time taken: 0.8 seconds for 2 augmentations (w/o ResizedCrop)
def augment_batch(batch_images):
    batch_size, channels, height, width = batch_images.shape

    # Define augmentation transformations using GPU-compatible operations
    aug_transform = transforms.Compose([
        transforms.RandomResizedCrop(28, scale=(0.75, 1.2), ratio=(0.75, 4.0/3.0)),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4)], p=0.6),
        transforms.RandomGrayscale(p=0.2),
    ])

    # Apply transformations directly on the tensor without converting to PIL
    augmented_batch = torch.empty_like(batch_images)  # Preallocate memory for augmented images

    for i in range(batch_size):
        augmented_batch[i] = aug_transform(batch_images[i])
    return augmented_batch.cuda()  # Move the augmented batch to GPU

In [7]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Encoder(outputs_dim=10).to(device)
criterion_cl = NTXentLoss(0.5)
criterion_sl = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training loop
for epoch in range(10):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.cuda()
        labels = labels.cuda()

        aug1 = augment_batch(images).cuda()
        aug2 = augment_batch(images).cuda()

        outputs1, proj1 = model(aug1)
        outputs2, proj2 = model(aug2)

        loss_cl = criterion_cl(proj1, proj2, labels)

        loss_sl_1 = criterion_sl(outputs1, labels)
        loss_sl_2 = criterion_sl(outputs2, labels)

        loss = 0.6 * loss_cl + 0.2 * loss_sl_1 + 0.2 * loss_sl_2
        # loss = 0.5 * loss_sl_1 + 0.5 * loss_sl_2
        loss = loss.mean()

        loss.backward()
        optimizer.step()

    # Evaluation phase
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in eval_loader:
            images = images.cuda()
            labels = labels.cuda()
            outputs, _ = model(images)
            _, preds = torch.max(outputs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f'Epoch [{epoch}/{10}], Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

## Results

In [8]:
# Stage 1: Pre-training with Supervised Contrastive Learning
for epoch in range(10):  # Adjust epochs as needed for pre-training
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.cuda()
        labels = labels.cuda()

        aug1 = augment_batch(images).cuda()
        aug2 = augment_batch(images).cuda()

        outputs1, proj1 = model(aug1)
        outputs2, proj2 = model(aug2)

        loss_cl = criterion_cl(proj1, proj2)

        loss_cl.backward()
        optimizer.step()

    # Evaluation phase for contrastive learning
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in eval_loader:
            images = images.cuda()
            labels = labels.cuda()
            outputs, _ = model(images)
            _, preds = torch.max(outputs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f'Stage 1 - Epoch [{epoch}/{10}], Contrastive Loss: {loss_cl.item():.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')


# Stage 2: Fine-tuning with CrossEntropy Loss
for epoch in range(10):  # Adjust epochs as needed for fine-tuning
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.cuda()
        labels = labels.cuda()

        outputs, _ = model(images)  # Only use the main output for classification

        loss_sl = criterion_sl(outputs, labels)

        loss_sl.backward()
        optimizer.step()

    # Evaluation phase for classification
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in eval_loader:
            images = images.cuda()
            labels = labels.cuda()
            outputs, _ = model(images)
            _, preds = torch.max(outputs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f'Stage 2 - Epoch [{epoch}/{10}], Classification Loss: {loss_sl.item():.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

Stage 1 - Epoch [0/10], Contrastive Loss: 0.0000, Accuracy: 0.1078, Precision: 0.1048, Recall: 0.1078, F1 Score: 0.0788
Stage 1 - Epoch [1/10], Contrastive Loss: 0.0000, Accuracy: 0.1076, Precision: 0.0985, Recall: 0.1076, F1 Score: 0.0754
Stage 1 - Epoch [2/10], Contrastive Loss: 0.0000, Accuracy: 0.1066, Precision: 0.1020, Recall: 0.1066, F1 Score: 0.0740
Stage 1 - Epoch [3/10], Contrastive Loss: 0.0000, Accuracy: 0.1089, Precision: 0.0963, Recall: 0.1089, F1 Score: 0.0776
Stage 1 - Epoch [4/10], Contrastive Loss: 0.0000, Accuracy: 0.1072, Precision: 0.1025, Recall: 0.1072, F1 Score: 0.0774
Stage 1 - Epoch [5/10], Contrastive Loss: 0.0000, Accuracy: 0.1090, Precision: 0.1009, Recall: 0.1090, F1 Score: 0.0768
Stage 1 - Epoch [6/10], Contrastive Loss: 0.0000, Accuracy: 0.1135, Precision: 0.1164, Recall: 0.1135, F1 Score: 0.0825
Stage 1 - Epoch [7/10], Contrastive Loss: 0.0000, Accuracy: 0.1109, Precision: 0.1143, Recall: 0.1109, F1 Score: 0.0801
Stage 1 - Epoch [8/10], Contrastive Loss