# SimCLR Training 

## Loading libraries and requirements

In [None]:
!pip install datasets

In [None]:
!pip install tensorboard

In [None]:

import os
from copy import deepcopy


import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') 
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set()

from tqdm.notebook import tqdm

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.utils.data import Subset

## Torchvision
import torchvision
from torchvision.datasets import STL10
from torchvision import transforms


# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


from datasets import load_dataset
from datasets import concatenate_datasets
from datasets import DatasetDict
from datasets import Value

# Import tensorboard
%load_ext tensorboard


DATASET_PATH = "../data"
CHECKPOINT_PATH = "/content/drive/MyDrive/ms_nist/brain"
NUM_WORKERS = os.cpu_count()

# Setting the seed
pl.seed_everything(42)






## Data preprocessing and augmentation

In [None]:


ds = load_dataset("youngp5/BrainMRI")
ds2 = load_dataset("Mahadih534/brain-tumor-MRI-dataset")
ds2_t= load_dataset("Falah/Alzheimer_MRI",split="train")
ds2_test= load_dataset("Falah/Alzheimer_MRI",split="test")
ds3= load_dataset("BTX24/tekno21-brain-stroke-dataset-binary")


In [None]:
ds3["train"]

In [None]:
ds2_test

In [None]:
ds['train'][5]["image"]

In [None]:


# Convert labels to int64 (remove ClassLabel)
ds = ds.cast_column("label", Value("int64"))

ds3 = ds3.cast_column("label", Value("int64"))
ds2_t = ds2_t.cast_column("label", Value("int64"))
ds2_test = ds2_test.cast_column("label", Value("int64"))
merged_train = concatenate_datasets([
    ds['train'], ds2_t,ds2_test,ds2['train'], ds3['train']
])


merged_dataset = DatasetDict({'train': merged_train})


In [None]:


resize_transform = transforms.Resize((128, 128))
grayscale=transforms.Grayscale(num_output_channels=3)
def resize_example(example):
    
    example['image'] = grayscale(example['image'])
    example['image'] = resize_transform(example['image'])
    return example

merged_dataset = merged_dataset.map(resize_example)
merged_dataset = merged_dataset.shuffle(seed=42)
merged_dataset = merged_dataset['train'].train_test_split(test_size=0.1)

In [None]:
merged_dataset = merged_dataset.map(lambda x: {'label': -1})

In [None]:
class ContrastiveTransformations(object):

    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

In [None]:

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.01):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean

    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

contrast_transforms = transforms.Compose([
    # Resize to fixed size
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=128,scale=(0.2,1)),
    transforms.RandomRotation(degrees=30),
      transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    transforms.GaussianBlur(kernel_size=9),
    transforms.ToTensor(),AddGaussianNoise(0.01),

    transforms.Normalize(mean=(0.5,), std=(0.5,))


])


In [None]:


contrastive_transform = ContrastiveTransformations(contrast_transforms, n_views=2)
class SimCLRDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, n_views=2):
        self.dataset = dataset
        self.transform = transform
        self.n_views = n_views


    def __getitem__(self, idx):
        img = self.dataset[idx]['image']
        label = self.dataset[idx]['label']
        views = [self.transform(img)  for _ in range(self.n_views)]

        return views, label

    def __len__(self):
        return len(self.dataset)


In [None]:
merged_dataset

In [None]:
simclr_dataset = SimCLRDataset(merged_dataset['train'], contrast_transforms)


In [None]:
simclr_test=SimCLRDataset(merged_dataset['test'], contrast_transforms)

In [None]:
simclr_test[5][0]

In [None]:
import os
import torch

from torch.utils.data import DataLoader


In [None]:
# Visualize some examples
pl.seed_everything(42)
NUM_IMAGES = 3
imgs = torch.stack([img for idx in range(NUM_IMAGES) for img in simclr_test[idx][0]], dim=0)
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(10,5))
plt.title('Augmented image examples of the MRI dataset')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

## SimCLR architecture and trainer 

In [None]:
class SimCLR(pl.LightningModule):

    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=200):
        super().__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        # Base model f(.)
        self.convnet = torchvision.models.resnet18(num_classes=4*hidden_dim)  # Output of last linear layer
        # The MLP for g(.) consists of Linear->ReLU->Linear
        self.convnet.fc = nn.Sequential(
            self.convnet.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4*hidden_dim, hidden_dim)
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr/50)
        return [optimizer], [lr_scheduler]

    def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch
        imgs = torch.cat(imgs, dim=0)

        # Encode all images
        feats = self.convnet(imgs)
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

        # Logging loss
        self.log(mode+'_loss', nll)
        # Get ranking position of positive example
        comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
                              cos_sim.masked_fill(pos_mask, -9e15)],
                             dim=-1)
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
        self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
        self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())

        return nll

    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode='val')

In [None]:
def train_simclr(batch_size, max_epochs=500, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, 'brain'),
                         accelerator="auto",
                         devices=1,log_every_n_steps=37,
                         max_epochs=max_epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_acc_top1'),
                                    LearningRateMonitor('epoch')])
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need


    train_loader = data.DataLoader(simclr_dataset, batch_size=batch_size, shuffle=True,
                                    drop_last=True, pin_memory=True,num_workers=94)

    val_loader = data.DataLoader(simclr_test, batch_size=batch_size, shuffle=False,
                                  drop_last=False, pin_memory=True,num_workers=94)

    pl.seed_everything(42) # To be reproducable
    model = SimCLR(max_epochs=max_epochs, **kwargs)
    trainer.fit(model, train_loader, val_loader)
    model = SimCLR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    return model

In [None]:
simclr_model = train_simclr(batch_size=512,
                            hidden_dim=128,
                            lr=5e-4,
                            temperature=0.07,
                            weight_decay=1e-4,
                            max_epochs=200)