## 1. Setup and Imports

In [None]:
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import models

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Paths
DATA_DIR = Path('data')
TRAIN_IMG_DIR = DATA_DIR / 'train_img'
TRAIN_LABELS_PATH = DATA_DIR / 'train_labels.csv'
MODEL_SAVE_DIR = Path('models')
MODEL_SAVE_DIR.mkdir(exist_ok=True)

# Contrastive Learning Hyperparameters
PRETRAIN_EPOCHS = 200
PRETRAIN_BATCH_SIZE = 64  # Larger batch size helps contrastive learning
PRETRAIN_LR = 3e-4
TEMPERATURE = 0.5  # Temperature for NT-Xent loss
PROJECTION_DIM = 128  # Dimension of projection head output

# Fine-tuning Hyperparameters
FINETUNE_EPOCHS = 100
FINETUNE_BATCH_SIZE = 32
FINETUNE_LR = 1e-4

# Model
IMG_SIZE = 224
NUM_CLASSES = 8  # Adjust based on your dataset

print(f"Pretraining: {PRETRAIN_EPOCHS} epochs, batch size {PRETRAIN_BATCH_SIZE}")
print(f"Fine-tuning: {FINETUNE_EPOCHS} epochs, batch size {FINETUNE_BATCH_SIZE}")
print(f"Temperature: {TEMPERATURE}, Projection dim: {PROJECTION_DIM}")

## 3. Strong Augmentation Pipeline for Contrastive Learning

The key to contrastive learning is creating diverse augmented views of the same image.
We use strong augmentations to force the model to learn invariant features.

In [None]:
class ContrastiveTransform:
    """Creates two augmented views of the same image for contrastive learning."""
    
    def __init__(self, img_size=224):
        # Strong augmentation pipeline (SimCLR style)
        self.transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.RandomRotation(degrees=30),
            T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            T.RandomGrayscale(p=0.2),
            T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __call__(self, x):
        # Return two different augmented views
        return self.transform(x), self.transform(x)


class SimpleTransform:
    """Simple transform for fine-tuning/evaluation."""
    
    def __init__(self, img_size=224, augment=False):
        if augment:
            self.transform = T.Compose([
                T.Resize((img_size, img_size)),
                T.RandomHorizontalFlip(p=0.5),
                T.RandomRotation(degrees=15),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = T.Compose([
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __call__(self, x):
        return self.transform(x)

## 4. Dataset Classes

In [None]:
class ContrastiveDataset(Dataset):
    """Dataset for contrastive pretraining (no labels needed)."""
    
    def __init__(self, img_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.image_files = sorted(list(self.img_dir.glob('*.png')) + 
                                   list(self.img_dir.glob('*.jpg')))
        self.transform = transform or ContrastiveTransform()
        print(f"Loaded {len(self.image_files)} images for contrastive learning")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Get two augmented views
        view1, view2 = self.transform(image)
        
        return view1, view2


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""
    
    def __init__(self, img_dir, labels_df, transform=None):
        self.img_dir = Path(img_dir)
        self.labels_df = labels_df
        self.transform = transform or SimpleTransform()
        print(f"Loaded {len(self.labels_df)} labeled images")
    
    def __len__(self):
        return len(self.labels_df)
    
    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        img_path = self.img_dir / f"{row['id']}.png"
        
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        label = int(row['label'])
        
        return image, label

## 5. SimCLR Model Architecture

SimCLR consists of:
1. **Encoder (f)**: Backbone network (ResNet) that extracts features
2. **Projection Head (g)**: MLP that projects features to contrastive space

During pretraining, we train both f and g.
During fine-tuning, we discard g and add a classification head to f.

In [None]:
class SimCLR(nn.Module):
    """SimCLR model for contrastive learning."""
    
    def __init__(self, base_encoder='resnet50', projection_dim=128):
        super(SimCLR, self).__init__()
        
        # Encoder: Use ResNet backbone
        if base_encoder == 'resnet18':
            self.encoder = models.resnet18(weights=None)  # Train from scratch
            feature_dim = 512
        elif base_encoder == 'resnet50':
            self.encoder = models.resnet50(weights=None)
            feature_dim = 2048
        else:
            raise ValueError(f"Unsupported encoder: {base_encoder}")
        
        # Remove the final classification layer
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
        
        # Projection head: MLP with one hidden layer
        self.projection_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, projection_dim)
        )
    
    def forward(self, x):
        # Extract features
        h = self.encoder(x)  # [batch, feature_dim, 1, 1]
        h = torch.flatten(h, 1)  # [batch, feature_dim]
        
        # Project to contrastive space
        z = self.projection_head(h)  # [batch, projection_dim]
        
        return h, z


class Classifier(nn.Module):
    """Classifier for fine-tuning (uses pretrained encoder)."""
    
    def __init__(self, encoder, feature_dim, num_classes, dropout=0.3):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        h = self.encoder(x)
        h = torch.flatten(h, 1)
        logits = self.classifier(h)
        return logits

## 6. NT-Xent Loss (Normalized Temperature-scaled Cross Entropy)

This is the core of contrastive learning:
- For each image, we create two views (positive pair)
- All other images in the batch are negative examples
- Goal: Maximize similarity between positive pairs, minimize similarity with negatives

In [None]:
class NTXentLoss(nn.Module):
    """Normalized Temperature-scaled Cross Entropy Loss."""
    
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, z_i, z_j):
        """
        Args:
            z_i: Projections of first augmented views [batch_size, projection_dim]
            z_j: Projections of second augmented views [batch_size, projection_dim]
        """
        batch_size = z_i.shape[0]
        
        # Normalize embeddings
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)
        
        # Concatenate both views
        z = torch.cat([z_i, z_j], dim=0)  # [2*batch_size, projection_dim]
        
        # Compute similarity matrix
        similarity_matrix = torch.mm(z, z.T)  # [2*batch_size, 2*batch_size]
        
        # Create masks for positive and negative pairs
        # Positive pairs: (i, i+batch_size) and (i+batch_size, i)
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        
        # Remove self-similarities
        similarity_matrix = similarity_matrix[~mask].view(2 * batch_size, -1)
        
        # Positive pairs are at positions [batch_size-1] and [2*batch_size-2]
        positives = torch.cat([
            torch.diag(similarity_matrix, batch_size - 1),
            torch.diag(similarity_matrix, -(batch_size - 1))
        ], dim=0).view(2 * batch_size, 1)
        
        # Negatives are all other entries
        negatives = similarity_matrix
        
        # Compute logits
        logits = torch.cat([positives, negatives], dim=1) / self.temperature
        
        # Labels: positive pair is always at index 0
        labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z.device)
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(logits, labels)
        
        return loss

## 7. Contrastive Pretraining Loop

In [None]:
def pretrain_contrastive(model, dataloader, optimizer, criterion, epochs, device):
    """Pretrain the encoder using contrastive learning."""
    
    model.train()
    history = {'loss': []}
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for view1, view2 in pbar:
            view1 = view1.to(device)
            view2 = view2.to(device)
            
            # Forward pass
            _, z_i = model(view1)
            _, z_j = model(view2)
            
            # Compute contrastive loss
            loss = criterion(z_i, z_j)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_loss = epoch_loss / len(dataloader)
        history['loss'].append(avg_loss)
        
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
    
    return history

## 8. Fine-tuning Loop

In [None]:
def train_classifier(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    """Fine-tune the classifier with labeled data."""
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}")
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100.*train_correct/train_total:.2f}%"
            })
        
        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), MODEL_SAVE_DIR / 'best_contrastive_model.pt')
            print(f"✓ Best model saved! Val Acc: {val_acc:.2f}%")
    
    return history

## 9. Load Data

In [None]:
# Load labels for supervised fine-tuning
labels_df = pd.read_csv(TRAIN_LABELS_PATH)
print(f"Total labeled samples: {len(labels_df)}")
print(f"Label distribution:\n{labels_df['label'].value_counts().sort_index()}")

# Split into train and validation
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(
    labels_df, 
    test_size=0.2, 
    random_state=SEED,
    stratify=labels_df['label']
)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")

## 10. Phase 1: Contrastive Pretraining

**"Before the name is given, the difference is felt."**

We pretrain the encoder without using labels, learning representations through contrastive loss.

In [None]:
# Create contrastive dataset (uses all training images, ignores labels)
contrastive_dataset = ContrastiveDataset(
    img_dir=TRAIN_IMG_DIR,
    transform=ContrastiveTransform(IMG_SIZE)
)

contrastive_loader = DataLoader(
    contrastive_dataset,
    batch_size=PRETRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

print(f"Contrastive batches: {len(contrastive_loader)}")

In [None]:
# Initialize SimCLR model
simclr_model = SimCLR(
    base_encoder='resnet18',  # Can also use 'resnet50' for more capacity
    projection_dim=PROJECTION_DIM
).to(device)

# Optimizer and loss
pretrain_optimizer = torch.optim.Adam(simclr_model.parameters(), lr=PRETRAIN_LR)
contrastive_criterion = NTXentLoss(temperature=TEMPERATURE)

print(f"SimCLR model initialized with ResNet18 backbone")
print(f"Total parameters: {sum(p.numel() for p in simclr_model.parameters()):,}")

In [None]:
# Pretrain the model
print("Starting contrastive pretraining...\n")

pretrain_history = pretrain_contrastive(
    model=simclr_model,
    dataloader=contrastive_loader,
    optimizer=pretrain_optimizer,
    criterion=contrastive_criterion,
    epochs=PRETRAIN_EPOCHS,
    device=device
)

# Save the pretrained encoder
torch.save(simclr_model.encoder.state_dict(), MODEL_SAVE_DIR / 'contrastive_encoder.pt')
print(f"\n✓ Pretrained encoder saved to {MODEL_SAVE_DIR / 'contrastive_encoder.pt'}")

In [None]:
# Plot pretraining loss
plt.figure(figsize=(10, 6))
plt.plot(pretrain_history['loss'], label='Contrastive Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Contrastive Pretraining Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 11. Phase 2: Supervised Fine-tuning

**"Upon this geometry of understanding, the house of classification, sturdy it stands."**

Now we use the pretrained encoder for classification with labels.

In [None]:
# Create supervised datasets
train_dataset = SupervisedDataset(
    img_dir=TRAIN_IMG_DIR,
    labels_df=train_df,
    transform=SimpleTransform(IMG_SIZE, augment=True)
)

val_dataset = SupervisedDataset(
    img_dir=TRAIN_IMG_DIR,
    labels_df=val_df,
    transform=SimpleTransform(IMG_SIZE, augment=False)
)

train_loader = DataLoader(
    train_dataset,
    batch_size=FINETUNE_BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=FINETUNE_BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

In [None]:
# Create classifier with pretrained encoder
classifier = Classifier(
    encoder=simclr_model.encoder,  # Use the pretrained encoder
    feature_dim=512,  # ResNet18 feature dimension
    num_classes=NUM_CLASSES,
    dropout=0.3
).to(device)

# Optimizer and loss for fine-tuning
finetune_optimizer = torch.optim.Adam(classifier.parameters(), lr=FINETUNE_LR)
classification_criterion = nn.CrossEntropyLoss()

print(f"Classifier initialized with pretrained encoder")
print(f"Total parameters: {sum(p.numel() for p in classifier.parameters()):,}")

In [None]:
# Fine-tune the classifier
print("Starting supervised fine-tuning...\n")

finetune_history = train_classifier(
    model=classifier,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=finetune_optimizer,
    criterion=classification_criterion,
    epochs=FINETUNE_EPOCHS,
    device=device
)

print("\n✓ Fine-tuning complete!")

In [None]:
# Plot fine-tuning results
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(finetune_history['train_loss'], label='Train Loss')
axes[0].plot(finetune_history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Fine-tuning Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(finetune_history['train_acc'], label='Train Accuracy')
axes[1].plot(finetune_history['val_acc'], label='Val Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Fine-tuning Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Best validation accuracy: {max(finetune_history['val_acc']):.2f}%")

## 12. Comparison: Pretrained vs From-Scratch

To verify the benefit of contrastive pretraining, let's train a model from scratch.

In [None]:
# Create a fresh encoder (not pretrained)
scratch_encoder = models.resnet18(weights=None)
scratch_encoder = nn.Sequential(*list(scratch_encoder.children())[:-1])

scratch_classifier = Classifier(
    encoder=scratch_encoder,
    feature_dim=512,
    num_classes=NUM_CLASSES,
    dropout=0.3
).to(device)

scratch_optimizer = torch.optim.Adam(scratch_classifier.parameters(), lr=FINETUNE_LR)

print("Training classifier from scratch (no pretraining)...\n")

scratch_history = train_classifier(
    model=scratch_classifier,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=scratch_optimizer,
    criterion=classification_criterion,
    epochs=FINETUNE_EPOCHS,
    device=device
)

print("\n✓ From-scratch training complete!")

In [None]:
# Compare results
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Validation loss comparison
axes[0].plot(finetune_history['val_loss'], label='Contrastive Pretrained', linewidth=2)
axes[0].plot(scratch_history['val_loss'], label='From Scratch', linewidth=2, linestyle='--')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Loss')
axes[0].set_title('Validation Loss Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Validation accuracy comparison
axes[1].plot(finetune_history['val_acc'], label='Contrastive Pretrained', linewidth=2)
axes[1].plot(scratch_history['val_acc'], label='From Scratch', linewidth=2, linestyle='--')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Accuracy (%)')
axes[1].set_title('Validation Accuracy Comparison')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nFinal Results:")
print(f"Contrastive Pretrained - Best Val Acc: {max(finetune_history['val_acc']):.2f}%")
print(f"From Scratch - Best Val Acc: {max(scratch_history['val_acc']):.2f}%")
print(f"Improvement: {max(finetune_history['val_acc']) - max(scratch_history['val_acc']):.2f}%")

## 13. Summary

**Contrastive Learning Benefits:**
1. Learns robust, domain-specific features without labels
2. Better generalization compared to training from scratch
3. More sample-efficient when labeled data is limited
4. Creates meaningful geometric structure in embedding space

**Next Steps:**
- Experiment with different augmentation strategies
- Try different backbone architectures (ResNet50, EfficientNet)
- Adjust temperature and projection dimension
- Use larger batch sizes if GPU memory allows (helps contrastive learning)
- Consider adding mask information as additional augmentation