## STEP 1: Import All Required Libraries and Set Up the Environment

In [None]:
# PyTorch core libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Torchvision (datasets, transforms, ResNet backbone)
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet50

# Utility libraries
import numpy as np
from tqdm import tqdm
import random
import math
import time

# ================================================================
# Check GPU availability
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ================================================================
# Set random seeds for reproducibility
# ================================================================
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Ensures deterministic behavior (may slow down training slightly)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("Environment setup complete.")

## STEP 2: Define MoCo-v2 Style Data Augmentations

This class builds the MoCo-v2 data augmentation pipeline and applies it twice to the same input image.
The result is two different strongly-augmented views of the same image, which are used as the positive pair in contrastive self-supervised learning.

In other words:

- Take one image
- Apply strong random augmentations twice
- Produce two correlated but different views
- Feed them to the contrastive model (query/key)

This encourages the model to learn invariance to color, crop, blur, distortion, etc., which is essential for MoCo-style SSL.

In [None]:
# These augmentations generate *two strongly augmented views* of each image.
# They are essential for contrastive learning because the model must learn
# invariances to color, crop, blur, distortion, etc.

class MoCoTransform:
    """Applies two strong augmentations to the same image."""

    def __init__(self, image_size=96):
        # MoCo-v2 Augmentation pipeline:
        self.base_transform = T.Compose([
            # 1. Random resized crop (strong spatial diversity)
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),

            # 2. Random horizontal flip
            T.RandomHorizontalFlip(p=0.5),

            # 3. Color jitter (brightness, contrast, saturation, hue)
            T.RandomApply(
                [T.ColorJitter(0.4, 0.4, 0.4, 0.1)],
                p=0.8
            ),

            # 4. Random grayscale conversion
            T.RandomGrayscale(p=0.2),

            # Mild solarization (improves invariance)
            T.RandomApply(
                [T.RandomSolarize(128)],
                p=0.2
            ),

            # 5. Gaussian blur (MoCo v2 uses it heavily)
            T.RandomApply(
                [T.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))],
                p=0.5
            ),

            # 6. Convert PIL → Tensor
            T.ToTensor(),

            # 7. Normalize to standard ImageNet stats
            T.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)
            ),
        ])

    def __call__(self, x):
        # Return TWO differently augmented views of the same image
        return self.base_transform(x), self.base_transform(x)


print("MoCo-v2 augmentations initialized successfully.")

## STEP 3: Implement the Contrastive Loss (InfoNCE) + Covariance Regularization Loss

This module defines a hybrid loss for our self-supervised model that combines:

1. MoCo-style InfoNCE contrastive loss

This part forces the model to pull together the embeddings of two augmented views of the same image (q and k⁺) and push them away from a large set of negative samples stored in a FIFO queue.
This teaches invariance to augmentations and builds discriminative features.

2. VICReg-style covariance regularization

In addition to contrastive learning, this part penalizes correlations between embedding dimensions so each feature dimension carries unique information.
This reduces redundancy and improves representation quality.

We can choose whether this covariance term uses:

- only the positive pair (q, k⁺), or
- (optionally) samples from the negative queue as well.

3. Final output

The function returns:
- the combined loss used for backprop,
- the pure contrastive component,
- the pure covariance regularization component.

So overall, this block creates a hybrid MoCo + VICReg loss that encourages:
- invariance to augmentations (via contrastive loss)
- diversity across feature dimensions (via covariance regularization)

This helps our model learn richer, more stable representations.

In [None]:
# This includes two main components:
# 1. MoCo InfoNCE contrastive loss.
# 2. VICReg-style covariance loss to decorrelate embedding dimensions.

# We support:
#   - Using only (q, k+) for covariance loss
#   - OR including samples from the negative queue (optional)
# via the flag `use_queue_for_cov`.


class MoCoCovLoss(nn.Module):
    def __init__(self, temperature=0.2, lambda_cov=1.0, use_queue_for_cov=False):
        """
        temperature: contrastive temperature parameter (tau)
        lambda_cov: weight for covariance regularization
        use_queue_for_cov: whether to include queue samples in covariance loss
        """
        super().__init__()
        self.tau = temperature
        self.lambda_cov = lambda_cov
        self.use_queue_for_cov = use_queue_for_cov

    def forward(self, q, k_pos, queue_neg):
        """
        q: (B, D) query embeddings
        k_pos: (B, D) positive key embeddings
        queue_neg: (D, K) negative keys from FIFO queue
        """
        # ------------------------------------------------------------
        # 1. Compute InfoNCE contrastive loss
        # ------------------------------------------------------------
        # Positive logits: q · k+
        pos_logits = torch.sum(q * k_pos, dim=1, keepdim=True)  # (B, 1)

        # Negative logits: q · K
        neg_logits = torch.einsum('nd,dk->nk', q, queue_neg)     # (B, K)

        # Scale by temperature
        pos_logits = pos_logits / self.tau
        neg_logits = neg_logits / self.tau

        # Concatenate pos + neg logits
        logits = torch.cat([pos_logits, neg_logits], dim=1)      # (B, 1+K)

        # Labels: positive is always index 0
        labels = torch.zeros(q.size(0), dtype=torch.long, device=q.device)

        # Cross-entropy loss for contrastive learning
        loss_contrast = F.cross_entropy(logits, labels)

        # ------------------------------------------------------------
        # 2. Compute VICReg covariance regularization loss
        # ------------------------------------------------------------
        # Build feature matrix Z = [q; k+] or optionally add queue samples
        if self.use_queue_for_cov:
            # Take a random subset of negatives from the queue
            K = queue_neg.shape[1]
            num_samples = min(1024, K)  # limit sample size
            idx = torch.randperm(K, device=q.device)[:num_samples]
            queue_subset = queue_neg[:, idx].T  # shape: (num_samples, D)

            Z = torch.cat([q, k_pos, queue_subset], dim=0)  # (B + B + num_samples, D)
        else:
            Z = torch.cat([q, k_pos], dim=0)  # (2B, D)

        # Compute covariance matrix C = Cov(Z)
        Z = Z - Z.mean(dim=0, keepdim=True)      # center
        C = (Z.T @ Z) / (Z.size(0) - 1)          # covariance: (D, D)

        # Penalize off-diagonal elements of covariance
        diag_mask = torch.eye(C.size(0), device=C.device).bool()
        off_diag = C[~diag_mask]
        loss_cov = (off_diag ** 2).sum()

        # ------------------------------------------------------------
        # 3. Combine losses
        # ------------------------------------------------------------
        loss = loss_contrast + self.lambda_cov * loss_cov

        return loss, loss_contrast, loss_cov


print("MoCo InfoNCE + covariance loss module ready.")

## STEP 4 — MoCo-v2 Model Setup with 128-Dim Projection Head and Negative Queue

This block builds the full MoCo-v2 backbone we will use for self-supervised learning. It creates.
This entire block assembles a fully functional MoCo-v2 encoder (query + key), builds a matching 128-dim projection head, and initializes the large negative queue used for contrastive learning.

1. A Projection Head (2048 → 128)

A small MLP that takes ResNet-50 features (2048-dim) and maps them into a 128-dim contrastive space, normalized for InfoNCE training.
This is the standard MoCo-v2 projection design.

2. A Dual-Encoder MoCo Model (query + key)

The model contains:

- encoder_q – the normal ResNet-50 backbone + projection head (trainable)

- encoder_k – the momentum encoder (no gradient updates)

At initialization, encoder_k is copied from encoder_q, ensuring they start identical.
This is essential for stability in MoCo-style contrastive learning.

Both encoders output 128-dim embeddings, guaranteeing shape consistency with the queue and the loss function.

3. A 128-Dim FIFO Queue of Negative Samples

A memory queue that stores thousands of past key embeddings and acts as the large negative sample bank for contrastive learning.

As new batches arrive:

- new key embeddings are enqueued
- the oldest ones are removed
This keeps a constantly refreshed pool of negatives.

4. GPU Instantiation

Finally, the model and queue are moved to GPU, ready for SSL training.



In [None]:
# ======================================================================
# FIXED & CORRECTED MoCo Model + Projection Head + Queue (128-dim)
# ======================================================================
# This block ensures that:
#   - encoder_q outputs 128-d features
#   - encoder_k outputs 128-d features
#   - queue stores 128-d features
#   - q, k, queue embeddings perfectly match in dimension
# ======================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


# ----------------------------------------------------------
# 1. Projection Head: 2048 → 2048 → 128 (MoCo-v2 style MLP)
# ----------------------------------------------------------
class ProjectionMLP(nn.Module):
    """
    Standard MoCo-v2 projection head:
        - Input: 2048-dim backbone features (ResNet-50)
        - Output: 128-dim normalized projection for contrastive learning
    """
    def __init__(self, dim_in=2048, dim_hidden=2048, dim_out=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim_in, dim_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(dim_hidden, dim_out)
        )

    def forward(self, x):
        x = self.mlp(x)
        return F.normalize(x, dim=1)


# ----------------------------------------------------------
# 2. MoCo Model (ResNet-50 backbone + projection head)
# ----------------------------------------------------------
class MoCoResNet50(nn.Module):
    """
    MoCo-v2 architecture:
      - encoder_q: trainable
      - encoder_k: EMA momentum encoder (no grads)
    """
    def __init__(self, dim_feature=128):
        super().__init__()

        # Load ResNet-50 backbone
        backbone = models.resnet50(weights=None)   # must be randomly initialized
        backbone.fc = nn.Identity()               # remove classification layer

        # Query encoder: backbone + projection head
        self.encoder_q = nn.Sequential(
            backbone,
            ProjectionMLP(dim_in=2048, dim_hidden=2048, dim_out=dim_feature)
        )

        # Key encoder (EMA)
        backbone_k = models.resnet50(weights=None)
        backbone_k.fc = nn.Identity()
        self.encoder_k = nn.Sequential(
            backbone_k,
            ProjectionMLP(dim_in=2048, dim_hidden=2048, dim_out=dim_feature)
        )

        # Initialize key encoder to match query encoder
        for param_q, param_k in zip(self.encoder_q.parameters(),
                                    self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

    # no forward() needed, we call encoder_q and encoder_k separately


# ----------------------------------------------------------
# 3. FIFO Queue for Negative Samples (Dimension = 128)
# ----------------------------------------------------------
class MoCoQueue(nn.Module):
    """
    FIFO memory queue storing 128-d embeddings.
    """
    def __init__(self, size=65536, dim=128):
        super().__init__()
        self.size = size
        self.dim = dim

        # queue: (size, dim)
        self.register_buffer("queue", torch.randn(size, dim))
        self.queue = F.normalize(self.queue, dim=1)

        self.register_buffer("ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def enqueue(self, keys):
        """
        Add new keys to queue, remove oldest entries.
        keys: (batch_size, dim)
        """
        batch_size = keys.shape[0]
        ptr = int(self.ptr)

        # Replace entries
        if ptr + batch_size <= self.size:
            self.queue[ptr:ptr+batch_size] = keys
            self.ptr[0] = (ptr + batch_size) % self.size
        else:
            # wrap around
            n1 = self.size - ptr
            n2 = batch_size - n1
            self.queue[ptr:] = keys[:n1]
            self.queue[:n2] = keys[n1:]
            self.ptr[0] = n2


# ----------------------------------------------------------
# 4. Instantiate Model + Queue (DIM=128) on GPU
# ----------------------------------------------------------
model = MoCoResNet50(dim_feature=128).cuda()
# queue = MoCoQueue(size=65536, dim=128).cuda()
queue = MoCoQueue(size=4096, dim=128).cuda()


print(">>> MoCo-ResNet50 model + 128-d queue initialized successfully!")

## STEP 5: Local Dataset Preparation Using Snapshot Download and Random Subsampling


This block configures and prepares the unlabeled training dataset for our self-supervised pretraining. Instead of using the full 500k–700k Hugging Face dataset (which is large and slow to load), it downloads only the ZIP shards, extracts metadata, and saves a random subset of x images locally for fast pipeline testing.

Here’s what this code accomplishes at a high level:

1. Choose the data source

We specify that training images should come from a local x-image subset rather than the full Hugging Face dataset. This makes early debugging and model testing much faster.

2. Download raw dataset ZIP files

Using snapshot_download(), the code fetches only the zipped shards from the Hugging Face repo. This avoids downloading unnecessary metadata and keeps the process efficient.

3. Parse the ZIP files to see all image names

Instead of extracting everything, the code looks inside each ZIP and gathers a list of every .jpg in the dataset.

4. Randomly select x images

From the full list of available images, the script chooses a random sample of 30k images.
This subset is used to quickly test our pipeline before scaling up to the full dataset.

5. Extract and save the selected images locally

For each selected image:

read it directly from its ZIP file

This block gives us a clean folder of x .jpg files ready for our dataloader.


### Step 5.a

In [None]:
# ===============================================================
# DATA SOURCE CONFIGURATION
# ===============================================================
# Choose where training images come from.
# For pipeline testing, use "local".
# Later, we can change this to "hf_full" or "local_500k".

DATA_SOURCE = "local_30k"    # options: "local_15k", "hf_full"

# Path to local subset in /content or Drive
LOCAL_DATA_DIR = "/content/local_pretrain_30k"

print("DATA SOURCE =", DATA_SOURCE)


### Step 5.b

In [None]:
# ===============================================================
# STEP 10 — FAST DOWNLOAD USING snapshot_download()  (3–5 minutes)
# ===============================================================

from huggingface_hub import snapshot_download
import zipfile
import os
from tqdm import tqdm

LOCAL_DATA_DIR = "/content/local_pretrain_30k"
os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

if DATA_SOURCE == "local_30k":

    print("✓ Using snapshot_download() to fetch dataset shards...")

    repo_dir = snapshot_download(
        repo_id="tsbpp/fall2025_deeplearning",
        repo_type="dataset",
        local_dir="/content/hf_raw",
        allow_patterns=["*.zip"],       # only zip files
    )

    print("✓ Download complete. Extracting shards...")

    zip_files = sorted([f for f in os.listdir(repo_dir) if f.endswith(".zip")])

    # extract only what we need
    extracted_images = []
    for z in tqdm(zip_files, desc="Extracting ZIPs"):
        zip_path = os.path.join(repo_dir, z)
        with zipfile.ZipFile(zip_path, "r") as zf:
            # list all image file names inside the zip
            for name in zf.namelist():
                if name.endswith(".jpg"):
                    extracted_images.append((zip_path, name))

    # pick random 30000
    import random
    random.seed(42)
    selected = random.sample(extracted_images, 30000)

    print("Saving 30000 images locally...")
    for i, (zip_path, name) in enumerate(tqdm(selected)):
        with zipfile.ZipFile(zip_path, "r") as zf:
            data = zf.read(name)
        out_path = os.path.join(LOCAL_DATA_DIR, f"img_{i:05d}.jpg")
        with open(out_path, "wb") as f:
            f.write(data)

    print(f"✓ Saved 30000 images to {LOCAL_DATA_DIR}")


## STEP 6: Building a Local SSL Dataset and DataLoader for MoCo Training


This block constructs the dataset and dataloader that feed training images into our MoCo pipeline. It loads the locally saved JPEG images (e.g., the 30k subset we created earlier), applies the MoCo augmentation pipeline, and prepares batches for contrastive training.

Here’s the high-level purpose:

1. Custom Dataset for Local Images

The LocalMoCoDataset class:
- scans a directory of .jpg files,
- loads each image from disk,
- applies our MoCoTransform to produce two augmented views of the same image,
- returns these paired views for contrastive learning.

This turns our folder of unlabeled images into the correct supervised-by-augmentations structure needed for MoCo-style SSL.

2. Create a DataLoader for Efficient Training

The DataLoader:
- batches the augmented pairs (e.g., 128 images per batch),
- shuffles the dataset,
- preloads data with workers (num_workers=2),
- drops the last incomplete batch for consistency.

This ensures that our SSL training loop receives a steady stream of (x1, x2) pairs efficiently and with proper batching.

In [None]:
# ================================================================
# STEP — Create dataset & dataloader for LOCAL 15k images
# ================================================================
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob
import os

class LocalMoCoDataset(Dataset):
    """
    Loads images saved locally in /content/local_pretrain_3k
    Returns two augmented views for MoCo training.
    """
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.files = sorted(glob.glob(os.path.join(root_dir, "*.jpg")))
        print(f"Local dataset: {len(self.files)} images found.")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        x1, x2 = self.transform(img)
        return x1, x2


# ================================================================
# Instantiate dataset + loader
# ================================================================
ssl_transform = MoCoTransform(image_size=96)

ssl_dataset = LocalMoCoDataset(
    root_dir=LOCAL_DATA_DIR,
    transform=ssl_transform
)

ssl_loader = DataLoader(
    ssl_dataset,
    batch_size=128,
    # batch_size=config["batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

print("✓ Local SSL DataLoader created.")
print("Batches per epoch:", len(ssl_loader))


***

## STEP 7 — Setting Up Full MoCo-Cov Training Configuration (Hyperparameters, Optimizer, Scheduler, Checkpointing)

This block prepares everything needed for real MoCo-Cov self-supervised training, but does not start training yet. It sets up hyperparameters, optimizer, learning-rate schedule, and checkpoint paths—essentially the full training environment so our next step can immediately begin pretraining.

Here’s the high-level purpose:

1. Define All Training Hyperparameters

A single config dictionary stores every important setting for MoCo-Cov training, including:

- number of epochs
- batch size
- learning rate & momentum
- InfoNCE temperature
- EMA momentum for the key encoder
- covariance-loss strength
- queue size
- checkpoint file path

This gives our training loop a clear and centralized configuration we can easily adjust later.

2. Create the Optimizer (SGD + Momentum)

The model parameters are attached to an SGD optimizer, which is the standard optimizer used in MoCo-v2 and similar contrastive SSL methods.
This sets momentum and weight decay exactly as in the original papers.

3. Set Up a Cosine Learning-Rate Scheduler

We initialize a cosine-annealing schedule, which slowly decreases the LR from its initial value down to zero across all epochs.
This is a standard and effective schedule for contrastive pretraining.

4. Prepare Checkpointing Settings

Defines where to save training checkpoints and how often to save them (every epoch), ensuring training can resume if interrupted.

In [None]:
# Goal:
#   - Define all hyperparameters for the REAL MoCo-Cov training
#   - Create optimizer (SGD)
#   - Create cosine LR scheduler
#   - Prepare checkpointing
#   - DO NOT start training yet (training will be Step 12)

# This block sets up everything needed for long SSL pretraining.
# ================================================================

import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.optim.lr_scheduler import CosineAnnealingLR

print("Configuring MoCo-Cov training ...")

# --------------------------------------------------------
# 1. Hyperparameters for sanity-check SSL training
# --------------------------------------------------------
config = {
    "epochs": 20,              # short sanity check
    "batch_size": 128,        # matches our DataLoader
    "lr": 0.03,               # standard MoCo-v2 LR for batch 256 → scaled here
    "momentum": 0.9,          # SGD momentum
    "weight_decay": 1e-4,     # standard regularization for ResNet-50
    "tau": 0.2,               # InfoNCE temperature
    "m": 0.999,               # EMA momentum for encoder_k
    "lambda_cov": 1.0,        # strength of VICReg covariance loss
    # "queue_size": 65536,      # typical MoCo queue size (~65k)   # later we need to change this
    "queue_size": 4096,
    "save_every": 1,          # save model every epoch for safety
    "checkpoint_path": "/content/moco_cov_checkpoint.pth"
}

# Print configuration
for k,v in config.items():
    print(f"{k}: {v}")


# --------------------------------------------------------
# 2. Re-create optimizer (SGD)
# --------------------------------------------------------
# We assume 'model' already exists (encoder_q + proj head)

optimizer = optim.SGD(
    model.parameters(),
    lr=config["lr"],
    momentum=config["momentum"],
    weight_decay=config["weight_decay"],
)

print("\nOptimizer ready (SGD with momentum).")


# --------------------------------------------------------
# 3. Cosine learning rate schedule
# --------------------------------------------------------
# Cosine annealing is very standard for MoCo-v2:
# LR(t) = 0.5 * lr * (1 + cos(pi * t / T))

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=config["epochs"],   # one full cosine cycle over training
    eta_min=0.0               # final LR goes to zero
)

print("Cosine LR scheduler ready.")


# --------------------------------------------------------
# 4. Final print-out
# --------------------------------------------------------
print("\n*** Step 11 completed. ***")
print("We can now proceed to Step 12 — Running full SSL pretraining.")
print("We will start training ONLY when we request it.")

## STEP 8: MoCo-Cov Sanity-Check Pretraining Loop

This block runs a complete self-supervised training loop using the MoCo-ResNet50 model on our local subset. Each iteration:

1. Computes q and k embeddings from two augmented views.
2. Builds the contrastive logits using the positive pair and the queue of negatives.
3. Adds a covariance penalty to reduce redundancy in the embeddings.
4. Backpropagates through encoder_q and updates encoder_k with momentum.
5. Updates the negative queue with the new keys.
6. Logs metrics and saves a checkpoint at the end of each epoch.

In short, this block performs the actual MoCo-Cov self-supervised pretraining, combining contrastive learning + covariance regularization in a full end-to-end training loop.

In [None]:
# ================================================================
# STEP 12 — SANITY-CHECK SSL PRETRAINING WITH MoCo-RESNET50
# ================================================================
# This final training loop:
#   - Uses our final MoCoResNet50 model
#   - Uses MoCoQueue (128-d)
#   - Uses MoCoCovLoss (InfoNCE + covariance reg.)
#   - Trains for config["epochs"]
#   - Logs every 500 steps
#   - Saves checkpoint every epoch
# ================================================================

print("Starting SANITY-CHECK MoCo-Cov training on local 15k subset of the data...\n")

loss_fn = MoCoCovLoss(
    temperature=config["tau"],
    lambda_cov=config["lambda_cov"],
    use_queue_for_cov=False            # best for stability
).cuda()

num_epochs = config["epochs"]
log_every = 100
save_every = config["save_every"]
m = config["m"]                        # EMA momentum

for epoch in range(num_epochs):

    model.train()
    epoch_loss = 0.0
    epoch_contrast = 0.0
    epoch_cov = 0.0

    print(f"\n==== Epoch {epoch+1}/{num_epochs} ====")

    for step, (x1, x2) in enumerate(ssl_loader):

        # ---------------------------------------------
        # Move batch to GPU
        # ---------------------------------------------
        x1 = x1.cuda(non_blocking=True)
        x2 = x2.cuda(non_blocking=True)

        # ---------------------------------------------
        # 1. Compute embeddings
        # ---------------------------------------------
        q = model.encoder_q(x1)
        q = F.normalize(q, dim=1)

        with torch.no_grad():
            k = model.encoder_k(x2)
            k = F.normalize(k, dim=1)

        # ---------------------------------------------
        # 2. InfoNCE contrastive loss
        # ---------------------------------------------
        # positive logits: (B, 1)
        l_pos = torch.einsum('nc,nc->n', q, k).unsqueeze(1)

        # negative logits: (B, K)
        l_neg = torch.einsum('nc,kc->nk', q, queue.queue.clone().detach())

        # concatenate logits
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= config["tau"]

        labels = torch.zeros(q.shape[0], dtype=torch.long, device=q.device)

        loss_contrast = F.cross_entropy(logits, labels)

        # ---------------------------------------------
        # 3. VICReg covariance penalty
        # ---------------------------------------------
        Z = torch.cat([q, k], dim=0)
        Z = Z - Z.mean(dim=0, keepdim=True)

        C = (Z.T @ Z) / (Z.size(0) - 1)
        cov_loss = (C ** 2).sum() - (C.diag() ** 2).sum()

        loss_cov = config["lambda_cov"] * cov_loss

        # ---------------------------------------------
        # 4. Total loss
        # ---------------------------------------------
        loss = loss_contrast + loss_cov

        # ---------------------------------------------
        # 5. Backprop on encoder_q
        # ---------------------------------------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ---------------------------------------------
        # 6. Momentum update of encoder_k
        # ---------------------------------------------
        with torch.no_grad():
            for param_q, param_k in zip(model.encoder_q.parameters(), model.encoder_k.parameters()):
                param_k.data = param_k.data * m + param_q.data * (1 - m)

        # ---------------------------------------------
        # 7. Update the queue
        # ---------------------------------------------
        queue.enqueue(k.detach())

        # ---------------------------------------------
        # 8. Track metrics
        # ---------------------------------------------
        epoch_loss += loss.item()
        epoch_contrast += loss_contrast.item()
        epoch_cov += loss_cov.item()

        # ---------------------------------------------
        # 9. Logging every 100 steps
        # ---------------------------------------------
        if (step + 1) % log_every == 0:
            print(
                f"[Epoch {epoch+1}/{num_epochs}] "
                f"Step {step+1}/{len(ssl_loader)} | "
                f"Loss: {loss.item():.4f} | "
                f"Contrast: {loss_contrast.item():.4f} | "
                f"Cov: {loss_cov.item():.4f} | "
                f"LR: {scheduler.get_last_lr()[0]:.6f}"
            )

        # scheduler step every batch
        # scheduler.step()

    # ==================================================
    # Epoch summary
    # ==================================================
    print(
        f"\n>>> Epoch {epoch+1} Summary: "
        f"AvgLoss={epoch_loss/len(ssl_loader):.4f}, "
        f"AvgContrast={epoch_contrast/len(ssl_loader):.4f}, "
        f"AvgCov={epoch_cov/len(ssl_loader):.4f}"
    )
    # Step LR scheduler ONCE per epoch
    scheduler.step()

    # ==================================================
    # Save checkpoint
    # ==================================================
    torch.save(
        {
            "epoch": epoch+1,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "queue": queue.queue.clone(),
        },
        config["checkpoint_path"]
    )
    print(f"Checkpoint saved to {config['checkpoint_path']}")

print("\n=== SANITY-CHECK TRAINING COMPLETE ===")


## STEP 9 a. Preparing CIFAR-10 for Downstream Evaluation  

This block loads the CIFAR-10 dataset and prepares it for downstream evaluation of our SSL encoder. It applies simple, non-contrastive transforms (resize + normalize), then builds train/test DataLoaders.



In [None]:
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# Basic eval transforms (no strong augmentations)
cifar_transform = T.Compose([
    T.Resize((96, 96)),     # match our backbone resolution
    T.ToTensor(),
    T.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    ),
])

# Load CIFAR-10 train/test
cifar_train = torchvision.datasets.CIFAR10(
    root="/content/cifar",
    train=True,
    download=True,
    transform=cifar_transform
)

cifar_test = torchvision.datasets.CIFAR10(
    root="/content/cifar",
    train=False,
    download=True,
    transform=cifar_transform
)

# Dataloaders
cifar_train_loader = DataLoader(
    cifar_train,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

cifar_test_loader = DataLoader(
    cifar_test,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("CIFAR-10 loaded successfully!")
print("Train samples:", len(cifar_train))
print("Test samples:", len(cifar_test))





## STEP 9 b. Preparing CIFAR-100 for Downstream Evaluation  

This block loads the CIFAR-100 dataset and prepares it for downstream evaluation of our SSL encoder. It applies simple, non-contrastive transforms (resize + normalize), then builds train/test DataLoaders.

In [None]:
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# Basic eval transforms (no strong augmentations)
cifar_transform = T.Compose([
    T.Resize((96, 96)),     # match our backbone resolution
    T.ToTensor(),
    T.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    ),
])

# Load CIFAR-100 train/test
cifar_train = torchvision.datasets.CIFAR100(
    root="/content/cifar100",
    train=True,
    download=True,
    transform=cifar_transform
)

cifar_test = torchvision.datasets.CIFAR100(
    root="/content/cifar100",
    train=False,
    download=True,
    transform=cifar_transform
)

# Dataloaders
cifar_train_loader = DataLoader(
    cifar_train,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

cifar_test_loader = DataLoader(
    cifar_test,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("CIFAR-100 loaded successfully!")
print("Train samples:", len(cifar_train))
print("Test samples:", len(cifar_test))


## STEP 10: Extracting Frozen Features for Downstream Evaluation


This block defines a helper function that runs our pretrained encoder_q over an evaluation dataset (like CIFAR-10) to extract 128-dim feature vectors for every image. It keeps the encoder frozen, normalizes the embeddings, and returns:

one big tensor of all features

one big tensor of their labels

These extracted features are later used for downstream tasks such as k-NN classification or a linear probe.

In [None]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def extract_features(encoder_q, dataloader, device="cuda"):
    """
    Extract frozen features from encoder_q.
    Returns:
        features: Tensor of shape [N, 128]
        labels:   Tensor of shape [N]
    """
    encoder_q.eval()
    features_list = []
    labels_list = []

    for images, labels in dataloader:
        images = images.to(device, non_blocking=True)

        # Forward pass through frozen encoder
        feats = encoder_q(images)

        # In case encoder_q doesn't normalize (but we do — this is safe)
        feats = F.normalize(feats, dim=1)

        features_list.append(feats.cpu())
        labels_list.append(labels.cpu())

    # Concatenate all batches into big tensors
    all_features = torch.cat(features_list, dim=0)
    all_labels = torch.cat(labels_list, dim=0)

    print("Finished extracting features.")
    print("Feature tensor shape:", all_features.shape)
    print("Labels tensor shape:", all_labels.shape)

    return all_features, all_labels


In [None]:
train_feats, train_labels = extract_features(
    model.encoder_q,
    cifar_train_loader
)

test_feats, test_labels = extract_features(
    model.encoder_q,
    cifar_test_loader
)


## STEP 11: k-NN Classifier for Evaluating SSL Representations


This block implements a k-nearest-neighbors classifier used to evaluate how good our SSL features are. It compares test features to all training features using cosine similarity, finds each sample’s top-k neighbors, weights their votes by similarity, and predicts the class.

In [None]:
import torch
import torch.nn.functional as F

def knn_classifier(train_feats, train_labels, test_feats, test_labels, k=10, temperature=0.1):
    """
    Standard SSL k-NN classifier.
    Args:
        train_feats: [N_train, D] tensor
        train_labels: [N_train] tensor
        test_feats: [N_test, D] tensor
        test_labels: [N_test] tensor
        k: number of neighbors (default 20, used in MoCo/SimCLR)
        temperature: softmax temperature for similarity weighting
    Returns:
        top-1 accuracy (%)
    """

    # Normalize features (important for cosine similarity)
    train_feats = F.normalize(train_feats, dim=1)
    test_feats = F.normalize(test_feats, dim=1)

    num_test = test_feats.size(0)
    batch_size = 100

    correct = 0

    print("Running k-NN evaluation...")

    for i in range(0, num_test, batch_size):
        # Batch slice
        end = min(i + batch_size, num_test)
        batch = test_feats[i:end]   # shape [B, D]

        # Compute cosine similarity: [B, D] x [D, N_train] = [B, N_train]
        sim = torch.mm(batch, train_feats.t())

        # For each test sample → get top-k neighbors
        sim_val, sim_idx = sim.topk(k=k, dim=1)

        # Retrieve their labels
        neighbor_labels = train_labels[sim_idx]   # shape [B, k]

        # Weight votes using softmax over similarity / temperature
        weights = torch.exp(sim_val / temperature)

        # Score per class
        # num_classes = 10 for CIFAR-10
        num_classes = train_labels.max().item() + 1
        class_scores = torch.zeros((batch.size(0), num_classes))

        for j in range(batch.size(0)):
            neighbors = neighbor_labels[j]    # shape [k]
            w = weights[j]                    # shape [k]
            for n, weight in zip(neighbors, w):
                class_scores[j, n] += weight.item()

        # Prediction = the class with highest score
        preds = class_scores.argmax(dim=1)

        # Count correct predictions
        correct += (preds == test_labels[i:end]).sum().item()

    acc = 100.0 * correct / num_test
    print(f"k-NN accuracy (k={k}): {acc:.2f}%")
    return acc


Running the KNN

In [None]:
knn_acc = knn_classifier(
    train_feats, train_labels,
    test_feats, test_labels,
    k=20,
    temperature=0.1
)

## STEP 12: Linear Probe Classifier for Evaluating SSL Features


This block trains a simple linear classifier on top of frozen SSL features to measure how well the MoCo encoder supports supervised tasks. It:

- Wraps the features/labels into DataLoaders.
- Trains a single linear layer (no nonlinearities) for a small number of epochs.
- Evaluates accuracy on a test set after each epoch.
- Returns the final accuracy.

In [None]:
def linear_probe(
    train_feats, train_labels,
    test_feats, test_labels,
    num_epochs=50,
    batch_size=1024,
    lr=0.1,
    device="cuda"
):

    train_dataset = TensorDataset(train_feats, train_labels)
    test_dataset = TensorDataset(test_feats, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Linear layer
    # classifier = nn.Linear(train_feats.size(1), 10).to(device)  # for cifar 10 use this
    classifier = nn.Linear(train_feats.size(1), 100).to(device) # for cifar 100, use this


    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
    num_epochs = 50
    # SGD optimizer (better for linear probe)
    # optimizer = torch.optim.SGD(
    #     classifier.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=1e-4
    # )

    # Cosine LR schedule
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs
    )

    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        classifier.train()
        total_loss = 0

        for feats, labels in train_loader:
            feats = feats.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = classifier(feats)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Evaluate
        classifier.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for feats, labels in test_loader:
                feats = feats.to(device)
                labels = labels.to(device)
                logits = classifier(feats)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {total_loss:.4f} | Test Acc: {acc:.2f}%")

        scheduler.step()

    return acc


Running linear probe

In [None]:
lp_acc = linear_probe(
    train_feats, train_labels,
    test_feats, test_labels,
    num_epochs=50,
    batch_size=1024,
    lr=1e-3
)