
## Section 1: Problem Formulation

### Introduction:
In this project, we aim to explore the effectiveness of contrastive learning, specifically using the SimCLR algorithm, 
for image classification tasks under the constraint of having limited labeled data. Contrastive learning is a technique 
in self-supervised learning that learns to encode similar items closer in the feature space while pushing dissimilar items further apart.


In [None]:
import sys
on_linux = sys.platform.startswith('linux')

In [None]:
# Import necessary libraries
import torch

# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
args = {
    "dataset": "cifar10",
    "model": "resnet50",
    "batch_size": 1024,
    "sample_rate": 1,
    "epochs": 100,
    "n_views": 2,
    "out_dim": 128,
    "lr": 12e-4,
    "wd": 1e-6,
    "log_every_n_steps": 5,
    "n_workers": 16,
    "temperature": 0.07,
    "learning": "contrastive",
    "val_split": 0.2,
}


## Section 2: Dataset Preparations

In this section, we will prepare the CIFAR-10, CIFAR-100, and MedMNIST datasets for training. 
We will apply necessary transformations and split the datasets into training, validation, and test sets.


In [None]:
from dataset import SimCLRDataset
data = SimCLRDataset(args["dataset"])
build_dataloader = lambda dataset: torch.utils.data.DataLoader(
    dataset,
    batch_size=args["batch_size"],
    shuffle=True,
    drop_last=True,
    num_workers=args["n_workers"],
)
train_dataset, val_dataset = data.get_train_val_datasets(args["n_views"], args["val_split"])
train_loader = build_dataloader(train_dataset)
val_loader = build_dataloader(val_dataset)
test_dataset = data.get_test_dataset(args["n_views"])
test_loader = build_dataloader(test_dataset)
num_classes = data.num_classes
print("# Classes:", num_classes)
print("# Train, Val, Test:", len(train_dataset), len(val_dataset), len(test_dataset))

## Section 3: Deep Learning Model

In this section, we will prepare the popular choice of deep learning model like ResNet18 and VGG16.


In [None]:
from model import SimCLRCNN 
model_args = {
    "backbone": args["model"],
    "out_dim": args["out_dim"] if args["learning"] == "contrastive" else num_classes,
    "mod": args["learning"] == "contrastive",
}
model = SimCLRCNN(**model_args).to(device)
if on_linux:
    model = torch.compile(model)
    torch.set_float32_matmul_precision('high')

## Section 4: Contrastive Training

In [None]:
from train import contrastive_training, supervised_training
from eval import info_nce_loss
if args["learning"] == "contrastive":
    loss_fn = info_nce_loss
    criterion = torch.nn.CrossEntropyLoss()
    train_records, test_records = contrastive_training(model, train_loader, val_loader, loss_fn, criterion, device, args)
else:
    loss_fn = torch.nn.CrossEntropyLoss()
    train_records, test_records = supervised_training(model, train_loader, val_loader, loss_fn, device, args)

In [None]:
import pandas as pd
timestamp = pd.Timestamp.now().strftime("%m%d%H%M")

df = pd.DataFrame.from_records(train_records)
test_df = pd.DataFrame.from_records(test_records)
df.to_csv(f"logs/{args['model']}_{args['dataset']}_{timestamp}_train.csv", index=False)
test_df.to_csv(f"logs/{args['model']}_{args['dataset']}_{timestamp}_test.csv", index=False)
display(test_df)

In [None]:
import matplotlib.pyplot as plt
plt.plot(df['loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

In [None]:
plt.plot(test_df['top1'])
plt.plot(test_df['top5'])
plt.legend(['Top1', 'Top5'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Test Accuracy')
plt.show()
