<a href="https://colab.research.google.com/github/dhruvi-05/contrastive_learning/blob/main/contrastive_learning_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
from google.colab import files

In [None]:
from google.colab import files

# Upload the zip file
uploaded = files.upload()

# Unzip the dataset
!unzip -o images.zip -d /content/images/

Saving images.zip to images.zip
Archive:  images.zip
   creating: /content/images/images/
  inflating: /content/images/images/1.JPG  
  inflating: /content/images/images/10.JPG  
  inflating: /content/images/images/100.JPG  
  inflating: /content/images/images/1000.JPG  
  inflating: /content/images/images/1001.JPG  
  inflating: /content/images/images/1002.JPG  
  inflating: /content/images/images/1003.JPG  
  inflating: /content/images/images/1004.JPG  
  inflating: /content/images/images/1005.JPG  
  inflating: /content/images/images/1006.JPG  
  inflating: /content/images/images/1007.JPG  
  inflating: /content/images/images/1008.JPG  
  inflating: /content/images/images/1009.JPG  
  inflating: /content/images/images/101.JPG  
  inflating: /content/images/images/1010.JPG  
  inflating: /content/images/images/1011.JPG  
  inflating: /content/images/images/1012.JPG  
  inflating: /content/images/images/1013.JPG  
  inflating: /content/images/images/1014.JPG  
  inflating: /content/im

In [None]:
# Define the augmentation for SimCLR
class SimCLRTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=224),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

In [None]:
#creating two different augmented versions of the dataset as positive pairs
class ContrastiveLearningDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = datasets.ImageFolder(root=root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, _ = self.dataset[idx]
        if self.transform:
            xi, xj = self.transform(image)
            return xi, xj
        else:
            return image, image



In [None]:
# Parameters
batch_size = 64
dataset_path = '/content/images'  # Update this path if different

# Initialize the dataset and dataloader
simclr_transform = SimCLRTransform()
dataset = ContrastiveLearningDataset(root_dir=dataset_path, transform=simclr_transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)




In [None]:
class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature=0.5, device='cuda'):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.mask = self._get_correlated_mask().type(torch.bool)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def _get_correlated_mask(self):
        N = 2 * self.batch_size
        mask = torch.ones((N, N)) - torch.eye(N)
        return mask.to(self.device)

    def forward(self, zi, zj):
        N = 2 * self.batch_size

        zi = F.normalize(zi, dim=1)
        zj = F.normalize(zj, dim=1)

        representations = torch.cat([zi, zj], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

        # Exponentiate similarity matrix divided by temperature
        similarity_matrix = similarity_matrix / self.temperature

        # For numerical stability
        similarity_matrix = torch.exp(similarity_matrix) * (1 - torch.eye(N, device=self.device))

        # Compute positives
        positives = torch.cat([torch.diag(similarity_matrix, self.batch_size),
                               torch.diag(similarity_matrix, -self.batch_size)], dim=0).view(N, 1)

        # Compute denominator
        denominator = similarity_matrix.sum(dim=1).view(N, 1)

        # Compute loss
        loss = -torch.log(positives / denominator)
        loss = loss.mean()
        return loss


In [None]:
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder(pretrained=True)
        # Remove the final fully connected layer
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

        # Projection head
        self.projection_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        h = h.squeeze()  # Remove extra dimensions
        z = self.projection_head(h)
        return z


In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize the model
model = SimCLR(base_encoder=models.resnet18, projection_dim=128)
model = model.to(device)

# Initialize the loss function
criterion = NTXentLoss(batch_size=batch_size, temperature=0.5, device=device)

# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)


Using device: cpu


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 63.7MB/s]


In [None]:
# Number of epochs
epochs = 10

for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0

    for step, (xi, xj) in enumerate(train_loader):
        xi = xi.to(device)
        xj = xj.to(device)

        optimizer.zero_grad()

        # Get projections
        zi = model(xi)
        zj = model(xj)

        # Compute loss
        loss = criterion(zi, zj)

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.4f}')


Epoch [1/10], Loss: 4.0980
Epoch [2/10], Loss: 3.8101
Epoch [3/10], Loss: 3.6511
Epoch [4/10], Loss: 3.5984
Epoch [5/10], Loss: 3.5636
Epoch [6/10], Loss: 3.5303
Epoch [7/10], Loss: 3.4991


KeyboardInterrupt: 