
## Section 1: Problem Formulation

### Introduction:
In this project, we aim to explore the effectiveness of contrastive learning, specifically using the SimCLR algorithm, 
for image classification tasks under the constraint of having limited labeled data. Contrastive learning is a technique 
in self-supervised learning that learns to encode similar items closer in the feature space while pushing dissimilar items further apart.


In [3]:
# Import necessary libraries
import torch
import torchvision
from torch import nn

# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
args = {
    "dataset": "cifar10",
    "model": "resnet18",
    "batch_size": 1024,
    "epochs": 100,
    "n_views": 2,
    "out_dim": 128,
    "lr": 12e-4,
    "log_every_n_steps": 50,
    "n_workers": 12,
    "temperature": 0.5,
}


## Section 2: Dataset Preparations

In this section, we will prepare the CIFAR-10, CIFAR-100, and MedMNIST datasets for training. 
We will apply necessary transformations and split the datasets into training, validation, and test sets.


In [5]:
from dataset import SimCLRDataset
data = SimCLRDataset(args["dataset"])
train_dataset = data.get_dataset(n_views=args["n_views"])
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=args["batch_size"], 
    num_workers=args["n_workers"],
    shuffle=True, 
    drop_last=True, 
    pin_memory=True,
)



## Section 3: Deep Learning Model

In this section, we will prepare the popular choice of deep learning model like ResNet18 and VGG16.


In [6]:
from model import SimCLRCNN 
model = SimCLRCNN(backbone=args["model"], out_dim=args["out_dim"])
model = model.to(device)

## Section 4: Contrastive Training

In [7]:
from utils import metric, info_nce_loss

In [8]:
from tqdm import tqdm

n_iter = 0
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

for epoch_counter in range(args['epochs']):

    for images, _ in tqdm(train_loader):
        images = torch.cat(images, dim=0)
        images = images.to(device)
        features = model(images)
        logits, labels = info_nce_loss(features, device, args)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        if n_iter % args['log_every_n_steps'] == 0:
            top1, top5 = metric(logits, labels, topk=(1, 5))

        n_iter += 1

    print(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")

print("Training has finished.")
# save model checkpoints

100%|██████████| 48/48 [01:52<00:00,  2.34s/it]  


Epoch: 0	Loss: 7.526223182678223	Top1 accuracy: 0.29296875


  0%|          | 0/48 [00:00<?, ?it/s]