In [1]:
import random
import torch
from torchvision import transforms
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [2]:
#Getting dataset

root = "./data"
train_set = STL10(root=root, split="train", download=True, transform=None)
unlabeled_set = STL10(root=root, split='unlabeled', download=True, transform=None)
test_set = STL10(root=root, split="test", download=True, transform=None)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


  8%|▊         | 200M/2.64G [00:12<02:36, 15.6MB/s]


KeyboardInterrupt: 

In [None]:
def show_images(images, titles=None, nrows=2, ncols=5, figsize=(12, 6)):
    """
    Display a grid of images.

    Args:
        images (list): List of images to display.
        titles (list): List of titles for each image.
        nrows (int): Number of rows in the grid.
        ncols (int): Number of columns in the grid.
        figsize (tuple): Size of the figure.
    """
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten()  # Flatten the 2D array of subplots into a 1D array

    # Ensure the number of images matches the number of subplots
    num_images = len(images)
    if num_images < nrows * ncols:
        print(f"Warning: Only {num_images} images provided, but {nrows * ncols} subplots created.")

    for i, ax in enumerate(axes):
        if i < num_images:
            ax.imshow(images[i])
            ax.axis('off')
            if titles:
                ax.set_title(titles[i])
        else:
            ax.axis('off')  # Hide empty subplots

    plt.tight_layout()
    plt.show()
torch.manual_seed(42)

#displyaing 2x5 grid for traning images
train_images = [train_set[i][0] for i in range(10)]  # First 10 images
train_labels = [train_set.classes[train_set[i][1]] for i in range(10)]  # Corresponding labels
show_images(train_images, titles=train_labels, nrows=2, ncols=5)

In [None]:
# displaying UNLABELED set images
torch.manual_seed(42)
unlabeled_images = [unlabeled_set[i][0] for i in range(10)]  # First 10 images
show_images(unlabeled_images, nrows=2, ncols=5)

In [None]:
import numpy as np
def compute_mean_std(dataset):
    """
    Compute mean and standard deviation for a dataset.

    Args:
        dataset: PyTorch dataset (e.g., STL10).
    Returns:
        mean (list): Mean for each channel (R, G, B).
        std (list): Standard deviation for each channel (R, G, B).
    """
    mean = torch.zeros(3)  # For R, G, B channels
    std = torch.zeros(3)   # For R, G, B channels
    total_pixels = 0

    for image, _ in dataset:
        image = transforms.ToTensor()(image)  # Convert PIL Image to tensor
        image = image.view(3, -1)  # Flatten the image to [3, height*width]
        mean += image.sum(dim=1)   # Sum of pixel values for each channel
        std += (image ** 2).sum(dim=1)  # Sum of squared pixel values for each channel
        total_pixels += image.size(1)

    # Compute the mean and std across the entire dataset
    mean /= total_pixels
    std = torch.sqrt(std / total_pixels - (mean ** 2))

    return mean.tolist(), std.tolist()

mean, std = compute_mean_std(train_set)

print(f"Mean: {mean}")
print(f"Std: {std}")

In [None]:
class STL10Augmentation:
    def __init__(self, image_size=96, mean=None, std=None):
        """
        Augmentation pipeline for STL-10 transformations.
        Args:
            image_size (int): Size of the output image (default: 96x96).
            mean (list): Mean for each channel (R, G, B).
            std (list): Standard deviation for each channel (R, G, B).
        """
        self.augment = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),  # Random crop and resize
            transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Color jitter
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=23)], p=0.5),  # Gaussian blur
            transforms.RandomGrayscale(p=0.2),  # Random grayscale
            transforms.ToTensor(),  # Convert PIL Image to tensor
            transforms.Normalize(mean=mean, std=std)  # Normalize
        ])

    def __call__(self, image):
        """
        Apply augmentations to the input image.
        Args:
            image (PIL.Image): Input image.
        Returns:
            torch.Tensor: Augmented and normalized image.
        """
        return self.augment(image)

In [None]:
augmentation = STL10Augmentation(mean=mean, std=std)

### Defining the MoCo Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [None]:
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=32768, m=0.99, T=0.07, mlp=False):
        """
        MoCo model for contrastive learning.
        Args:
        base_encoder: Base Encoder Network (ResNet)
        dim: Dimension of the output feature vector
        k (int): Size of the queue (number of negative samples)
        m (float): Momentum for updating the key encoder
        T (float): Temperature for the contrastive loss
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the query and key encoder
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        #initializing key encoder with query encoder weights
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False #freezing key encoder

        #queue for negative samples
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def dequeue_and_enqueue(self, keys):
       batch_size = keys.shape[0]
       ptr = int(self.queue_ptr)

       # Handle the case where batch_size doesn't divide evenly into self.K
       remaining = self.K - ptr
       if remaining < batch_size:
           self.queue[:, ptr:ptr + remaining] = keys[:remaining].T
           self.queue[:, 0:batch_size - remaining] = keys[remaining:].T
       else:
           self.queue[:, ptr:ptr + batch_size] = keys.T

       ptr = (ptr + batch_size) % self.K
       self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
      """forward pass for MoCo"""

      #query feature
      q = self.encoder_q(im_q)
      q = F.normalize(q, dim=1)

      #key features
      with torch.no_grad():
        self._momentum_update_key_encoder()
        k = self.encoder_k(im_k)
        k = F.normalize(k, dim=1)

      #compute logits
      logit_pos = torch.einsum("nc, nc->n",[q, k]).unsqueeze(-1)
      logit_neg = torch.einsum("nc,ck->nk",[q, self.queue.clone().detach()])
      logits = torch.cat([logit_pos, logit_neg],dim=1) /self.T #applying temperature

      #labels: the first column (positive key) is the ground truth
      labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

      #update the queue
      self.dequeue_and_enqueue(k)

      return logits, labels

### Defining the Base Encoder

In [None]:
def resnet18(num_classes=128):
  """
  creating a resnet18 model with custom output dimension
  Args:
  num_classes (int): Dimension of the output feature vector
  """
  model = models.resnet18(weights=None)
  model.fc = nn.Linear(model.fc.in_features, num_classes)
  return model

### Defining the Contrastive Loss, InfoNCE

In [None]:
class InfoNCE(nn.Module):

  def __init__(self):
    super(InfoNCE, self).__init__()

  def forward(self, logits, labels):
    return F.cross_entropy(logits, labels)


### Training the MoCo Model

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm

batch_size = 256
num_epochs = 5
learning_rate = 0.03
momentum = 0.99
temperature = 0.07
queue_size = 32768

# to_tensor_transform = transforms.ToTensor()

# Step 8: Define the DataLoader with Two Augmented Views
def collate_fn(batch):
    images = [item[0] for item in batch]  # Extract PIL images from the batch
    im_q = torch.stack([augmentation(img) for img in images])  # First augmented view
    im_k = torch.stack([augmentation(img) for img in images])  # Second augmented view
    return im_q, im_k

#creating the dataloader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
unlabeled_loader = DataLoader(unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2,collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=2)

#Creating the MoCo Model
model = MoCo(resnet18, dim=128, K=queue_size, m =momentum, T=temperature)
criterion = InfoNCE()
optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params}")

In [None]:
for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  for im_q, im_k in tqdm(unlabeled_loader):
    im_q = im_q
    im_k = im_k #the same image for query and key (with different augmentation)

    #forward pass
    logits, labels = model(im_q, im_k)
    loss = criterion(logits, labels)

    #backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
  print(f"Epoch[{epoch +1} /{num_epochs}], Loss: {total_loss / len(unlabeled_loader)}")

### Evaluating the model with a Linear Classifier on the Labeled Training Set

In [None]:
#freeze the query encoder
for param in model.encoder_q.paramters():
  param.requires_grad = False

#creating a linear classifier
linear_classifier = nn.Linear(128, 10)
optimizer = Adam(linear_classifier.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

#evaluating loop

for epoch in range(5):
  model.eval()
  total_loss = 0
  correct = 0

  for im, labels in tqdm(train_loader):
    im = im
    labels = labels

    #extract feature
    with torch.no_grad():
      features = model.encoder_q(im)
    #forward pass
    outputs = linear_classifier(features)
    loss = criterion(logits, labels)

    #backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    _, predicted = torch.max(outputs, 1)
    correct += (predicted == labels).sum().item()

  print(f"Epoch [{epoch + 1}/5], Loss: {total_loss / len(train_loader)}, Accuracy: {100 * correct / len(train_set)}")

#### Evaluate on the Test Set and Print Confusion Matrix


In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

y_true = []
y_pred = []

model.eval()
for im, labels in tqdm(test_loader):
    im = im.cuda()
    labels = labels.cuda()

    with torch.no_grad():
        features = model.encoder_q(im)
        outputs = linear_classifier(features)
        _, predicted = torch.max(outputs, 1)

    y_true.extend(labels.cpu().numpy())
    y_pred.extend(predicted.cpu().numpy())

# Plot confusion matrix
plot_confusion_matrix(y_true, y_pred, classes=train_set.classes)