# Prerequisite

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm

In [2]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
temperature = 0.5
num_pretrain_epochs = 40
num_downstream_epochs = 10
encoder_dim = 512
projection_dim = 128

In [3]:
# Transformations for data augmentation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])

augmentation_transform = transforms.Compose([
    transforms.RandomApply([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        ], p=0.5),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.2),
    transforms.RandomRotation(degrees=90),
])

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset

In [5]:
# Load MNIST dataset
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Dataloader

In [6]:
# Create a DataLoader for training and testing
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=2)

# Network

In [11]:
# Encoder network (small and simple for demonstration)
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Linear(128 * 3 * 3, encoder_dim)

        self.projection_head = nn.Sequential(
            nn.Linear(encoder_dim, 2*projection_dim),
            nn.ReLU(),
            nn.Linear(2*projection_dim, projection_dim)
        )


    def forward(self, x, project_head = False):
        x = self.conv_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        if project_head:
            x = self.projection_head(x)
        return x

if True:
    print('testing:')
    m = Encoder()
    x = torch.randn((10, 1, 28, 28))
    print(m(x).shape)

testing:
torch.Size([10, 512])


In [12]:
class SimCLRLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(SimCLRLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        """
        SimCLR loss function.

        Parameters:
            z1 (torch.Tensor): Embeddings of the first set of augmented samples.
            z2 (torch.Tensor): Embeddings of the second set of augmented samples.

        Returns:
            torch.Tensor: SimCLR loss.
        """
        # Normalize the embeddings
        z1 = F.normalize(z1, dim=-1, p=2)
        z2 = F.normalize(z2, dim=-1, p=2)

        # Compute cosine similarity
        similarity_matrix = torch.matmul(z1, z2.t()) / self.temperature

        # Diagonal elements represent positive pairs, off-diagonal elements are negative pairs
        # For each row, the diagonal element is treated as a positive sample, and the rest are treated as negative samples
        positive_pairs = torch.diag(similarity_matrix, 0)
        negative_pairs = torch.exp(similarity_matrix - torch.eye(similarity_matrix.size(0), device=z1.device) * 1e9)

        # Compute loss
        numerator = -torch.log(torch.exp(positive_pairs / self.temperature) / torch.sum(negative_pairs, dim=1))
        loss = torch.mean(numerator)

        return loss

if True:
    m = SimCLRLoss()
    x1 = torch.randn((10, 128))
    x2 = torch.randn((10, 128))
    print(m(x1, x1))


tensor(-1.8336)


# Training

In [13]:
encoder = Encoder().to(device)
loss_fn = SimCLRLoss().to(device)
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [14]:
for epoch_num in range(num_pretrain_epochs):
    progress_bar = tqdm(total=len(train_loader), desc=f'Training Progress Epoch {epoch_num}/{num_pretrain_epochs}')
    total_loss = 0
    for iter_num, (imgs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        imgs = imgs.to(device)
        imgs_repr = encoder(imgs, project_head=True)
        imgs_augmented_repr = encoder(augmentation_transform(imgs), project_head=True)
        loss = loss_fn(imgs_repr, imgs_augmented_repr)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item(), avg_loss=(total_loss/(iter_num+1)))
        progress_bar.update(1)

Training Progress Epoch 0/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 1/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 2/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 3/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 4/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 5/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 6/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 7/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 8/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 9/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 10/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 11/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 12/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 13/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 14/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 15/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 16/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 17/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 18/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 19/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 20/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 21/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 22/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 23/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 24/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 25/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 26/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 27/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 28/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 29/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 30/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 31/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 32/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 33/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 34/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 35/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 36/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 37/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 38/40:   0%|          | 0/938 [00:00<?, ?it/s]

Training Progress Epoch 39/40:   0%|          | 0/938 [00:00<?, ?it/s]

# Evaluation

In [15]:
from torch.utils.data import DataLoader, Subset, random_split

class_indices = [[] for _ in range(10)]  # 10 classes in MNIST
shuffled_mnist_dataset = torch.utils.data.SubsetRandomSampler(range(len(mnist_train)))

for i, (img, label) in enumerate(mnist_train):
    class_indices[label].append(i)

In [16]:
# Step 3: Select a small number of samples from each class
few_shot_samples_per_class = 100
few_shot_indices = []
for indices in class_indices:
    few_shot_indices.extend(indices[:few_shot_samples_per_class])

# Step 4: Create a few-shot dataset
few_shot_dataset = torch.utils.data.Subset(mnist_train, few_shot_indices)
few_shot_dataloader = DataLoader(few_shot_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [17]:
classifier = nn.Sequential(encoder, nn.Linear(encoder_dim, 10)).to(device)
for param in encoder.parameters():
    param.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

In [18]:
for epoch in range(num_downstream_epochs):
    classifier.train()  # Set the network to training mode
    total_loss = 0
    for images, labels in tqdm(few_shot_dataloader):
        images, labels = images.to(device), labels.to(device)
        # Forward pass
        optimizer.zero_grad()
        outputs = classifier(images)
        loss = loss_fn(outputs, labels)
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch + 1}/{num_downstream_epochs}, Loss: {total_loss/len(few_shot_dataloader):.4f}')

  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 1/10, Loss: 18.4426


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 2/10, Loss: 7.4948


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 3/10, Loss: 4.7874


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 4/10, Loss: 3.5882


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 5/10, Loss: 3.6899


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 6/10, Loss: 3.8976


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 7/10, Loss: 3.3727


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 8/10, Loss: 2.8603


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 9/10, Loss: 2.8167


  0%|          | 0/16 [00:00<?, ?it/s]

Epoch 10/10, Loss: 3.1413


In [19]:
with torch.no_grad():
    total_samples, total_correct = 0, 0
    classifier.eval()
    for images, labels in tqdm(test_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classifier(images)

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

# Calculate accuracy and average loss
accuracy = (total_correct / total_samples) * 100.0

print(f'Test Accuracy: {accuracy:.2f}%')

  0%|          | 0/157 [00:00<?, ?it/s]

Test Accuracy: 51.03%
