<a href="https://colab.research.google.com/github/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D2_ComparingTasks/student/W1D2_Tutorial2.ipynb" target="_blank"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"/></a>   <a href="https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/neuromatch/NeuroAI_Course/main/tutorials/W1D2_ComparingTasks/student/W1D2_Tutorial2.ipynb" target="_blank"><img alt="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"/></a>

# Tutorial 2: Contrastive learning for object recognition

**Week 1, Day 2: Comparing Tasks**

**By Neuromatch Academy**

__Content creators:__ Andrew F. Luo, Leila Wehbe

__Content reviewers:__ Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura

__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk


___


# Tutorial Objectives

*Estimated timing of tutorial: 20 minutes*

By the end of this tutorial, participants will be able to:
1. Understand why we want to do contrastive learning.
2. Understand the losses in contrastive learning.
3. Run an example on contrastive learning using MNIST.



In [None]:
# @markdown
from IPython.display import IFrame
from ipywidgets import widgets
out = widgets.Output()
with out:
    print(f"If you want to download the slides: https://osf.io/download/d4r6g/")
    display(IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/d4r6g/?direct%26mode=render%26action=download%26mode=render", width=730, height=410))
display(out)

---
# Setup



##  Install and import feedback gadget


###  Install and import feedback gadget


In [None]:
# @title Install and import feedback gadget

!pip install vibecheck numpy matplotlib torch torchvision tqdm ipysankeywidget ipywidgets --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt - leave this as is
        notebook_section,
        {
        "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
        "name": "sciencematch_sm", # change the name of the course : neuromatch_dl, climatematch_ct, etc
        "user_key": "y1x3mpx5",
        },
    ).render()

feedback_prefix = "W1D2_T2"

###  Import dependencies


In [None]:
# @title Import dependencies
# @markdown

import logging
import gc

# PyTorch and related libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision

# Set up PyTorch backend configurations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Numpy for numerical operations
import numpy as np

# Matplotlib for plotting
import matplotlib.pyplot as plt

# Scikit-learn for machine learning utilities
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

## Figure settings


###  Figure settings


In [None]:
# @title Figure settings
# @markdown

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perform high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

##  Plotting functions


###  Plotting functions


In [None]:
# @title Plotting functions
# @markdown

##  Helper functions

###  Helper functions


In [None]:
#@title Helper functions
# @markdown

# This is code from the pytorch metric learning package

def neg_inf(dtype):
    # Returns the smallest possible value for the given data type
    return torch.finfo(dtype).min

def small_val(dtype):
    # Returns the smallest positive value greater than zero for the given data type
    return torch.finfo(dtype).tiny

def to_dtype(x, tensor=None, dtype=None):
    # Converts tensor `x` to the specified `dtype`, or to the same dtype as `tensor`
    if not torch.is_autocast_enabled():
        dt = dtype if dtype is not None else tensor.dtype
        if x.dtype != dt:
            x = x.type(dt)
    return x

def get_matches_and_diffs(labels, ref_labels=None):
    # Returns tensors indicating matches and differences between pairs of labels
    if ref_labels is None:
        ref_labels = labels
    labels1 = labels.unsqueeze(1)  # Expand dimensions for comparison
    labels2 = ref_labels.unsqueeze(0)  # Expand dimensions for comparison
    matches = (labels1 == labels2).byte()  # Byte tensor of matches
    diffs = matches ^ 1  # Byte tensor of differences (inverse of matches)
    if ref_labels is labels:
        matches.fill_diagonal_(0)  # Remove self-matches
    return matches, diffs

def get_all_pairs_indices(labels, ref_labels=None):
    """
    Given a tensor of labels, this will return 4 tensors.
    The first 2 tensors are the indices which form all positive pairs
    The second 2 tensors are the indices which form all negative pairs
    """
    matches, diffs = get_matches_and_diffs(labels, ref_labels)
    a1_idx, p_idx = torch.where(matches)  # Indices for positive pairs
    a2_idx, n_idx = torch.where(diffs)  # Indices for negative pairs
    return a1_idx, p_idx, a2_idx, n_idx

def cos_sim(input_embeddings):
    # Computes cosine similarity matrix for input embeddings
    normed_embeddings = torch.nn.functional.normalize(input_embeddings, dim=-1)  # Normalize embeddings
    return normed_embeddings @ normed_embeddings.t()  # Cosine similarity matrix

def dcl_loss(pos_pairs, neg_pairs, indices_tuple, temperature=0.07):
    # This is the modified InfoNCE loss called "Decoupled Contrastive Learning" for small batch sizes
    # Basically You remove the numerator from the sum to the denominator

    a1, p, a2, _ = indices_tuple  # Unpack indices

    if len(a1) > 0 and len(a2) > 0:
        dtype = neg_pairs.dtype
        pos_pairs = pos_pairs.unsqueeze(1) / temperature  # Scale positive pairs by temperature
        neg_pairs = neg_pairs / temperature  # Scale negative pairs by temperature
        n_per_p = to_dtype(a2.unsqueeze(0) == a1.unsqueeze(1), dtype=dtype)  # Indicator matrix for matching pairs
        neg_pairs = neg_pairs * n_per_p  # Zero out non-matching pairs
        neg_pairs[n_per_p == 0] = neg_inf(dtype)  # Replace non-matching pairs with negative infinity

        # Compute the maximum value for numerical stability
        max_val = torch.max(
            pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]
        ).detach()
        # Compute numerator and denominator for the loss
        numerator = torch.exp(pos_pairs - max_val).squeeze(1)
        denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1)
        log_exp = torch.log((numerator / denominator) + small_val(dtype))
        return -log_exp  # Return the negative log of the exponential
    return 0

def pair_based_loss(mat, indices_tuple, lossfunc):
    # Computes pair-based loss using the provided loss function
    a1, p, a2, n = indices_tuple  # Unpack indices
    pos_pair, neg_pair = [], []
    if len(a1) > 0:
        pos_pair = mat[a1, p]  # Extract positive pairs
    if len(a2) > 0:
        neg_pair = mat[a2, n]  # Extract negative pairs
    return lossfunc(pos_pair, neg_pair, indices_tuple)  # Apply loss function

## Section 1: Building the model

### What is contrastive learning?

Contrastive learning is often referred to as "self-supervised learning (SSL)" and has historically been known as "metric learning." The essence of contrastive/metric learning is that instead of outputting a classification one-hot/softmax vector, or a regression value, you directly output a high-dimensional embedding.

Here is an example: given multiple data points from a single class (for example, three photos of you from different viewpoints) and different classes (for example, 10 photos from one or multiple people who are not you), you want the three embeddings from your photos to be closer to each other while being farther away from the ten embeddings from the different classes.

Hence the name "metric learning," where you seek to learn a metric/distance that fits the constraints of the data.

### Why contrastive learning?

It may not be immediately obvious why you would want to engage in contrastive or metric learning. Can't you just use a large 1000-class ImageNet-trained classifier to recognize every image? However, metric learning proves useful when the number of classes is not known ahead of time. For example, if you wanted a network to recognize human faces, there are approximately 7 billion people on this planet, making it impractical to train a classification network with 7 billion output neurons. Instead, you can train a network to output a high-dimensional embedding for each image. With this approach, given a reference image of a person, your network can determine if a new photo is similar to or different from the reference image.

### Analysis of the results

As we move forward, we'll employ PCA (Principal Component Analysis) and t-SNE (t-Distributed Stochastic Neighbor Embedding) as our primary tools for visualizing data. These techniques are instrumental in reducing the dimensionality of the data, allowing us to observe patterns and relationships that are otherwise difficult to discern in high-dimensional spaces. By visualizing data in this way, we can gain insightful perspectives that are crucial for understanding complex datasets.

### Mini residual block

Our initial focus will be on creating a mini_residual block. This block adopts a modern approach to the residual design, featuring a prenormalization step as suggested by Kaiming He. We will also incorporate the LeakyReLU activation function. LeakyReLU is particularly favored in generative adversarial networks (GANs) due to its ability to maintain non-zero gradients, which helps in the training process by avoiding the vanishing gradient problem. 


In [None]:
class mini_residual(nn.Module):
    # Follows "Identity Mappings in Deep Residual Networks", uses LayerNorm instead of BatchNorm, and LeakyReLU instead of ReLU
    def __init__(self, feat_in=128, feat_out=128, feat_hidden=256, use_norm=True):
        super().__init__()
        # Define the residual block with or without normalization
        if use_norm:
            self.block = nn.Sequential(
                nn.LayerNorm(feat_in),  # Layer normalization on input features
                nn.LeakyReLU(negative_slope=0.1),  # LeakyReLU activation
                nn.Linear(feat_in, feat_hidden),  # Linear layer transforming input to hidden features
                nn.LayerNorm(feat_hidden),  # Layer normalization on hidden features
                nn.LeakyReLU(negative_slope=0.1),  # LeakyReLU activation
                nn.Linear(feat_hidden, feat_out)  # Linear layer transforming hidden to output features
            )
        else:
            self.block = nn.Sequential(
                nn.LeakyReLU(negative_slope=0.1),  # LeakyReLU activation
                nn.Linear(feat_in, feat_hidden),  # Linear layer transforming input to hidden features
                nn.LeakyReLU(negative_slope=0.1),  # LeakyReLU activation
                nn.Linear(feat_hidden, feat_out)  # Linear layer transforming hidden to output features
            )

        # Define the bypass connection
        if feat_in != feat_out:
            self.bypass = nn.Linear(feat_in, feat_out)  # Linear layer to match dimensions if they differ
        else:
            self.bypass = nn.Identity()  # Identity layer if input and output dimensions are the same

    def forward(self, input_data):
        # Forward pass: apply the block and add the bypass connection
        return self.block(input_data) + self.bypass(input_data)

### Full model construction

Following the mini_residual block, we will construct the full model. This model will consist of a series of residual blocks stacked together. In PyTorch, the components of a model are organized in a sequence using nn.Sequential, which executes the blocks from the first to the last. This sequential arrangement simplifies the process of defining forward pass operations, ensuring that data flows through the blocks in the intended order. By stacking these blocks, the model can learn complex patterns from the data, enhancing its predictive performance.

In [None]:
class Model(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, num_blocks=4):
        super().__init__()
        # Initial linear projection from input dimension to hidden dimension
        self.in_proj = nn.Linear(in_dim, hidden_dim)
        # Sequence of residual blocks
        self.hidden = nn.Sequential(
            *[mini_residual(feat_in=hidden_dim, feat_out=hidden_dim, feat_hidden=hidden_dim) for i in range(num_blocks)]
        )
        # Output linear projection from hidden dimension to output dimension
        self.out = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        # Forward pass: input projection, passing through residual blocks, and final output projection
        in_proj_out = self.in_proj(x)
        hidden_out = self.hidden(in_proj_out)
        return self.out(hidden_out)

Now, let's move on to defining the loss function for our model using an approach derived from the PyTorch metric learning package for better clarity. We will implement a variant of the InfoNCE loss function, which is widely recognized as one of the most effective contrastive or metric learning losses. It has been prominently used in various models, including OpenAI's CLIP, due to its ability to enhance feature discrimination by contrasting positive pairs against negative pairs.

To clarify, positive pairs refer to two data points that should be close together in embedding space. For example, two photos of you in different lighting conditions. On the other hand, negative pairs refer to two data points that should be far apart in embedding space, such as a photo of you versus a photo of a dog (assuming you are not a dog). Note that positive pairs and negative pairs do not have to be images. For instance, a picture and its corresponding text could also form a positive pair. Recent work has also explored defining positive pairs using an older version of the encoder, as seen in Google’s Momentum Contrast (MoCo) or EMA Contrastive methods.

InfoNCE is one of the most common contrastive losses. It is essentially a cross-entropy loss used for classifying the correct positive pair from a pool of pairs. Variants like MIL-NCE allow for multiple positive pairs. This loss typically requires substantial batch sizes—commonly 128 or larger—to perform optimally. The need for large batch sizes stems from the necessity for diverse negative samples in the batch to effectively learn the contrasts. However, large batch sizes can be impractical in resource-constrained settings or when data availability is limited.

To address this, we will implement a modified version of InfoNCE as described in the ["Decoupled Contrastive Learning"] (https://link.springer.com/chapter/10.1007/978-3-031-19809-0_38) paper. This variant adapts the loss to be more suitable for smaller batch sizes by modifying the denominator of the InfoNCE formula. Specifically, it removes the positive example from the denominator, which reduces the computational demand and stabilizes training when fewer examples are available. This adjustment not only makes the loss function more flexible but also maintains robustness in learning discriminative features even with smaller batch sizes.

Here is what the default InfoNCE loss looks like. Note that prior to the dot product, the vectors are normalized to unit norm.

$$ \mathcal{L}_q = -\log \left( \frac{\exp(q \cdot k_{+} / \tau)}{\sum_{i=0}^{K} \exp(q \cdot k_i / \tau)} \right) $$

Remember, the goal is to minimize the loss function. The numerator $ {\exp(q \cdot k_{+} / \tau)} $ represents the similarity between the query and the positive key. By maximizing this term, the model learns to bring the positive pairs closer together in the embedding space. The denominator $ {\sum_{i=0}^{K} \exp(q \cdot k_i / \tau)} $ includes the similarities of the query with all other pairs (positive and negative). By normalizing with this sum, the model is encouraged to push the positive pairs closer together while pushing the negative pairs further apart. Essentially, we want the similarity of positive pairs to be higher relative to the similarity of all pairs. The Decoupled Contrastive Learning (DCL) loss modifies this slightly by removing the positive pair from the denominator, as detailed in their paper.

Now, we will create the PyTorch dataset object. This object defines how data is loaded from disk for each batch and what transformations are applied. It is important to note that you are not limited to using torchvision transforms; it is quite common to write custom transformation code within the dataset object.

In [None]:
# Define the transformations for the MNIST dataset
mnist_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # Convert images to tensor
    torchvision.transforms.Normalize((0.1307,), (0.3081,))  # Normalize the images with mean and standard deviation
])

# Load the MNIST test dataset with the defined transformations
test_dset = torchvision.datasets.MNIST("./", train=False, transform=mnist_transforms, download=True)

# Calculate the height and width of the MNIST images (28x28)
height = int(784**0.5)
width = height

# Select the first image from the test dataset
idx = 0
data_point = test_dset[idx]

# Display the image using matplotlib
plt.imshow(data_point[0][0].numpy(), cmap='gray')  # Display the image in grayscale
plt.show()

# Print the label of the selected image
print(data_point[1])

Now we will create the model using the definition we wrote previously and move it to the desired device. It is important to note that in PyTorch, calling .to(device) on a module (such as a neural network model) acts on the module itself, meaning it is an in-place operation. However, when calling this function on a tensor directly, it is not an in-place operation.

In [None]:
# Initialize the model with specified input, output, and hidden dimensions
mynet = Model(in_dim=784, out_dim=128, hidden_dim=256)

# Automatically select the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Output the device that will be used
print(f"Using device: {device}")

# Move the model to the selected device
_ = mynet.to(device)

Let's create a test DataLoader and examine the representations produced by the untrained network for each number. We will compute the cosine similarity for each handwritten character within the same class, setting the diagonal to np.nan to avoid self-comparison.

Additionally, we will compute the cosine similarity for each handwritten character across different classes.

Remember to call network.eval() before evaluating the network. This is an in-place operation that instructs PyTorch to freeze certain buffers (such as those in batch normalization) and disable dropout.

We will use torch.inference_mode() to disable gradient computation and speed up the testing process. However, if this causes issues, you can replace it with torch.no_grad(). Note that inference_mode does not automatically enable eval.

In [None]:
# First try with untrained network, find the cosine similarities within a class and across classes

# Create a DataLoader for the test dataset with a batch size of 50
test_loader = DataLoader(test_dset, batch_size=50, shuffle=False)  # enable persistent_workers=True if more than 1 worker to save CPU

# Set the model to evaluation mode
mynet.eval()

# Initialize lists to store test embeddings and labels
test_embeddings = []
test_labels = []

# Initialize a similarity matrix of size 10x10 for 10 classes
sim_matrix = np.zeros((10, 10))

# Disable gradient computation for inference
with torch.inference_mode():
    for data_batch in test_loader:
        test_img, test_label = data_batch  # Get images and labels from the batch
        batch_size = test_img.shape[0]  # Get the batch size
        flat = test_img.reshape(batch_size, -1).to(device, non_blocking=True)  # Flatten the images and move to device
        pred_embeddings = mynet(flat).cpu().numpy().tolist()  # Get embeddings from the model and move to CPU
        test_embeddings.extend(pred_embeddings)  # Store the embeddings
        test_labels.extend(test_label.numpy().tolist())  # Store the labels

# Convert embeddings and labels to numpy arrays
test_embeddings = np.array(test_embeddings)

# Normalize the embeddings
test_embeddings_normed = test_embeddings / np.linalg.norm(test_embeddings, axis=1, keepdims=True)

# Convert test labels to numpy array
test_labels = np.array(test_labels)

### Visualizing the cosine similarity of embeddings within the same class and across different classes before training

Ideally, you should observe a very high cosine similarity for images within the same class (along the diagonal) and very low cosine similarity for images from different classes (off-diagonal).

However, since our network is untrained, you will notice that there isn't much difference in the cosine similarities. This lack of clear structure in the similarity matrix is expected at this stage because the network has not yet learned to distinguish between different classes.

In [None]:
# Dictionary to store normalized embeddings for each class
embeddings = {}
for i in range(10):
    embeddings[i] = test_embeddings_normed[test_labels == i]

# Within class cosine similarity:
for i in range(10):
    sims = embeddings[i] @ embeddings[i].T  # Compute cosine similarity matrix within the class
    np.fill_diagonal(sims, np.nan)  # Ignore diagonal values (self-similarity)
    cur_sim = np.nanmean(sims)  # Calculate the mean similarity excluding diagonal
    sim_matrix[i, i] = cur_sim  # Store the within-class similarity in the matrix

    print("Within class {} cosine similarity".format(i, cur_sim))

print("==================")

# Between class cosine similarity:
for i in range(10):
    for j in range(10):
        if i == j:
            continue  # Skip if same class (already computed)
        elif i > j:
            continue  # Skip if already computed (matrix symmetry)
        else:
            sims = embeddings[i] @ embeddings[j].T  # Compute cosine similarity between different classes
            cur_sim = np.mean(sims)  # Calculate the mean similarity
            sim_matrix[i, j] = cur_sim  # Store the similarity in the matrix
            sim_matrix[j, i] = cur_sim  # Ensure symmetry in the matrix
            print("{} and {} cosine similarity {}".format(i, j, cur_sim))

# Plotting the similarity matrix
plt.imshow(sim_matrix, vmin=0.0, vmax=1.0)
plt.title("Untrained Network Cosine Similarity Matrix")
plt.colorbar()
plt.show()

## Section 2: Training the model and visualizing feature similarity 

Now we will train the network!

Notice how we decay the learning rate so that the final learning rate will be half of the initial learning rate. We will use the AdamW optimizer, which is the Adam optimizer with decoupled weight decay. A learning rate of 3e-4 and a weight decay of 1e-2 are typical settings for AdamW.

It is important to note that weight decay in AdamW and SGD works differently in PyTorch implementations. In PyTorch, the AdamW weight decay is further scaled by the learning rate (real weight decay = weight decay * lr), but in SGD, it is not scaled by the learning rate. Therefore, in AdamW, it is common to use higher weight decay values than in SGD.

Additionally, remember to call mynet.train() before starting the training process. This sets mynet to training mode, enabling the buffers and dropout layers (if they are present in the network architecture).

In [None]:
# Number of epochs for training
epochs = 10

# Automatically select the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Output the device that will be used
print(f"Using device: {device}")

# Load the MNIST training dataset with the defined transformations
train_dset = torchvision.datasets.MNIST("./", train=True, transform=mnist_transforms)
train_loader = DataLoader(train_dset, batch_size=50, shuffle=True)  # Enable persistent_workers=True if more than 1 worker to save CPU

# Cleanup: delete the optimizer and free up memory if this block is re-run
try:
    del optimizer
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

# Cleanup: delete the network and free up memory if this block is re-run
try:
    del mynet
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

# Initialize the model with specified input, output, and hidden dimensions
mynet = Model(in_dim=784, out_dim=128, hidden_dim=256)
_ = mynet.to(device)  # Move the model to the selected device

# Enable training mode, which may affect dropout and other layers
mynet.train(mode=True)
print("Is the network in training mode?", mynet.training)

# Initial learning rate and decay factor for the optimizer
init_lr = 3e-4
lr_decay_factor = 0.5

# Initialize the optimizer with model parameters and learning rate
optimizer = torch.optim.AdamW(mynet.parameters(), lr=init_lr, weight_decay=1e-2)

# Tracker to keep track of loss values during training
loss_tracker = []

# Training loop over the specified number of epochs
for epoch_id in range(1, epochs+1):
    loss_epoch_tracker = 0
    batch_counter = 0

    # Adjust learning rate for the current epoch
    new_lrate = init_lr * (lr_decay_factor ** (epoch_id / epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate

    batches_in_epoch = len(train_loader)
    for data_batch in train_loader:
        optimizer.zero_grad()  # Zero out gradients

        # Get images and labels from the batch
        train_img, train_label = data_batch
        batch_size = train_img.shape[0]

        # Flatten images and move data to the selected device
        flat = train_img.reshape(batch_size, -1).to(device, non_blocking=True)
        train_label = train_label.to(device, non_blocking=True)

        # Forward pass through the network
        predicted_results = mynet(flat)

        # Compute cosine similarity matrix for the batch
        similarities = cos_sim(predicted_results)

        # Get pairs of indices for positive and negative pairs
        label_pos_neg = get_all_pairs_indices(train_label)

        # Compute the loss using the decoupled contrastive learning loss function
        final_loss = torch.mean(pair_based_loss(similarities, label_pos_neg, dcl_loss))

        # Compute gradients from the loss
        final_loss.backward()

        # Update the model parameters using the optimizer
        optimizer.step()

        # Convert the loss to a single CPU scalar
        loss_cpu_number = final_loss.item()

        # Keep track of the losses for visualization
        loss_epoch_tracker += loss_cpu_number
        batch_counter += 1

        # Print the current epoch, batch number, and loss every 500 batches
        if batch_counter % 500 == 0:
            print("Epoch {}, Batch {}/{}, loss: {}".format(epoch_id, batch_counter, batches_in_epoch, loss_cpu_number))

    # Print the average loss for the epoch
    print("Epoch average loss {}".format(loss_epoch_tracker / batch_counter))

# Set the model to test mode (optional, not used here)

Let us now extract the features from the trained network!

Again, please make it a habit to set the network into eval mode.

In [None]:
# DataLoader for the test dataset with a batch size of 50
test_loader = DataLoader(test_dset, batch_size=50, shuffle=False)  # Enable persistent_workers=True if more than 1 worker to save CPU

# Set the model to evaluation mode
mynet.eval()

# Initialize lists to store test embeddings and labels
test_embeddings = []
test_labels = []

# Disable gradient computation for inference
with torch.inference_mode():
    for data_batch in test_loader:
        test_img, test_label = data_batch  # Get images and labels from the batch
        batch_size = test_img.shape[0]  # Get the batch size
        flat = test_img.reshape(batch_size, -1).to(device, non_blocking=True)  # Flatten images and move to device
        pred_embeddings = mynet(flat).cpu().numpy().tolist()  # Get embeddings from the model and move to CPU
        test_embeddings.extend(pred_embeddings)  # Store the embeddings
        test_labels.extend(test_label.numpy().tolist())  # Store the labels

# Convert test labels to numpy array for further processing
test_labels = np.array(test_labels)

# Indicate that feature extraction is complete
print("Feature extraction done!")

Since the network was trained using InfoNCE, we will normalize each feature to unit norm. Additionally, PCA expects the features to be centered and standardized to have a standard deviation of 1.

In [None]:
# Convert list of embeddings to a numpy array
test_embeddings = np.array(test_embeddings)

# Normalize the embeddings
test_embeddings_normed = test_embeddings / np.linalg.norm(test_embeddings, axis=1, keepdims=True)

# Center the normalized embeddings by subtracting the mean
test_embeddings_normed = test_embeddings_normed - np.mean(test_embeddings_normed, axis=1, keepdims=True)

# Standardize the centered embeddings by dividing by the standard deviation
test_embeddings_normed = test_embeddings_normed / np.std(test_embeddings_normed, axis=1, keepdims=True)

# Initialize PCA with 2 components
pca = PCA(n_components=2)

# Fit PCA on the embeddings and transform them to 2D
pca_embeddings = pca.fit_transform(test_embeddings)

# Optional: Print the shape of the resulting PCA embeddings to verify
print("PCA embeddings shape:", pca_embeddings.shape)

For t-SNE, we simply normalize each feature to unit norm due to InfoNCE. We will not perform any additional centering.

In [None]:
# Convert list of embeddings to a numpy array
test_embeddings = np.array(test_embeddings)

# Normalize the embeddings to unit length by dividing each embedding by its L2 norm
test_embeddings_normed = test_embeddings / np.linalg.norm(test_embeddings, axis=1, keepdims=True)

# Initialize t-SNE with 2 components for dimensionality reduction
tsne = TSNE(n_components=2)

# Fit t-SNE on the normalized embeddings and transform them to 2D
tsne_embeddings = tsne.fit_transform(test_embeddings)

# Notify that the t-SNE transformation may take some time
print("t-SNE transformation in progress... This may take a minute. Go grab a coffee or something.")

# Optional: Print the shape of the resulting t-SNE embeddings to verify
print("t-SNE embeddings shape:", tsne_embeddings.shape)

In [None]:
test_labels.shape, tsne_embeddings.shape

Observe the distribution of features for each number! Notice how well-separated the embeddings for different characters are.

In [None]:
# Use t-SNE embeddings for visualization
my_embeddings = tsne_embeddings
# TSNE or PCA? TSNE is nicer to look at.

# Plot embeddings for digit '0' in red
num = 0
plt.scatter(my_embeddings[test_labels==num, 0], my_embeddings[test_labels==num, 1], c="red")

# Plot embeddings for digit '1' in green
num = 1
plt.scatter(my_embeddings[test_labels==num, 0], my_embeddings[test_labels==num, 1], c="green")

# Plot embeddings for digit '2' in blue
num = 2
plt.scatter(my_embeddings[test_labels==num, 0], my_embeddings[test_labels==num, 1], c="blue")

# Plot embeddings for digit '3' in orange
num = 3
plt.scatter(my_embeddings[test_labels==num, 0], my_embeddings[test_labels==num, 1], c="orange")

### Visualizing the cosine similarity after training

Observe that the diagonal elements are significantly more positive than the off-diagonal elements. This indicates that the similarity within the same class is much stronger than the similarity between different classes.

In [None]:
# Create DataLoader for the test dataset with a batch size of 50
test_loader = DataLoader(test_dset, batch_size=50, shuffle=False)  # Enable persistent_workers=True if more than 1 worker to save CPU

# Set the model to evaluation mode
mynet.eval()

# Initialize lists to store test embeddings and labels
test_embeddings = []
test_labels = []

# Initialize a similarity matrix of size 10x10 for 10 classes
sim_matrix = np.zeros((10, 10))

# Disable gradient computation for inference
with torch.inference_mode():
    for data_batch in test_loader:
        # Get images and labels from the batch
        test_img, test_label = data_batch
        batch_size = test_img.shape[0]  # Get the batch size

        # Flatten images and move data to the selected device
        flat = test_img.reshape(batch_size, -1).to(device, non_blocking=True)

        # Get embeddings from the model and move to CPU
        pred_embeddings = mynet(flat).cpu().numpy().tolist()

        # Store the embeddings and labels
        test_embeddings.extend(pred_embeddings)
        test_labels.extend(test_label.numpy().tolist())

# Convert embeddings and labels to numpy arrays for further processing
test_embeddings = np.array(test_embeddings)

# Normalize the embeddings to unit length by dividing each embedding by its L2 norm
test_embeddings_normed = test_embeddings / np.linalg.norm(test_embeddings, axis=1, keepdims=True)

# Convert test labels to a numpy array
test_labels = np.array(test_labels)

# Dictionary to store normalized embeddings for each class
embeddings = {}
for i in range(10):
    embeddings[i] = test_embeddings_normed[test_labels == i]

# Calculate within-class cosine similarity
for i in range(10):
    # Compute cosine similarity matrix within the class
    sims = embeddings[i] @ embeddings[i].T

    # Ignore diagonal values (self-similarity)
    np.fill_diagonal(sims, np.nan)

    # Calculate the mean similarity excluding diagonal
    cur_sim = np.nanmean(sims)

    # Store the within-class similarity in the matrix
    sim_matrix[i, i] = cur_sim

    # Print the within-class cosine similarity
    print("Within class {} cosine similarity".format(cur_sim))

print("==================")

# Calculate between-class cosine similarity
for i in range(10):
    for j in range(10):
        if i == j:
            pass  # Skip if same class (already computed)
        elif i > j:
            pass  # Skip if already computed (matrix symmetry)
        else:
            # Compute cosine similarity between different classes
            sims = embeddings[i] @ embeddings[j].T

            # Calculate the mean similarity
            cur_sim = np.mean(sims)

            # Store the similarity in the matrix
            sim_matrix[i, j] = cur_sim
            sim_matrix[j, i] = cur_sim  # Ensure symmetry in the matrix

            # Print the between-class cosine similarity
            print("{} and {} cosine similarity {}".format(i, j, cur_sim))

# Plot the similarity matrix using matplotlib
plt.imshow(sim_matrix, vmin=0.0, vmax=1.0)
plt.title("Trained Network Cosine Similarity Matrix")
plt.colorbar()
plt.show()

### Using the network to identify nearest neighbors in the test set.

How do people actually use a contrastive learning network? In Person Re-Identification (Person Re-ID), a network computes embeddings for two images and checks if the cosine or Euclidean similarity between these embeddings exceeds a certain threshold to determine if they depict the same person.

In foundation model training, such as with CLIP, the typical approach is to fine-tune the entire network or train a linear probe or small network on the outputs of the last layer.

Here, we will follow the Person Re-ID setup to find the most similar image in a test set and determine if they represent the same character.

In [None]:
# Calculate the cosine similarity matrix for all embeddings
sims_all = test_embeddings_normed @ test_embeddings_normed.T

# Set diagonal elements to a large negative value to avoid self-matching
np.fill_diagonal(sims_all, -1000.0)  # Set to a small value so it doesn't give us the same number for argmax

# Index of the embedding to check for the most similar embedding
idx_to_check = 3029

# Find the index of the most similar embedding (excluding itself)
best_idx = np.argmax(sims_all[idx_to_check])

# Plot the image corresponding to the index to check
plt.imshow(test_dset[idx_to_check][0][0].cpu().numpy())
plt.show()

# Plot the image corresponding to the most similar embedding
plt.imshow(test_dset[best_idx][0][0].cpu().numpy())
plt.show()

### How is contrastive learning used in practice?

Nearly all vision foundation models, such as DINO, DINOv2, CLIP, and their derivatives (including OpenCLIP and EVA-CLIP), are trained using contrastive losses. DINO and DINOv2 are trained solely on images, while CLIP is trained on a combination of images and text.

When only images are used, the contrastive learning loss is applied to augmentations of the same image. These augmentations can include crops, flips, and rotations, and this approach is referred to as a "pretext task." Typically, augmentations of the same image are treated as instances where the embeddings should be the same. For example, a network should recognize a photo of you and a photo of you flipped, with altered brightness, noise added, or converted to black and white, as representing the same person.

When images and text are used together, as in CLIP, the training data consists of images and their corresponding captions. For example, the caption "A photo of a dog" might be paired with a picture of a blue heeler puppy. These captions are typically scraped from online sources and collected into datasets like LAION-2B, COYO-700M, and CommonCrawl. Although these captions are often of varying quality, the sheer volume of data helps to mitigate this issue.

In this case, contrastive learning typically employs a dual encoder system—one for text and one for images. The network is trained using a loss function that minimizes the distance between the correct text-image pairs while maximizing the distance between incorrect pairs. For example, the caption "A photo of a dog" should have embeddings close to the image of the blue heeler puppy and far from the image of a cat. To compute the "distance" of the embeddings, methods such as normalized dot-product (cosine similarity), angular distance (Universal Sentence Encoder), Euclidean distance, or squared Euclidean distance are often used.

### References:

[1] Unsupervised feature learning via non-parametric instance discrimination (2018)

[2] Representation learning with contrastive predictive coding (2018)

[3] A simple framework for contrastive learning of visual representations (2020)

[4] Improved Deep Metric Learning with Multi-class N-pair Loss Objective (2016)

[4] Noise-contrastive estimation: A new estimation principle for unnormalize statistical models (2010)