In [1]:
import random
import torch
from torchvision import transforms
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "C:\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Python39\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Python39\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "C:\Python39\lib\site-packages\traitlets\config\application.py", line 1046, in launch_instance
    app.start()
  File "C:\Python39\lib\site-packages\ipykernel\kernelapp.py", line 712, in start
    self.io_loo

In [2]:
#Getting dataset

root = "./data"
train_set = STL10(root=root, split="train", download=True, transform=None)
unlabeled_set = STL10(root=root, split='unlabeled', download=True, transform=None)
test_set = STL10(root=root, split="test", download=True, transform=None)

Files already downloaded and verified
Files already downloaded and verified


KeyboardInterrupt: 

In [None]:
def show_images(images, titles=None, nrows=2, ncols=5, figsize=(12, 6)):
    """
    Display a grid of images.

    Args:
        images (list): List of images to display.
        titles (list): List of titles for each image.
        nrows (int): Number of rows in the grid.
        ncols (int): Number of columns in the grid.
        figsize (tuple): Size of the figure.
    """
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten()  # Flatten the 2D array of subplots into a 1D array

    # Ensure the number of images matches the number of subplots
    num_images = len(images)
    if num_images < nrows * ncols:
        print(f"Warning: Only {num_images} images provided, but {nrows * ncols} subplots created.")

    for i, ax in enumerate(axes):
        if i < num_images:
            ax.imshow(images[i])
            ax.axis('off')
            if titles:
                ax.set_title(titles[i])
        else:
            ax.axis('off')  # Hide empty subplots

    plt.tight_layout()
    plt.show()
torch.manual_seed(42)

#displyaing 2x5 grid for traning images
train_images = [train_set[i][0] for i in range(10)] 
train_labels = [train_set.classes[train_set[i][1]] for i in range(10)]  # Corresponding labels
show_images(train_images, titles=train_labels, nrows=2, ncols=5)

In [None]:
# displaying UNLABELED set images
torch.manual_seed(42)
unlabeled_images = [unlabeled_set[i][0] for i in range(10)]
show_images(unlabeled_images, nrows=2, ncols=5)

In [None]:
def compute_mean_std(dataset):
    """
    Compute mean and standard deviation for a dataset.

    Args:
        dataset: PyTorch dataset (e.g., STL10).
    Returns:
        mean (list): Mean for each channel (R, G, B).
        std (list): Standard deviation for each channel (R, G, B).
    """
    mean = torch.zeros(3)  # For R, G, B channels
    std = torch.zeros(3)   # For R, G, B channels
    total_pixels = 0

    for image, _ in dataset:
        image = transforms.ToTensor()(image)  # Convert PIL Image to tensor
        image = image.view(3, -1)  # Flatten the image to [3, height*width]
        mean += image.sum(dim=1)   # Sum of pixel values for each channel
        std += (image ** 2).sum(dim=1)  # Sum of squared pixel values for each channel
        total_pixels += image.size(1)

    # Compute the mean and std across the entire dataset
    mean /= total_pixels
    std = torch.sqrt(std / total_pixels - (mean ** 2))

    return mean.tolist(), std.tolist()

mean, std = compute_mean_std(train_set)

print(f"Mean: {mean}")
print(f"Std: {std}")

In [None]:
class STL10Augmentation:
    def __init__(self, image_size=96, mean=None, std=None):
        """
        Augmentation pipeline for STL-10 transformations.
        Args:
            image_size (int): Size of the output image (default: 96x96).
            mean (list): Mean for each channel (R, G, B).
            std (list): Standard deviation for each channel (R, G, B).
        """
        self.augment = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.3),  # Random horizontal flip
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),  # Color jitter
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=23)], p=0.5),  # Gaussian blur
            transforms.RandomGrayscale(p=0.2),  # Random grayscale
            transforms.ToTensor(),  # Convert PIL Image to tensor
            transforms.Normalize(mean=mean, std=std)  # Normalize
        ])

    def __call__(self, image):
        """
        Apply augmentations to the input image.
        Args:
            image (PIL.Image): Input image.
        Returns:
            torch.Tensor: Augmented and normalized image.
        """
        return self.augment(image)

In [None]:
augmentation = STL10Augmentation(mean=mean, std=std)

### Defining the MoCo Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [None]:
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, queue_size=32768, momentum=0.99, temperature=0.07, mlp=False):
        """
        MoCo model for contrastive learning.
        Args:
        base_encoder: Base Encoder Network (ResNet)
        dim: Dimension of the output feature vector
        k (int): Size of the queue (number of negative samples)
        m (float): Momentum for updating the key encoder
        T (float): Temperature for the contrastive loss
        """
        super(MoCo, self).__init__()

        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature

        # creating the query and key encoder
        self.encoder_query = base_encoder(num_classes=dim)
        self.encoder_key = base_encoder(num_classes=dim)

        #initializing key encoder with query encoder weights
        for param_query, param_key in zip(self.encoder_query.parameters(), self.encoder_key.parameters()):
            param_key.data.copy_(param_query.data)
            param_key.requires_grad = False #freezing key encoder

        #queue for negative samples
        self.register_buffer("queue", torch.randn(dim, queue_size))
        self.queue = F.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_query, param_key in zip(self.encoder_query.parameters(), self.encoder_key.parameters()):
            param_key.data = param_key.data * self.momentum + param_query.data * (1.0 - self.momentum)

    @torch.no_grad()
    def dequeue_and_enqueue(self, keys):
       batch_size = keys.shape[0]
       ptr = int(self.queue_ptr)

       # handling the case where batch_size doesn't divide evenly into self.K
       remaining = self.queue_size - ptr
       if remaining < batch_size:
           self.queue[:, ptr:ptr + remaining] = keys[:remaining].T
           self.queue[:, 0:batch_size - remaining] = keys[remaining:].T
       else:
           self.queue[:, ptr:ptr + batch_size] = keys.T

       ptr = (ptr + batch_size) % self.queue_size
       self.queue_ptr[0] = ptr

    def forward(self, query_image, key_image):
      """forward pass for MoCo"""

      #query feature
      query_features = self.encoder_query(query_image)
      query_features = F.normalize(query_features, dim=1)

      #key features
      with torch.no_grad():
        self._momentum_update_key_encoder()
        key_features = self.encoder_key(key_image)
        key_features = F.normalize(key_features, dim=1)

      #computing logits
      logit_positive = torch.einsum("nc, nc->n",[query_features, key_features]).unsqueeze(-1)
      logit_negative = torch.einsum("nc,ck->nk",[query_features, self.queue.clone().detach()])
      logits = torch.cat([logit_positive, logit_negative],dim=1) /self.temperature #applying temperature

      #labels: the first column (positive key) is the ground truth
      labels = torch.zeros(logits.shape[0], dtype=torch.long)

      #updating the queue
      self.dequeue_and_enqueue(key_features)

      return logits, labels

### Defining the Base Encoder

In [None]:
def resnet18(num_classes=128):
  """
  creating a resnet18 model with custom output dimension
  Args:
  num_classes (int): Dimension of the output feature vector
  """
  model = models.resnet18(weights=None)
  model.fc = nn.Linear(model.fc.in_features, num_classes)
  return model

### Defining the Contrastive Loss, InfoNCE

In [None]:
class InfoNCE(nn.Module):

  def __init__(self):
    super(InfoNCE, self).__init__()

  def forward(self, logits, labels):
    return F.cross_entropy(logits, labels)


### Training the MoCo Model

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm

batch_size = 256
num_epochs = 2
learning_rate = 0.001
momentum = 0.999
temperature = 0.07
queue_size = 65536

#to apply two augmentation to the unlabeled set
def collate_fn(batch):
    images = [item[0] for item in batch]  
    query_images = torch.stack([augmentation(img) for img in images]) 
    key_images = torch.stack([augmentation(img) for img in images])  
    return query_images, key_images

#to apply augmentation to train_set returning labels too
def collate_fn_training(batch):
    images = [item[0] for item in batch]  
    labels = [item[1] for item in batch]  
    
    img_tensor = torch.stack([augmentation(img) for img in images])  
    labels = torch.tensor(labels) 

    return img_tensor, labels

#creating the dataloader
train_loader = DataLoader(train_set, batch_size=256, shuffle=True,collate_fn=collate_fn_training, drop_last=True)
unlabeled_loader = DataLoader(unlabeled_set, batch_size=batch_size, shuffle=True,collate_fn=collate_fn)

#Creating the MoCo Model
model = MoCo(resnet18, dim=128, queue_size = queue_size, momentum = momentum, temperature = temperature)
criterion = InfoNCE()
optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
from torchsummary import summary 
#printing summary of the encoder_query
print("Summary for encoder_query:")
summary(model.encoder_query, input_size=(3, 96, 96), device="cpu")

In [None]:
#checking if correctly applied
for img, labels in train_loader:
    print(f"im shape: {img.shape}") 
    print(f"labels shape: {labels.shape}")  # Should be [batch_size]
    break

In [None]:
def count_trainable_parameters(model):
    return sum(w.numel() for w in model.parameters() if w.requires_grad)

num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params:,}")

In [None]:
from torchsummary import summary 
#printing summary of the encoder_query
print("Summary for encoder_query:")
summary(model.encoder_query, input_size=(3, 96, 96), device="cpu")

In [None]:
import time 
total_training_time = 0

for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  start_time = time.time()

  for query_images, key_images in tqdm(unlabeled_loader):
    query_images = query_images
    key_images = key_images #the same image for query and key (with different augmentation)

    #forward pass
    logits, labels = model(query_images, key_images)
    loss = criterion(logits, labels)

    #backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
  
  end_time = time.time()
  epoch_time = end_time - start_time
  total_training_time += epoch_time
  print(f"Epoch[{epoch +1} /{num_epochs}], Loss: {total_loss / len(unlabeled_loader)}")

#printing training time for unlabelled set on CPU
hours, remainder = divmod(total_training_time, 3600)  
minutes, seconds = divmod(remainder, 60)
print(f"Total training time: {hours} hours, {minutes} minutes, {seconds} seconds")

### Evaluating the model with a Linear Classifier on the Labeled Training Set

In [None]:
to_tensor_transform = transforms.ToTensor()
test_loader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda x: (torch.stack([to_tensor_transform(item[0]) for item in x]), torch.tensor([item[1] for item in x]))
)

In [None]:
#freeze the query encoder
for param in model.encoder_query.parameters():
  param.requires_grad = False

#creating a linear classifier
linear_classifier = nn.Linear(128, 10)
optimizer = Adam(linear_classifier.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [None]:
total_training_time = 0

for epoch in range(5):
  model.eval()
  total_loss = 0
  correct = 0

  start_time = time.time()

  for im, labels in tqdm(train_loader):
    im = im
    labels = labels
    #extract feature
    with torch.no_grad():
      features = model.encoder_query(im)
    #forward pass
    outputs = linear_classifier(features)
    outputs = outputs.type(torch.float32)  
    labels = labels.type(torch.long)
    loss = criterion(outputs, labels)  #correct loss calculation

    #backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    _, predicted = torch.max(outputs, 1)
    correct += (predicted == labels).sum().item()

  end_time = time.time()
  epoch_time = end_time - start_time
  total_training_time += epoch_time

  print(f"Epoch [{epoch + 1}/5], Loss: {total_loss / len(train_loader):.3f}, Accuracy: {100 * correct / len(train_set)}")

print(f"Total training time: {total_training_time} seconds")

#### Evaluate on the Test Set and Print Confusion Matrix


In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

y_true = []
y_pred = []

model.eval()
for im, labels in tqdm(test_loader):
    im = im
    labels = labels

    with torch.no_grad():
        features = model.encoder_q(im)
        outputs = linear_classifier(features)
        _, predicted = torch.max(outputs, 1)

    y_true.extend(labels.cpu().numpy())
    y_pred.extend(predicted.cpu().numpy())

# Plot confusion matrix
plot_confusion_matrix(y_true, y_pred, classes=train_set.classes)
print(classification_report(y_true, y_pred, target_names=train_set.classes))

In [None]:

# Convert total training time to hours, minutes, and seconds
total_seconds = int(23542)
hours, remainder = divmod(total_seconds, 3600)  # 1 hour = 3600 seconds
minutes, seconds = divmod(remainder, 60)  # 1 minute = 60 seconds

# Print the total training time in a human-readable format
print(f"Total training time: {hours} hours, {minutes} minutes, {seconds} seconds")