## Step 1.a: Import all required libraries and set up the environment

In [None]:
# PyTorch core libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Torchvision (datasets, transforms, ResNet backbone)
import torchvision
import torchvision.transforms as T
from torchvision.models import resnet50

# Utility libraries
import numpy as np
from tqdm import tqdm
import random
import math
import time

# ================================================================
# Check GPU availability
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ================================================================
# Set random seeds for reproducibility
# ================================================================
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Ensures deterministic behavior (may slow down training slightly)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("Environment setup complete.")

Using device: cuda
Environment setup complete.


## STEP 2 ‚Äî Define MoCo-v2 Style Data Augmentations

MoCo-style self-supervised learning relies heavily on strong augmentations to create two different ‚Äúviews‚Äù of the same image.
This block creates the standard MoCo-v2 augmentation pipeline we‚Äôll use during training.

In [None]:
# These augmentations generate *two strongly augmented views* of each image.
# They are essential for contrastive learning because the model must learn
# invariances to color, crop, blur, distortion, etc.

class MoCoTransform:
    """Applies two strong augmentations to the same image."""

    def __init__(self, image_size=96):
        # MoCo-v2 Augmentation pipeline:
        self.base_transform = T.Compose([
            # 1. Random resized crop (strong spatial diversity)
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),

            # 2. Random horizontal flip
            T.RandomHorizontalFlip(p=0.5),

            # 3. Color jitter (brightness, contrast, saturation, hue)
            T.RandomApply(
                [T.ColorJitter(0.4, 0.4, 0.4, 0.1)],
                p=0.8
            ),

            # 4. Random grayscale conversion
            T.RandomGrayscale(p=0.2),

            # Mild solarization (improves invariance)
            T.RandomApply(
                [T.RandomSolarize(128)],
                p=0.2
            ),

            # 5. Gaussian blur (MoCo v2 uses it heavily)
            T.RandomApply(
                [T.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))],
                p=0.5
            ),

            # 6. Convert PIL ‚Üí Tensor
            T.ToTensor(),

            # 7. Normalize to standard ImageNet stats
            T.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)
            ),
        ])

    def __call__(self, x):
        # Return TWO differently augmented views of the same image
        return self.base_transform(x), self.base_transform(x)


print("MoCo-v2 augmentations initialized successfully.")

MoCo-v2 augmentations initialized successfully.


## STEP 3 ‚Äî Build the Unlabeled Dataset Loader (Self-Supervised Dataset)

This dataset:
- Uses the MoCoTransform we defined
- Returns two augmented views per image
- Works for any folder of unlabeled images
- Creates a PyTorch DataLoader ready for training

# Mehi delete this block later on. This is just for our tests.


In [None]:
# # ================================================================
# # TEMPORARY FIX: Create a dummy dataset for debugging the pipeline
# # ================================================================

# from PIL import Image
# import numpy as np
# import os

# # Create a small dummy directory
# dummy_root = "/content/unlabeled_dataset/"
# os.makedirs(dummy_root, exist_ok=True)

# # Create 20 small random images for testing the DataLoader
# for i in range(20):
#     img = (np.random.rand(96, 96, 3) * 255).astype(np.uint8)
#     Image.fromarray(img).save(f"{dummy_root}/dummy_{i}.jpg")

# print("Dummy dataset created.")


Dummy dataset created.


In [None]:
# # We assume the dataset directory contains raw images organized like:
# # root/
# #    img1.jpg
# #    img2.jpg
# #    img3.jpg
# #    ...
# #
# # There are NO labels in SSL pretraining.

# class UnlabeledImageDataset(torch.utils.data.Dataset):
#     """Dataset that returns two augmented views for self-supervised learning."""

#     def __init__(self, root_dir, transform):
#         """
#         root_dir: path to the directory with unlabeled images
#         transform: MoCoTransform (returns two views)
#         """
#         self.root_dir = root_dir
#         self.transform = transform
#         self.files = []

#         # Collect all image file paths
#         for ext in ["png", "jpg", "jpeg"]:
#             self.files.extend(list(Path(root_dir).glob(f"**/*.{ext}")))

#         print(f"Loaded {len(self.files)} unlabeled images from {root_dir}")

#     def __len__(self):
#         return len(self.files)

#     def __getitem__(self, idx):
#         # Load image
#         img_path = self.files[idx]
#         image = Image.open(img_path).convert("RGB")

#         # Return two augmented views
#         x1, x2 = self.transform(image)
#         return x1, x2


# # ================================================================
# # Example: create the dataset and dataloader
# # ================================================================
# from pathlib import Path
# from PIL import Image

# # UPDATE THIS PATH to your unlabeled dataset in Colab:
# unlabeled_root = "/content/unlabeled_dataset/"  # <-- I will change this later

# # Instantiate dataset with the MoCoTransform
# ssl_transform = MoCoTransform(image_size=96)
# ssl_dataset = UnlabeledImageDataset(unlabeled_root, ssl_transform)

# # Create dataloader (this will be used in training)
# ssl_loader = torch.utils.data.DataLoader(
#     ssl_dataset,
#     batch_size=4,           # adjust based on GPU memory
#     shuffle=True,
#     num_workers=4,
#     drop_last=True
# )

# print("SSL DataLoader is ready.")


Loaded 20 unlabeled images from /content/unlabeled_dataset/
SSL DataLoader is ready.


## STEP 4 ‚Äî Build MoCo Encoders (ResNet-50 + Projection Head + Momentum Copy)

This step creates:
- ResNet-50 backbone (without its classification head)
- Projection head (MLP) used by MoCo-v2
- MoCoEncoder module that combines:
- encoder_q (trainable)
- encoder_k (momentum encoder)
- projection head for each

This is the backbone of our SSL method.

In [None]:
# # We will create:
# # - encoder_q: the main, trainable backbone
# # - encoder_k: the EMA/momentum backbone
# # - projection head: a 2-layer MLP as used in MoCo-v2

# # IMPORTANT:
# # We remove the classification head of ResNet-50 and keep only the
# # convolutional feature extractor.


# class ProjectionHead(nn.Module):
#     """
#     2-layer MLP projection head used in MoCo-v2.
#     It maps high-dimensional ResNet features into a smaller embedding space.
#     """
#     def __init__(self, in_dim=2048, hidden_dim=2048, out_dim=128):
#         super().__init__()
#         self.fc1 = nn.Linear(in_dim, hidden_dim)
#         self.relu = nn.ReLU(inplace=True)
#         self.fc2 = nn.Linear(hidden_dim, out_dim)

#     def forward(self, x):
#         # Applies a simple 2-layer MLP: FC -> ReLU -> FC
#         x = self.fc1(x)
#         x = self.relu(x)
#         x = self.fc2(x)
#         return x


# def build_resnet50_backbone():
#     """
#     Loads a ResNet-50 backbone WITHOUT the classification head.
#     This is the standard backbone for MoCo-v2.
#     """
#     backbone = resnet50(weights=None)  # randomly initialized
#     backbone.fc = nn.Identity()        # remove classification layer
#     return backbone


# class MoCoEncoder(nn.Module):
#     """
#     Encapsulates:
#     - encoder_q (trainable)
#     - encoder_k (momentum/EMA)
#     - projection heads for both encoders
#     """
#     def __init__(self, feature_dim=128, m=0.999):
#         super().__init__()

#         # Momentum coefficient
#         self.m = m

#         # ------------------------------------------------------------
#         # 1. Build query encoder (trainable)
#         # ------------------------------------------------------------
#         self.encoder_q = build_resnet50_backbone()
#         self.proj_q = ProjectionHead(in_dim=2048, out_dim=feature_dim)

#         # ------------------------------------------------------------
#         # 2. Build key encoder (momentum version)
#         # ------------------------------------------------------------
#         self.encoder_k = build_resnet50_backbone()
#         self.proj_k = ProjectionHead(in_dim=2048, out_dim=feature_dim)

#         # Key encoder should NOT update by gradient descent
#         for param in self.encoder_k.parameters():
#             param.requires_grad = False
#         for param in self.proj_k.parameters():
#             param.requires_grad = False

#         # ------------------------------------------------------------
#         # Initialize encoder_k with encoder_q weights
#         # ------------------------------------------------------------
#         self._momentum_update_key_encoder(initial=True)

#     @torch.no_grad()
#     def _momentum_update_key_encoder(self, initial=False):
#         """
#         Updates encoder_k parameters using exponential moving average (EMA).
#         Called every training step.
#         """
#         for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
#             if initial:
#                 # At initialization, copy weights directly
#                 param_k.data.copy_(param_q.data)
#             else:
#                 # Momentum update: k = m * k + (1-m) * q
#                 param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

#         for param_q, param_k in zip(self.proj_q.parameters(), self.proj_k.parameters()):
#             if initial:
#                 param_k.data.copy_(param_q.data)
#             else:
#                 param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

#     def forward(self, x_q, x_k):
#         """
#         Runs the query encoder and (optionally) the key encoder.
#         Returns:
#         - q: normalized query embedding
#         - k: normalized key embedding (no grad)
#         """
#         # Query branch (gradients flow here)
#         q = self.encoder_q(x_q)
#         q = self.proj_q(q)
#         q = F.normalize(q, dim=1)

#         # Key branch (momentum encoder, no grad)
#         with torch.no_grad():
#             k = self.encoder_k(x_k)
#             k = self.proj_k(k)
#             k = F.normalize(k, dim=1)

#         return q, k


# print("MoCo encoders (ResNet-50 + projection heads) defined successfully.")

MoCo encoders (ResNet-50 + projection heads) defined successfully.


## STEP 5 ‚Äî Implement the MoCo Negative Queue (Memory Bank)

The queue stores thousands of negative keys used for contrastive learning.
MoCo depends heavily on this queue to approximate a large batch size.

We need:
- a fixed-size FIFO queue
- enqueue new keys (from each batch)
- dequeue oldest entries
- ensure keys are normalized
- prevent gradients from flowing through queue values

This class will be small but critical for stable training.

In [None]:
# # This queue stores negative embeddings for the contrastive loss.

# # Key ideas:
# # - It is a fixed-size FIFO queue.
# # - Each new batch of keys is enqueued.
# # - The oldest entries are removed to keep size fixed.
# # - Queue embeddings NEVER require gradients (detach safely).

# class MoCoQueue:
#     def __init__(self, feature_dim=128, queue_size=65536):
#         """
#         feature_dim: dimensionality of embedding (default 128)
#         queue_size: total number of negative keys stored
#         """
#         self.queue_size = queue_size

#         # Initialize queue as a (feature_dim x queue_size) tensor
#         self.queue = torch.randn(feature_dim, queue_size)
#         self.queue = F.normalize(self.queue, dim=0)  # normalize each key
#         self.queue_ptr = 0  # pointer to the oldest entry

#         # Queue does NOT require gradients
#         self.queue.requires_grad = False

#     @torch.no_grad()
#     def enqueue_dequeue(self, keys):
#         """
#         keys: tensor of shape (batch_size, feature_dim)
#         Inserts new keys and removes old ones in FIFO fashion.
#         """
#         batch_size = keys.shape[0]

#         # If batch is larger than queue size (rare), we trim
#         if batch_size > self.queue_size:
#             keys = keys[:self.queue_size]

#         # Compute positions in the queue to replace
#         end_ptr = (self.queue_ptr + batch_size) % self.queue_size

#         if end_ptr > self.queue_ptr:
#             # Simple case: no wrap-around
#             self.queue[:, self.queue_ptr:end_ptr] = keys.T
#         else:
#             # Wrap-around case
#             first_segment = self.queue_size - self.queue_ptr
#             self.queue[:, self.queue_ptr:] = keys[:first_segment].T
#             self.queue[:, :end_ptr] = keys[first_segment:].T

#         # Update pointer
#         self.queue_ptr = end_ptr

#     @torch.no_grad()
#     def get_negatives(self):
#         """
#         Returns all negative keys of shape (feature_dim, queue_size),
#         with no gradients.
#         """
#         return self.queue.clone().detach()


# print("MoCo negative queue initialized successfully.")

MoCo negative queue initialized successfully.


## STEP 6 ‚Äî Implement the Contrastive Loss (InfoNCE) + Covariance Regularization Loss

This is the heart of our method.We will build:
- InfoNCE loss for MoCo
- Covariance regularization loss (VICReg-style)
- A clean loss module we can plug directly into the training loop

Our code supports both:
- Covariance loss on (q, k+) only
- Covariance loss on (q, k+, queue samples)
Controlled by a flag: use_queue_for_cov = True/False

This makes our method flexible, cleaner, and perfect for ablation studies in.

In [None]:
# This includes two main components:
# 1. MoCo InfoNCE contrastive loss.
# 2. VICReg-style covariance loss to decorrelate embedding dimensions.

# We support:
#   - Using only (q, k+) for covariance loss
#   - OR including samples from the negative queue (optional)
# via the flag `use_queue_for_cov`.


class MoCoCovLoss(nn.Module):
    def __init__(self, temperature=0.2, lambda_cov=1.0, use_queue_for_cov=False):
        """
        temperature: contrastive temperature parameter (tau)
        lambda_cov: weight for covariance regularization
        use_queue_for_cov: whether to include queue samples in covariance loss
        """
        super().__init__()
        self.tau = temperature
        self.lambda_cov = lambda_cov
        self.use_queue_for_cov = use_queue_for_cov

    def forward(self, q, k_pos, queue_neg):
        """
        q: (B, D) query embeddings
        k_pos: (B, D) positive key embeddings
        queue_neg: (D, K) negative keys from FIFO queue
        """
        # ------------------------------------------------------------
        # 1. Compute InfoNCE contrastive loss
        # ------------------------------------------------------------
        # Positive logits: q ¬∑ k+
        pos_logits = torch.sum(q * k_pos, dim=1, keepdim=True)  # (B, 1)

        # Negative logits: q ¬∑ K
        neg_logits = torch.einsum('nd,dk->nk', q, queue_neg)     # (B, K)

        # Scale by temperature
        pos_logits = pos_logits / self.tau
        neg_logits = neg_logits / self.tau

        # Concatenate pos + neg logits
        logits = torch.cat([pos_logits, neg_logits], dim=1)      # (B, 1+K)

        # Labels: positive is always index 0
        labels = torch.zeros(q.size(0), dtype=torch.long, device=q.device)

        # Cross-entropy loss for contrastive learning
        loss_contrast = F.cross_entropy(logits, labels)

        # ------------------------------------------------------------
        # 2. Compute VICReg covariance regularization loss
        # ------------------------------------------------------------
        # Build feature matrix Z = [q; k+] or optionally add queue samples
        if self.use_queue_for_cov:
            # Take a random subset of negatives from the queue
            K = queue_neg.shape[1]
            num_samples = min(1024, K)  # limit sample size
            idx = torch.randperm(K, device=q.device)[:num_samples]
            queue_subset = queue_neg[:, idx].T  # shape: (num_samples, D)

            Z = torch.cat([q, k_pos, queue_subset], dim=0)  # (B + B + num_samples, D)
        else:
            Z = torch.cat([q, k_pos], dim=0)  # (2B, D)

        # Compute covariance matrix C = Cov(Z)
        Z = Z - Z.mean(dim=0, keepdim=True)      # center
        C = (Z.T @ Z) / (Z.size(0) - 1)          # covariance: (D, D)

        # Penalize off-diagonal elements of covariance
        diag_mask = torch.eye(C.size(0), device=C.device).bool()
        off_diag = C[~diag_mask]
        loss_cov = (off_diag ** 2).sum()

        # ------------------------------------------------------------
        # 3. Combine losses
        # ------------------------------------------------------------
        loss = loss_contrast + self.lambda_cov * loss_cov

        return loss, loss_contrast, loss_cov


print("MoCo InfoNCE + covariance loss module ready.")

MoCo InfoNCE + covariance loss module ready.


## Step 7 ‚Äî Putting everything together: The Full Training Step (forward + loss + backward + momentum update + queue update).

This is our most important engineering step in this entire project.

We now have all components:
- MoCo encoders (query + momentum key)
- Negative queue
- Contrastive + covariance loss

In this step we want to wire all of these together into one training step.This is where everything comes together:
- Forward pass
- Compute q and k+
- Compute loss
- Backprop on encoder_q
- Momentum update of encoder_k
- Update the queue
- Return losses


In [None]:
# # This function performs a *single iteration* of the training loop:
# #   1) Compute q and k+ embeddings
# #   2) Compute contrastive + covariance loss
# #   3) Backprop only through encoder_q
# #   4) Momentum-update encoder_k
# #   5) Enqueue new k+ embeddings into the queue
# # ================================================================

# def training_step(model, queue, loss_fn, batch, optimizer):
#     """
#     model: MoCoEncoder
#     queue: MoCoQueue
#     loss_fn: MoCoCovLoss
#     batch: (x1, x2) two augmented views from dataloader
#     optimizer: optimizer for encoder_q only
#     """

#     # ------------------------------------------------------------
#     # 1. Move data to device
#     # ------------------------------------------------------------
#     x1, x2 = batch
#     x1 = x1.to(device, non_blocking=True)
#     x2 = x2.to(device, non_blocking=True)

#     # ------------------------------------------------------------
#     # 2. Forward pass:
#     #    - q = encoder_q(x1)
#     #    - k = encoder_k(x2)
#     # ------------------------------------------------------------
#     q, k_pos = model(x1, x2)  # both are (B, D)

#     # ------------------------------------------------------------
#     # 3. Get all negative keys from the queue
#     # ------------------------------------------------------------
#     queue_neg = queue.get_negatives()  # (D, K)
#     queue_neg = queue_neg.to(device)

#     # ------------------------------------------------------------
#     # 4. Compute total loss = contrastive + lambda * covariance
#     # ------------------------------------------------------------
#     loss, loss_contrast, loss_cov = loss_fn(q, k_pos, queue_neg)

#     # ------------------------------------------------------------
#     # 5. Backward pass: update encoder_q parameters
#     # ------------------------------------------------------------
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     # ------------------------------------------------------------
#     # 6. Momentum update of encoder_k
#     # ------------------------------------------------------------
#     model._momentum_update_key_encoder(initial=False)

#     # ------------------------------------------------------------
#     # 7. Update the negative queue using the detached k+
#     # ------------------------------------------------------------
#     queue.enqueue_dequeue(k_pos.detach())

#     # ------------------------------------------------------------
#     # 8. Return losses for logging
#     # ------------------------------------------------------------
#     return loss.item(), loss_contrast.item(), loss_cov.item()

## STEP 8 ‚Äî Full MoCo-Cov Training Loop

Now we move to the final major coding step before actual training.
This is where we build the full training loop that:
- iterates over epochs
- iterates over batches
- uses your training_step()
- logs the 3 losses
- saves checkpoints (MoCo encoder + queue)


In [None]:
# # This loop:
# #  - Runs for multiple epochs
# #  - Iterates through all batches
# #  - Calls training_step() each iteration
# #  - Logs losses
# #  - Saves periodic checkpoints
# # ================================================================

# def train_moco(
#     model,
#     queue,
#     loss_fn,
#     dataloader,
#     optimizer,
#     num_epochs=10,
#     log_every=50,
#     save_every=1,
#     checkpoint_path="/content/moco_cov_checkpoint.pth"
# ):
#     """
#     model: MoCoEncoder
#     queue: MoCoQueue
#     loss_fn: MoCoCovLoss
#     dataloader: unlabeled dataset loader
#     optimizer: optimizer for encoder_q
#     num_epochs: total training epochs
#     log_every: how often to print loss
#     save_every: save checkpoint every N epochs
#     checkpoint_path: file to save model+queue
#     """

#     model.train()
#     model.to(device)

#     for epoch in range(num_epochs):
#         epoch_loss = 0.0
#         epoch_contrast = 0.0
#         epoch_cov = 0.0

#         pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

#         for step, batch in enumerate(pbar):

#             # ------------------------------------------------------------
#             # Single training step
#             # ------------------------------------------------------------
#             loss, loss_contrast, loss_cov = training_step(
#                 model, queue, loss_fn, batch, optimizer
#             )

#             # ------------------------------------------------------------
#             # Track running averages
#             # ------------------------------------------------------------
#             epoch_loss += loss
#             epoch_contrast += loss_contrast
#             epoch_cov += loss_cov

#             # ------------------------------------------------------------
#             # Logging
#             # ------------------------------------------------------------
#             if (step + 1) % log_every == 0:
#                 pbar.set_postfix({
#                     "loss": f"{loss:.4f}",
#                     "contrast": f"{loss_contrast:.4f}",
#                     "cov": f"{loss_cov:.4f}"
#                 })

#         # ------------------------------------------------------------
#         # End of epoch logging
#         # ------------------------------------------------------------
#         print(f"\n[Epoch {epoch+1}] "
#               f"Avg Loss: {epoch_loss/len(dataloader):.4f}, "
#               f"Avg Contrast: {epoch_contrast/len(dataloader):.4f}, "
#               f"Avg Cov: {epoch_cov/len(dataloader):.4f}")

#         # ------------------------------------------------------------
#         # Save checkpoint periodically
#         # ------------------------------------------------------------
#         if (epoch + 1) % save_every == 0:
#             save_state = {
#                 "epoch": epoch + 1,
#                 "model_state": model.state_dict(),
#                 "queue_state": queue.queue,
#                 "optimizer_state": optimizer.state_dict()
#             }
#             torch.save(save_state, checkpoint_path)
#             print(f"Checkpoint saved: {checkpoint_path}")

#     print("Training complete.")

## STEP 9 ‚Äî MoCo-ResNet50 + Queue Code Block

This code block defines the core architecture for our self-supervised learning method. It creates the MoCo-v2 model using ResNet-50 as the backbone and adds the required contrastive-learning components. The block contains four major parts:

1. ProjectionMLP ‚Äî the MoCo-v2 projection head

MoCo-v2 requires projecting backbone features (2048-dim from ResNet-50) into a smaller space (128-dim) for contrastive learning.

This class builds the projection head:

2048 ‚Üí 2048 ‚Üí 128


with:
- a hidden layer (ReLU)
- an output layer
- normalization

This is exactly the architecture recommended in the MoCo v2 paper.

2. MoCoResNet50 ‚Äî query encoder and momentum encoder

MoCo requires two networks:

üîπ encoder_q

- the main trainable network
- receives one augmented view of each image
- updated by gradients

üîπ encoder_k
- a slowly updated ‚Äúmomentum encoder‚Äù
- receives the second augmented view
- no gradients
- updated via EMA:
encoder_k ‚Üê m * encoder_k + (1 - m) * encoder_q

Both encoders:
- use a randomly initialized ResNet-50 backbone
- remove the classification layer (fc = Identity())
- attach the same 128-dim projection MLP

The code also initializes encoder_k with the same weights as encoder_q.

3. MoCoQueue ‚Äî the FIFO memory bank of negative samples

MoCo uses a large queue of previously encoded features (negatives), instead of relying only on the current batch.

This queue:
- stores 65,536 embeddings
- each embedding has 128 dimensions
- behaves like a circular buffer
- automatically overwrites the oldest entries
- is used every iteration for contrastive loss

The queue dramatically improves contrastive learning stability and performance.

4. Instantiate the model and queue on GPU

At the end, the code creates:

- model = MoCoResNet50(dim_feature=128).cuda()
- queue = MoCoQueue(size=65536, dim=128).cuda()

and prints a confirmation message.

These objects will be used later in:
- the training step (contrastive_covariance_step)
- the optimizer setup
- the main SSL training loop

In [None]:
# ======================================================================
# FIXED & CORRECTED MoCo Model + Projection Head + Queue (128-dim)
# ======================================================================
# This block ensures that:
#   - encoder_q outputs 128-d features
#   - encoder_k outputs 128-d features
#   - queue stores 128-d features
#   - q, k, queue embeddings perfectly match in dimension
# ======================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


# ----------------------------------------------------------
# 1. Projection Head: 2048 ‚Üí 2048 ‚Üí 128 (MoCo-v2 style MLP)
# ----------------------------------------------------------
class ProjectionMLP(nn.Module):
    """
    Standard MoCo-v2 projection head:
        - Input: 2048-dim backbone features (ResNet-50)
        - Output: 128-dim normalized projection for contrastive learning
    """
    def __init__(self, dim_in=2048, dim_hidden=2048, dim_out=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim_in, dim_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(dim_hidden, dim_out)
        )

    def forward(self, x):
        x = self.mlp(x)
        return F.normalize(x, dim=1)


# ----------------------------------------------------------
# 2. MoCo Model (ResNet-50 backbone + projection head)
# ----------------------------------------------------------
class MoCoResNet50(nn.Module):
    """
    MoCo-v2 architecture:
      - encoder_q: trainable
      - encoder_k: EMA momentum encoder (no grads)
    """
    def __init__(self, dim_feature=128):
        super().__init__()

        # Load ResNet-50 backbone
        backbone = models.resnet50(weights=None)   # must be randomly initialized
        backbone.fc = nn.Identity()               # remove classification layer

        # Query encoder: backbone + projection head
        self.encoder_q = nn.Sequential(
            backbone,
            ProjectionMLP(dim_in=2048, dim_hidden=2048, dim_out=dim_feature)
        )

        # Key encoder (EMA)
        backbone_k = models.resnet50(weights=None)
        backbone_k.fc = nn.Identity()
        self.encoder_k = nn.Sequential(
            backbone_k,
            ProjectionMLP(dim_in=2048, dim_hidden=2048, dim_out=dim_feature)
        )

        # Initialize key encoder to match query encoder
        for param_q, param_k in zip(self.encoder_q.parameters(),
                                    self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

    # no forward() needed, we call encoder_q and encoder_k separately


# ----------------------------------------------------------
# 3. FIFO Queue for Negative Samples (Dimension = 128)
# ----------------------------------------------------------
class MoCoQueue(nn.Module):
    """
    FIFO memory queue storing 128-d embeddings.
    """
    def __init__(self, size=65536, dim=128):
        super().__init__()
        self.size = size
        self.dim = dim

        # queue: (size, dim)
        self.register_buffer("queue", torch.randn(size, dim))
        self.queue = F.normalize(self.queue, dim=1)

        self.register_buffer("ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def enqueue(self, keys):
        """
        Add new keys to queue, remove oldest entries.
        keys: (batch_size, dim)
        """
        batch_size = keys.shape[0]
        ptr = int(self.ptr)

        # Replace entries
        if ptr + batch_size <= self.size:
            self.queue[ptr:ptr+batch_size] = keys
            self.ptr[0] = (ptr + batch_size) % self.size
        else:
            # wrap around
            n1 = self.size - ptr
            n2 = batch_size - n1
            self.queue[ptr:] = keys[:n1]
            self.queue[:n2] = keys[n1:]
            self.ptr[0] = n2


# ----------------------------------------------------------
# 4. Instantiate Model + Queue (DIM=128) on GPU
# ----------------------------------------------------------
model = MoCoResNet50(dim_feature=128).cuda()
# queue = MoCoQueue(size=65536, dim=128).cuda()
queue = MoCoQueue(size=4096, dim=128).cuda()


print(">>> MoCo-ResNet50 model + 128-d queue initialized successfully!")

>>> MoCo-ResNet50 model + 128-d queue initialized successfully!


## STEP 10 ‚Äî Load only a subset of 500k (here we test on 3k) Unlabeled Pretraining Dataset from Hugging Face

### Step 10.0.

In [None]:
# ===============================================================
# DATA SOURCE CONFIGURATION
# ===============================================================
# Choose where training images come from.
# For pipeline testing, use "local".
# Later, we can change this to "hf_full" or "local_500k".

DATA_SOURCE = "local_30k"    # options: "local_15k", "hf_full"

# Path to local subset in /content or Drive
LOCAL_DATA_DIR = "/content/local_pretrain_30k"

print("DATA SOURCE =", DATA_SOURCE)


DATA SOURCE = local_30k


In [None]:
# ===============================================================
# STEP 10 ‚Äî FAST DOWNLOAD USING snapshot_download()  (3‚Äì5 minutes)
# ===============================================================

from huggingface_hub import snapshot_download
import zipfile
import os
from tqdm import tqdm

LOCAL_DATA_DIR = "/content/local_pretrain_30k"
os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

if DATA_SOURCE == "local_30k":

    print("‚úì Using snapshot_download() to fetch dataset shards...")

    repo_dir = snapshot_download(
        repo_id="tsbpp/fall2025_deeplearning",
        repo_type="dataset",
        local_dir="/content/hf_raw",
        allow_patterns=["*.zip"],       # only zip files
    )

    print("‚úì Download complete. Extracting shards...")

    zip_files = sorted([f for f in os.listdir(repo_dir) if f.endswith(".zip")])

    # extract only what you need
    extracted_images = []
    for z in tqdm(zip_files, desc="Extracting ZIPs"):
        zip_path = os.path.join(repo_dir, z)
        with zipfile.ZipFile(zip_path, "r") as zf:
            # list all image file names inside the zip
            for name in zf.namelist():
                if name.endswith(".jpg"):
                    extracted_images.append((zip_path, name))

    # pick random 30000
    import random
    random.seed(42)
    selected = random.sample(extracted_images, 30000)

    print("Saving 30000 images locally...")
    for i, (zip_path, name) in enumerate(tqdm(selected)):
        with zipfile.ZipFile(zip_path, "r") as zf:
            data = zf.read(name)
        out_path = os.path.join(LOCAL_DATA_DIR, f"img_{i:05d}.jpg")
        with open(out_path, "wb") as f:
            f.write(data)

    print(f"‚úì Saved 30000 images to {LOCAL_DATA_DIR}")


‚úì Using snapshot_download() to fetch dataset shards...


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

‚úì Download complete. Extracting shards...


Extracting ZIPs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:03<00:00,  1.51it/s]


Saving 15000 images locally...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30000/30000 [5:25:22<00:00,  1.54it/s]

‚úì Saved 30000 images to /content/local_pretrain_30k





Step 10.1.

In [None]:
# ================================================================
# STEP ‚Äî Create dataset & dataloader for LOCAL 15k images
# ================================================================
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob
import os

class LocalMoCoDataset(Dataset):
    """
    Loads images saved locally in /content/local_pretrain_3k
    Returns two augmented views for MoCo training.
    """
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.files = sorted(glob.glob(os.path.join(root_dir, "*.jpg")))
        print(f"Local dataset: {len(self.files)} images found.")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        x1, x2 = self.transform(img)
        return x1, x2


# ================================================================
# Instantiate dataset + loader
# ================================================================
ssl_transform = MoCoTransform(image_size=96)

ssl_dataset = LocalMoCoDataset(
    root_dir=LOCAL_DATA_DIR,
    transform=ssl_transform
)

ssl_loader = DataLoader(
    ssl_dataset,
    batch_size=128,
    # batch_size=config["batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

print("‚úì Local SSL DataLoader created.")
print("Batches per epoch:", len(ssl_loader))


Local dataset: 30000 images found.
‚úì Local SSL DataLoader created.
Batches per epoch: 234


In [None]:
# # The dataset is hosted at: tsbpp/fall2025_deeplearning
# # It contains:
# #   (1) pretraining split: ~500,000 unlabeled images
# #   (2) eval_public split: labeled dataset for downstream classification
# #
# # We will:
# #   1. Install Hugging Face "datasets" library
# #   2. Load the unlabeled dataset directly (no manual download needed)
# #   3. Wrap it into our self-supervised MoCo dataset class
# #   4. Create the final DataLoader used for SSL training
# # ================================================================

# # ------------------------------------------------
# # 1. Install Hugging Face Datasets
# # ------------------------------------------------
# !pip install datasets --quiet

# from datasets import load_dataset
# from torchvision.transforms import functional as TF

# print("Hugging Face 'datasets' library installed.")

# # ------------------------------------------------
# # 2. Load the unlabeled pretraining dataset
# # ------------------------------------------------
# # The "train" split of this dataset corresponds to the pretrain/ folder.
# # These are ALL unlabeled images used for MoCo pretraining.
# print("Loading unlabeled 500k dataset from Hugging Face...")
# hf_pretrain = load_dataset("tsbpp/fall2025_deeplearning", split="train")

# print("Unlabeled dataset loaded!")
# print("Number of images:", len(hf_pretrain))
# # Each element is like: {"image": PIL.Image}

# # ------------------------------------------------
# # 3. Define a dataset wrapper compatible with our MoCo code
# # ------------------------------------------------
# class HFMoCoDataset(torch.utils.data.Dataset):
#     """
#     Wraps the Hugging Face dataset so that it:
#       - returns two augmented views of each image
#       - matches the expected format for the MoCo training step
#     """
#     def __init__(self, hf_dataset, transform):
#         self.ds = hf_dataset
#         self.transform = transform

#     def __len__(self):
#         return len(self.ds)

#     def __getitem__(self, idx):
#         # HF dataset returns a dict: {"image": PIL.Image}
#         img = self.ds[idx]["image"].convert("RGB")

#         # Return two MoCo-style augmentations
#         x1, x2 = self.transform(img)
#         return x1, x2


# # ------------------------------------------------
# # 4. Create SSL dataset and DataLoader
# # ------------------------------------------------
# print("Preparing MoCo-style dataset...")

# # Use your existing MoCo augmentations
# ssl_transform = MoCoTransform(image_size=96)

# ssl_dataset = HFMoCoDataset(hf_pretrain, ssl_transform)

# # IMPORTANT NOTES:
# #   - batch_size=128 works well for T4 GPUs
# #   - drop_last=True is recommended for MoCo training stability
# ssl_loader = torch.utils.data.DataLoader(
#     ssl_dataset,
#     batch_size=128,        # adjust later depending on GPU memory
#     shuffle=True,
#     num_workers=0,
#     pin_memory=True,
#     drop_last=True
# )

# print("SSL DataLoader ready.")
# print("Number of batches per epoch:", len(ssl_loader))


***

## STEP 11 ‚Äî Configure Full SSL Pretraining (the ‚ÄúReal Training Loop‚Äù)

Goal: Now that we have MoCo-v2 architecture, VICReg covariance loss, 500k dataloader, dummy training validated etc, it‚Äôs time to build the actual training loop that we will run on the full dataset.

This step will:
- configure training hyperparameters
- set the optimizer
- set the learning rate schedule
- prepare checkpointing
- create the full train_moco call for real training
- carefully tune parameters to run on Colab T4 / A100
- ensure the loop is safe, resumable, and stable

This step does NOT start training yet ‚Äî it only defines and prepares everything.

Once Step 11 is ready and stable, Step 12 will actually launch the long pretraining run.

WHAT THE CODE WILL DO?
1. Define recommended hyperparameters(batch size, epochs, temperature œÑ,
EMA momentum m, queue size, Œª for covariance)
2. Build the optimizer (SGD or AdamW)
3. Add a cosine learning rate scheduler
4. Prepare checkpoint saving every N epochs
5. Bind all components together in a training configuration object
6. Leave you ready to run:

train_moco(model, queue, loss_fn, ssl_loader, optimizer, ...)


In [None]:
# Goal:
#   - Define all hyperparameters for the REAL MoCo-Cov training
#   - Create optimizer (SGD)
#   - Create cosine LR scheduler
#   - Prepare checkpointing
#   - DO NOT start training yet (training will be Step 12)

# This block sets up everything needed for long SSL pretraining.
# ================================================================

import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.optim.lr_scheduler import CosineAnnealingLR

print("Configuring MoCo-Cov training ...")

# --------------------------------------------------------
# 1. Hyperparameters for sanity-check SSL training
# --------------------------------------------------------
config = {
    "epochs": 20,              # short sanity check
    "batch_size": 128,        # matches your DataLoader
    "lr": 0.03,               # standard MoCo-v2 LR for batch 256 ‚Üí scaled here
    "momentum": 0.9,          # SGD momentum
    "weight_decay": 1e-4,     # standard regularization for ResNet-50
    "tau": 0.2,               # InfoNCE temperature
    "m": 0.999,               # EMA momentum for encoder_k
    "lambda_cov": 1.0,        # strength of VICReg covariance loss
    # "queue_size": 65536,      # typical MoCo queue size (~65k)   # later we need to change this
    "queue_size": 4096,
    "save_every": 1,          # save model every epoch for safety
    "checkpoint_path": "/content/moco_cov_checkpoint.pth"
}

# Print configuration
for k,v in config.items():
    print(f"{k}: {v}")


# --------------------------------------------------------
# 2. Re-create optimizer (SGD)
# --------------------------------------------------------
# We assume 'model' already exists (encoder_q + proj head)

optimizer = optim.SGD(
    model.parameters(),
    lr=config["lr"],
    momentum=config["momentum"],
    weight_decay=config["weight_decay"],
)

print("\nOptimizer ready (SGD with momentum).")


# --------------------------------------------------------
# 3. Cosine learning rate schedule
# --------------------------------------------------------
# Cosine annealing is very standard for MoCo-v2:
# LR(t) = 0.5 * lr * (1 + cos(pi * t / T))

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=config["epochs"],   # one full cosine cycle over training
    eta_min=0.0               # final LR goes to zero
)

print("Cosine LR scheduler ready.")


# --------------------------------------------------------
# 4. Final print-out
# --------------------------------------------------------
print("\n*** Step 11 completed. ***")
print("You can now proceed to Step 12 ‚Äî Running full SSL pretraining.")
print("We will start training ONLY when you request it.")

Configuring MoCo-Cov training ...
epochs: 20
batch_size: 128
lr: 0.03
momentum: 0.9
weight_decay: 0.0001
tau: 0.2
m: 0.999
lambda_cov: 1.0
queue_size: 4096
save_every: 1
checkpoint_path: /content/moco_cov_checkpoint.pth

Optimizer ready (SGD with momentum).
Cosine LR scheduler ready.

*** Step 11 completed. ***
You can now proceed to Step 12 ‚Äî Running full SSL pretraining.
We will start training ONLY when you request it.


## STEP 12 ‚Äî performs SSL pretraining using the MoCo-Cov architecture on the full 500k dataset.

1. Each iteration receives a batch:
x1, x2 = augmented views of each image


These are strong MoCo-v2 augmentations.

2. Compute embeddings

We compute:

Query embedding (trainable)
q = model.encoder_q(x1)     # B √ó 128

Key embedding (momentum encoder)
k = model.encoder_k(x2)     # B √ó 128


These are both normalized to unit length.

3. Compute contrastive loss (InfoNCE)

We compute:

- Positive similarity ‚Üí q¬∑k
- Negative similarities ‚Üí q¬∑queue
- Temperature scaling
- Cross-entropy classification (positive is class 0)

This trains q to match its own positive (k) and be different from all negatives in the queue.

4. Compute covariance regularization (VICReg‚Äôs C-loss)

We concatenate q and k:

Z = [q; k]


Then compute covariance matrix C:

C = (Z·µÄ Z) / (N-1)


The covariance penalty is:

- Only off-diagonal elements (penalizes redundancy)
- Makes embeddings less correlated and more stable

5. Total Loss
loss = contrastive_loss + lambda_cov * covariance_loss

6. SGD update on encoder_q

Only encoder_q updates with gradients:

loss.backward()
optimizer.step()

7. Momentum update of encoder_k

This is MoCo's main trick:

param_k = m * param_k + (1-m) * param_q


This keeps encoder_k stable.

8. Add new keys to queue (FIFO)
queue.enqueue(k.detach())


This provides thousands of negative examples cheaply.

9. Cosine LR Scheduler

Learning rate decays smoothly each step.

10. Logging + checkpoint saving

Every 500 steps ‚Üí print loss
Every epoch ‚Üí save model


In [None]:
# # Prep code for step 12
# # ================================================================
# # Function: contrastive_covariance_step
# # This performs one full MoCo-Cov training step:
# #   1. Forward pass (encoder_q and encoder_k)
# #   2. Compute contrastive InfoNCE loss
# #   3. Compute VICReg covariance regularization
# #   4. Momentum update for encoder_k
# #   5. Update the negative queue
# # ================================================================

# import torch
# import torch.nn.functional as F

# def contrastive_covariance_step(model, queue, loss_fn,
#                                 x1, x2, m, lambda_cov, temperature):

#     # ------------------------------------------------------------
#     # 1. Compute queries from encoder_q
#     # ------------------------------------------------------------
#     q = model.encoder_q(x1)         # (B, dim)
#     q = F.normalize(q, dim=1)

#     # ------------------------------------------------------------
#     # 2. Compute positive keys from encoder_k (EMA network)
#     # ------------------------------------------------------------
#     with torch.no_grad():
#         k = model.encoder_k(x2)
#         k = F.normalize(k, dim=1)

#     # ------------------------------------------------------------
#     # 3. Compute contrastive logits using the queue
#     # ------------------------------------------------------------
#     # Positive logit: (B, 1)
#     l_pos = torch.einsum('nc,nc->n', q, k).unsqueeze(-1)

#     # Negative logits: (B, K)
#     l_neg = torch.einsum('nc,kc->nk', q, queue.queue.clone().detach())

#     # Concatenate [positive | negatives]
#     logits = torch.cat([l_pos, l_neg], dim=1)

#     # Apply temperature
#     logits /= temperature

#     # Labels: positive key is index 0
#     labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)

#     # Contrastive loss (InfoNCE)
#     loss_contrast = loss_fn(logits, labels)

#     # ------------------------------------------------------------
#     # 4. VICReg covariance regularizer
#     # ------------------------------------------------------------
#     # We compute covariance across embedding dimensions.
#     # Option: concat q and k (this is the typical batch-based approach)
#     Z = torch.cat([q, k], dim=0)             # shape: (2B, dim)
#     Z = Z - Z.mean(dim=0, keepdim=True)

#     # Covariance matrix: dim x dim
#     C = (Z.T @ Z) / (Z.shape[0] - 1)

#     # Penalize off-diagonal entries (reduce redundancy)
#     cov_loss = (C ** 2).sum() - (C.diag() ** 2).sum()

#     # Rescale covariance loss
#     loss_cov = lambda_cov * cov_loss

#     # ------------------------------------------------------------
#     # 5. Total loss
#     # ------------------------------------------------------------
#     loss = loss_contrast + loss_cov

#     # ------------------------------------------------------------
#     # 6. Momentum update of encoder_k
#     # ------------------------------------------------------------
#     with torch.no_grad():
#         for param_q, param_k in zip(model.encoder_q.parameters(),
#                                     model.encoder_k.parameters()):
#             param_k.data = param_k.data * m + param_q.data * (1.0 - m)

#     # ------------------------------------------------------------
#     # 7. Update the queue
#     # ------------------------------------------------------------
#     queue.enqueue(k)

#     return loss, float(loss_contrast.item()), float(loss_cov.item())

In [None]:
# # ================================================================
# # STEP 12 ‚Äî Run Sanity-Check SSL Pretraining (5 epochs)
# # Option B: Recommended logging (print every 500 steps)
# # ================================================================
# # What this block does:
# #   - runs full MoCo-Cov training on the real 500k dataset
# #   - prints logs every 500 training steps
# #   - saves checkpoint every epoch
# #   - uses the optimizer and scheduler from Step 11
# # ================================================================

# print("Starting SANITY-CHECK MoCo-Cov training on full 500k dataset...\n")

# num_epochs = config["epochs"]
# log_every = 500          # recommended logging frequency
# save_every = config["save_every"]

# for epoch in range(num_epochs):
#     model.train()
#     epoch_loss = 0.0
#     epoch_contrast = 0.0
#     epoch_cov = 0.0

#     print(f"\n==== Epoch {epoch+1}/{num_epochs} ====")

#     for step, (x1, x2) in enumerate(ssl_loader):

#         # ----------------------------------------------------------
#         # Move batch to GPU
#         # ----------------------------------------------------------
#         x1 = x1.cuda(non_blocking=True)
#         x2 = x2.cuda(non_blocking=True)

#         # ----------------------------------------------------------
#         # 1. Forward pass through MoCo-Cov step
#         # ----------------------------------------------------------
#         loss, loss_contrast, loss_cov = contrastive_covariance_step(
#             model=model,
#             queue=queue,
#             loss_fn=loss_fn,
#             x1=x1,
#             x2=x2,
#             m=config["m"],               # EMA momentum
#             lambda_cov=config["lambda_cov"],
#             temperature=config["tau"],
#         )

#         # ----------------------------------------------------------
#         # 2. Backprop + optimizer update
#         # ----------------------------------------------------------
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         # ----------------------------------------------------------
#         # 3. Update learning rate with cosine scheduler
#         # ----------------------------------------------------------
#         scheduler.step()

#         # ----------------------------------------------------------
#         # 4. Logging accumulator
#         # ----------------------------------------------------------
#         epoch_loss += loss.item()
#         epoch_contrast += loss_contrast
#         epoch_cov += loss_cov

#         # ----------------------------------------------------------
#         # 5. Print logs every 500 steps
#         # ----------------------------------------------------------
#         if (step + 1) % log_every == 0:
#             print(
#                 f"[Epoch {epoch+1}/{num_epochs}] "
#                 f"Step {step+1}/{len(ssl_loader)} | "
#                 f"Loss: {loss.item():.4f} | "
#                 f"Contrast: {loss_contrast:.4f} | "
#                 f"Cov: {loss_cov:.4f} | "
#                 f"LR: {scheduler.get_last_lr()[0]:.6f}"
#             )

#     # --------------------------------------------------------------
#     # 6. Epoch summary
#     # --------------------------------------------------------------
#     avg_loss = epoch_loss / len(ssl_loader)
#     avg_contrast = epoch_contrast / len(ssl_loader)
#     avg_cov = epoch_cov / len(ssl_loader)

#     print(
#         f"\n>>> Epoch {epoch+1} Summary: "
#         f"AvgLoss={avg_loss:.4f}, "
#         f"AvgContrast={avg_contrast:.4f}, "
#         f"AvgCov={avg_cov:.4f}"
#     )

#     # --------------------------------------------------------------
#     # 7. Save checkpoint
#     # --------------------------------------------------------------
#     if (epoch + 1) % save_every == 0:
#         torch.save(
#             {
#                 "epoch": epoch+1,
#                 "state_dict": model.state_dict(),
#                 "optimizer": optimizer.state_dict(),
#                 "scheduler": scheduler.state_dict(),
#                 "queue": queue.queue.clone(),
#             },
#             config["checkpoint_path"]
#         )
#         print(f"Checkpoint saved to {config['checkpoint_path']}")

# print("\n=== SANITY-CHECK TRAINING COMPLETE ===")

In [None]:
# # ================================================================
# # STEP ‚Äî Create dataset & dataloader for LOCAL 30k images
# # ================================================================

# from torch.utils.data import Dataset, DataLoader
# from PIL import Image
# import glob
# import os

# class LocalMoCoDataset(Dataset):
#     """
#     Loads images saved locally in Drive, e.g.
#         img_00000.jpg
#         img_00001.jpg
#         ...
#     Returns two augmented views for MoCo training.
#     """

#     def __init__(self, root_dir, transform):
#         self.root_dir = root_dir
#         self.transform = transform
#         self.files = sorted(glob.glob(os.path.join(root_dir, "*.jpg")))
#         print(f"Local dataset: {len(self.files)} images found.")

#     def __len__(self):
#         return len(self.files)

#     def __getitem__(self, idx):
#         img_path = self.files[idx]
#         img = Image.open(img_path).convert("RGB")
#         x1, x2 = self.transform(img)
#         return x1, x2


# # ================================================================
# # Instantiate dataset + loader
# # ================================================================
# local_dataset_dir = "/content/drive/MyDrive/moco_pretrain_30k"

# ssl_transform = MoCoTransform(image_size=96)

# ssl_dataset = LocalMoCoDataset(
#     root_dir=local_dataset_dir,
#     transform=ssl_transform
# )

# ssl_loader = DataLoader(
#     ssl_dataset,
#     batch_size=128,
#     shuffle=True,
#     num_workers=2,
#     pin_memory=True,
#     drop_last=True
# )

# print("‚úì Local SSL DataLoader created.")
# print("Batches per epoch:", len(ssl_loader))


In [None]:
# ================================================================
# STEP 12 ‚Äî SANITY-CHECK SSL PRETRAINING WITH MoCo-RESNET50
# ================================================================
# This final training loop:
#   - Uses your final MoCoResNet50 model
#   - Uses MoCoQueue (128-d)
#   - Uses MoCoCovLoss (InfoNCE + covariance reg.)
#   - Trains for config["epochs"]
#   - Logs every 500 steps
#   - Saves checkpoint every epoch
# ================================================================

print("Starting SANITY-CHECK MoCo-Cov training on local 15k subset of the data...\n")

loss_fn = MoCoCovLoss(
    temperature=config["tau"],
    lambda_cov=config["lambda_cov"],
    use_queue_for_cov=False            # best for stability
).cuda()

num_epochs = config["epochs"]
log_every = 100
save_every = config["save_every"]
m = config["m"]                        # EMA momentum

for epoch in range(num_epochs):

    model.train()
    epoch_loss = 0.0
    epoch_contrast = 0.0
    epoch_cov = 0.0

    print(f"\n==== Epoch {epoch+1}/{num_epochs} ====")

    for step, (x1, x2) in enumerate(ssl_loader):

        # ---------------------------------------------
        # Move batch to GPU
        # ---------------------------------------------
        x1 = x1.cuda(non_blocking=True)
        x2 = x2.cuda(non_blocking=True)

        # ---------------------------------------------
        # 1. Compute embeddings
        # ---------------------------------------------
        q = model.encoder_q(x1)
        q = F.normalize(q, dim=1)

        with torch.no_grad():
            k = model.encoder_k(x2)
            k = F.normalize(k, dim=1)

        # ---------------------------------------------
        # 2. InfoNCE contrastive loss
        # ---------------------------------------------
        # positive logits: (B, 1)
        l_pos = torch.einsum('nc,nc->n', q, k).unsqueeze(1)

        # negative logits: (B, K)
        l_neg = torch.einsum('nc,kc->nk', q, queue.queue.clone().detach())

        # concatenate logits
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= config["tau"]

        labels = torch.zeros(q.shape[0], dtype=torch.long, device=q.device)

        loss_contrast = F.cross_entropy(logits, labels)

        # ---------------------------------------------
        # 3. VICReg covariance penalty
        # ---------------------------------------------
        Z = torch.cat([q, k], dim=0)
        Z = Z - Z.mean(dim=0, keepdim=True)

        C = (Z.T @ Z) / (Z.size(0) - 1)
        cov_loss = (C ** 2).sum() - (C.diag() ** 2).sum()

        loss_cov = config["lambda_cov"] * cov_loss

        # ---------------------------------------------
        # 4. Total loss
        # ---------------------------------------------
        loss = loss_contrast + loss_cov

        # ---------------------------------------------
        # 5. Backprop on encoder_q
        # ---------------------------------------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ---------------------------------------------
        # 6. Momentum update of encoder_k
        # ---------------------------------------------
        with torch.no_grad():
            for param_q, param_k in zip(model.encoder_q.parameters(), model.encoder_k.parameters()):
                param_k.data = param_k.data * m + param_q.data * (1 - m)

        # ---------------------------------------------
        # 7. Update the queue
        # ---------------------------------------------
        queue.enqueue(k.detach())

        # ---------------------------------------------
        # 8. Track metrics
        # ---------------------------------------------
        epoch_loss += loss.item()
        epoch_contrast += loss_contrast.item()
        epoch_cov += loss_cov.item()

        # ---------------------------------------------
        # 9. Logging every 100 steps
        # ---------------------------------------------
        if (step + 1) % log_every == 0:
            print(
                f"[Epoch {epoch+1}/{num_epochs}] "
                f"Step {step+1}/{len(ssl_loader)} | "
                f"Loss: {loss.item():.4f} | "
                f"Contrast: {loss_contrast.item():.4f} | "
                f"Cov: {loss_cov.item():.4f} | "
                f"LR: {scheduler.get_last_lr()[0]:.6f}"
            )

        # scheduler step every batch
        # scheduler.step()

    # ==================================================
    # Epoch summary
    # ==================================================
    print(
        f"\n>>> Epoch {epoch+1} Summary: "
        f"AvgLoss={epoch_loss/len(ssl_loader):.4f}, "
        f"AvgContrast={epoch_contrast/len(ssl_loader):.4f}, "
        f"AvgCov={epoch_cov/len(ssl_loader):.4f}"
    )
    # Step LR scheduler ONCE per epoch
    scheduler.step()

    # ==================================================
    # Save checkpoint
    # ==================================================
    torch.save(
        {
            "epoch": epoch+1,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "queue": queue.queue.clone(),
        },
        config["checkpoint_path"]
    )
    print(f"Checkpoint saved to {config['checkpoint_path']}")

print("\n=== SANITY-CHECK TRAINING COMPLETE ===")


Starting SANITY-CHECK MoCo-Cov training on local 15k subset of the data...


==== Epoch 1/20 ====
[Epoch 1/20] Step 100/234 | Loss: 8.1277 | Contrast: 8.0951 | Cov: 0.0326 | LR: 0.030000
[Epoch 1/20] Step 200/234 | Loss: 8.2089 | Contrast: 8.1866 | Cov: 0.0223 | LR: 0.030000

>>> Epoch 1 Summary: AvgLoss=8.0504, AvgContrast=8.0313, AvgCov=0.0191
Checkpoint saved to /content/moco_cov_checkpoint.pth

==== Epoch 2/20 ====
[Epoch 2/20] Step 100/234 | Loss: 8.2489 | Contrast: 8.2340 | Cov: 0.0149 | LR: 0.029815
[Epoch 2/20] Step 200/234 | Loss: 8.2528 | Contrast: 8.2403 | Cov: 0.0125 | LR: 0.029815

>>> Epoch 2 Summary: AvgLoss=8.2503, AvgContrast=8.2356, AvgCov=0.0148
Checkpoint saved to /content/moco_cov_checkpoint.pth

==== Epoch 3/20 ====
[Epoch 3/20] Step 100/234 | Loss: 8.2534 | Contrast: 8.2426 | Cov: 0.0107 | LR: 0.029266
[Epoch 3/20] Step 200/234 | Loss: 8.2216 | Contrast: 8.2098 | Cov: 0.0117 | LR: 0.029266

>>> Epoch 3 Summary: AvgLoss=8.2585, AvgContrast=8.2470, AvgCov=0.0114
Ch

## Step. CIFAR-10 Loader

What we do here?

We will use only simple transforms:
- Resize to 96√ó96
- Convert to tensor
- Normalize using ImageNet stats


In [None]:
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# Basic eval transforms (no strong augmentations)
cifar_transform = T.Compose([
    T.Resize((96, 96)),     # match your backbone resolution
    T.ToTensor(),
    T.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    ),
])

# Load CIFAR-10 train/test
cifar_train = torchvision.datasets.CIFAR10(
    root="/content/cifar",
    train=True,
    download=True,
    transform=cifar_transform
)

cifar_test = torchvision.datasets.CIFAR10(
    root="/content/cifar",
    train=False,
    download=True,
    transform=cifar_transform
)

# Dataloaders
cifar_train_loader = DataLoader(
    cifar_train,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

cifar_test_loader = DataLoader(
    cifar_test,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("CIFAR-10 loaded successfully!")
print("Train samples:", len(cifar_train))
print("Test samples:", len(cifar_test))





100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 170M/170M [00:19<00:00, 8.65MB/s]


CIFAR-10 loaded successfully!
Train samples: 50000
Test samples: 10000


## Cifar 100

In [None]:
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# Basic eval transforms (no strong augmentations)
cifar_transform = T.Compose([
    T.Resize((96, 96)),     # match your backbone resolution
    T.ToTensor(),
    T.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    ),
])

# Load CIFAR-100 train/test
cifar_train = torchvision.datasets.CIFAR100(
    root="/content/cifar100",
    train=True,
    download=True,
    transform=cifar_transform
)

cifar_test = torchvision.datasets.CIFAR100(
    root="/content/cifar100",
    train=False,
    download=True,
    transform=cifar_transform
)

# Dataloaders
cifar_train_loader = DataLoader(
    cifar_train,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

cifar_test_loader = DataLoader(
    cifar_test,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("CIFAR-100 loaded successfully!")
print("Train samples:", len(cifar_train))
print("Test samples:", len(cifar_test))


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 169M/169M [00:13<00:00, 12.6MB/s]


CIFAR-100 loaded successfully!
Train samples: 50000
Test samples: 10000


## Step: Feature Extraction Function

What this function does:
- It puts the encoder in eval() mode
- Freezes parameters
- Processes CIFAR batches
- Computes the 128-dim MoCo embedding for each image
- Saves two tensors:
1) features ‚Üí shape [N, 128]
2) labels ‚Üí shape [N]


In [None]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def extract_features(encoder_q, dataloader, device="cuda"):
    """
    Extract frozen features from encoder_q.
    Returns:
        features: Tensor of shape [N, 128]
        labels:   Tensor of shape [N]
    """
    encoder_q.eval()
    features_list = []
    labels_list = []

    for images, labels in dataloader:
        images = images.to(device, non_blocking=True)

        # Forward pass through frozen encoder
        feats = encoder_q(images)

        # In case encoder_q doesn't normalize (but you do ‚Äî this is safe)
        feats = F.normalize(feats, dim=1)

        features_list.append(feats.cpu())
        labels_list.append(labels.cpu())

    # Concatenate all batches into big tensors
    all_features = torch.cat(features_list, dim=0)
    all_labels = torch.cat(labels_list, dim=0)

    print("Finished extracting features.")
    print("Feature tensor shape:", all_features.shape)
    print("Labels tensor shape:", all_labels.shape)

    return all_features, all_labels


In [None]:
train_feats, train_labels = extract_features(
    model.encoder_q,
    cifar_train_loader
)

test_feats, test_labels = extract_features(
    model.encoder_q,
    cifar_test_loader
)


Finished extracting features.
Feature tensor shape: torch.Size([50000, 128])
Labels tensor shape: torch.Size([50000])
Finished extracting features.
Feature tensor shape: torch.Size([10000, 128])
Labels tensor shape: torch.Size([10000])


## STEP 3 ‚Äî k-NN Classifier for SSL Evaluation

This implementation uses:
- cosine similarity
- top-k = 20 neighbors
- distance weighting (stronger version of k-NN used in SSL papers)

In [None]:
import torch
import torch.nn.functional as F

def knn_classifier(train_feats, train_labels, test_feats, test_labels, k=10, temperature=0.1):
    """
    Standard SSL k-NN classifier.
    Args:
        train_feats: [N_train, D] tensor
        train_labels: [N_train] tensor
        test_feats: [N_test, D] tensor
        test_labels: [N_test] tensor
        k: number of neighbors (default 20, used in MoCo/SimCLR)
        temperature: softmax temperature for similarity weighting
    Returns:
        top-1 accuracy (%)
    """

    # Normalize features (important for cosine similarity)
    train_feats = F.normalize(train_feats, dim=1)
    test_feats = F.normalize(test_feats, dim=1)

    num_test = test_feats.size(0)
    batch_size = 100

    correct = 0

    print("Running k-NN evaluation...")

    for i in range(0, num_test, batch_size):
        # Batch slice
        end = min(i + batch_size, num_test)
        batch = test_feats[i:end]   # shape [B, D]

        # Compute cosine similarity: [B, D] x [D, N_train] = [B, N_train]
        sim = torch.mm(batch, train_feats.t())

        # For each test sample ‚Üí get top-k neighbors
        sim_val, sim_idx = sim.topk(k=k, dim=1)

        # Retrieve their labels
        neighbor_labels = train_labels[sim_idx]   # shape [B, k]

        # Weight votes using softmax over similarity / temperature
        weights = torch.exp(sim_val / temperature)

        # Score per class
        # num_classes = 10 for CIFAR-10
        num_classes = train_labels.max().item() + 1
        class_scores = torch.zeros((batch.size(0), num_classes))

        for j in range(batch.size(0)):
            neighbors = neighbor_labels[j]    # shape [k]
            w = weights[j]                    # shape [k]
            for n, weight in zip(neighbors, w):
                class_scores[j, n] += weight.item()

        # Prediction = the class with highest score
        preds = class_scores.argmax(dim=1)

        # Count correct predictions
        correct += (preds == test_labels[i:end]).sum().item()

    acc = 100.0 * correct / num_test
    print(f"k-NN accuracy (k={k}): {acc:.2f}%")
    return acc


Running the KNN

In [None]:
knn_acc = knn_classifier(
    train_feats, train_labels,
    test_feats, test_labels,
    k=20,
    temperature=0.1
)

Running k-NN evaluation...
k-NN accuracy (k=20): 12.82%


## STEP ‚Äî Linear Probe on Frozen Features
This is what this block does for us:
Use the CIFAR extracted features you already computed:
- train_feats (50k √ó 128)
- train_labels
- test_feats
- test_labels

Train a small linear classifier:
- 128 ‚Üí 10
- No hidden layers
- No dropout
- This is the standard evaluation protocol in SSL papers.
- Train for ~20 epochs (quick and enough for CIFAR)

Optimization:
- CrossEntropyLoss
- Adam or SGD
- Batch size = 1024 (fast ‚Äî data fits into RAM)
- Shuffle training batches

In [None]:
def linear_probe(
    train_feats, train_labels,
    test_feats, test_labels,
    num_epochs=50,
    batch_size=1024,
    lr=0.1,
    device="cuda"
):

    train_dataset = TensorDataset(train_feats, train_labels)
    test_dataset = TensorDataset(test_feats, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Linear layer
    # classifier = nn.Linear(train_feats.size(1), 10).to(device)  # for cifar 10 use this
    classifier = nn.Linear(train_feats.size(1), 100).to(device) # for cifar 100, use this


    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
    num_epochs = 50
    # SGD optimizer (better for linear probe)
    # optimizer = torch.optim.SGD(
    #     classifier.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=1e-4
    # )

    # Cosine LR schedule
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs
    )

    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        classifier.train()
        total_loss = 0

        for feats, labels in train_loader:
            feats = feats.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = classifier(feats)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Evaluate
        classifier.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for feats, labels in test_loader:
                feats = feats.to(device)
                labels = labels.to(device)
                logits = classifier(feats)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {total_loss:.4f} | Test Acc: {acc:.2f}%")

        scheduler.step()

    return acc


Running linear probe

In [None]:
lp_acc = linear_probe(
    train_feats, train_labels,
    test_feats, test_labels,
    num_epochs=50,
    batch_size=1024,
    lr=1e-3
)

Epoch 1/50 | Loss: 225.7438 | Test Acc: 1.00%
Epoch 2/50 | Loss: 225.6652 | Test Acc: 1.05%
Epoch 3/50 | Loss: 225.6352 | Test Acc: 1.54%
Epoch 4/50 | Loss: 225.6112 | Test Acc: 1.99%
Epoch 5/50 | Loss: 225.5855 | Test Acc: 1.62%
Epoch 6/50 | Loss: 225.5641 | Test Acc: 1.59%
Epoch 7/50 | Loss: 225.5410 | Test Acc: 1.60%
Epoch 8/50 | Loss: 225.5192 | Test Acc: 1.88%
Epoch 9/50 | Loss: 225.5122 | Test Acc: 2.27%
Epoch 10/50 | Loss: 225.4850 | Test Acc: 1.87%
Epoch 11/50 | Loss: 225.4779 | Test Acc: 2.22%
Epoch 12/50 | Loss: 225.4595 | Test Acc: 2.02%
Epoch 13/50 | Loss: 225.4494 | Test Acc: 2.18%
Epoch 14/50 | Loss: 225.4356 | Test Acc: 2.67%
Epoch 15/50 | Loss: 225.4223 | Test Acc: 2.59%
Epoch 16/50 | Loss: 225.4124 | Test Acc: 2.78%
Epoch 17/50 | Loss: 225.3978 | Test Acc: 2.85%
Epoch 18/50 | Loss: 225.3915 | Test Acc: 2.65%
Epoch 19/50 | Loss: 225.3806 | Test Acc: 2.26%
Epoch 20/50 | Loss: 225.3695 | Test Acc: 2.84%
Epoch 21/50 | Loss: 225.3616 | Test Acc: 2.97%
Epoch 22/50 | Loss: 22