# SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. With support for the LARS (Layer-wise Adaptive Rate Scaling) optimizer and global batch norm.

[Link to paper](https://arxiv.org/pdf/2002.05709.pdf)


## Setup the repository

In [1]:
%pip install  pyyaml --upgrade

Note: you may need to restart the kernel to use updated packages.


# Part 1:
## SimCLR pre-training

In [2]:
import os
import torch
import numpy as np  
import argparse

apex = False
try:
    from apex import amp
    apex = True
except ImportError:
    print(
        "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
    )


from simclr import SimCLR
from simclr.modules import get_resnet, NT_Xent
from simclr.modules.transformations import TransformsSimCLR

Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training


### Load arguments from `config/config.yaml`

In [3]:
from pprint import pprint
import argparse
from utils import yaml_config_hook

parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args([])
#args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
### override any configuration parameters here, e.g. to adjust for use on GPUs on the Colab platform:
args.batch_size = 128
args.resnet = "resnet18"
pprint(vars(args))

{'batch_size': 154,
 'dataparallel': 0,
 'dataset': '',
 'dataset_dir': './data',
 'epoch_num': 10,
 'epochs': 10,
 'gpus': ['0 1'],
 'image_size': 64,
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'save',
 'nodes': 1,
 'nr': 1,
 'optimizer': 'Adam',
 'pretrain': True,
 'projection_dim': 64,
 'reload': False,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 12}


### Load dataset into train loader

In [5]:
import pandas as pd

if not os.path.isfile('IFCB.csv.zip'):
    print("CSV data do not exist. Downloading...")
    !wget -O IFCB.csv.zip "https://unioviedo-my.sharepoint.com/:u:/g/personal/gonzalezgpablo_uniovi_es/EfsVLhFsYJpPjO0KZlpWUq0BU6LaqJ989Re4XzatS9aG4Q?download=1"

data = pd.read_csv('IFCB.csv.zip',compression='infer', header=0,sep=',',quotechar='"')

#Compute sample and year information
data['year'] = data['Sample'].str[6:10].astype(str) #Compute the year
samples=data.groupby('Sample').first()
samples = samples[["year"]]
print(data)

                        Sample  roi_number        OriginalClass  \
0        IFCB1_2006_158_000036           1                  mix   
1        IFCB1_2006_158_000036           2  Tontonia_gracillima   
2        IFCB1_2006_158_000036           3                  mix   
3        IFCB1_2006_158_000036           4                  mix   
4        IFCB1_2006_158_000036           5                  mix   
...                        ...         ...                  ...   
3457814  IFCB5_2014_353_205141        6850       Leptocylindrus   
3457815  IFCB5_2014_353_205141        6852                  mix   
3457816  IFCB5_2014_353_205141        6855                  mix   
3457817  IFCB5_2014_353_205141        6856                  mix   
3457818  IFCB5_2014_353_205141        6857                  mix   

              AutoClass FunctionalGroup  year  
0                   mix      Flagellate  2006  
1           ciliate_mix         Ciliate  2006  
2                   mix      Flagellate  2006  
3  

In [6]:
import progressbar
from tqdm import tqdm
from shutil import copyfile
import numpy as np

tqdm.pandas()

classcolumn = "AutoClass" #AutoClass means 51 classes
yearstraining = ['2006','2007'] #Years to consider as training
#yearsvalidation = ['2008'] #years validation
yearstest = ['2008'] #Years to consider as test

samplestraining = list(samples[samples['year'].isin(yearstraining)].index) #Samples to consider for training
#samplesval = list(samples[samples['year'].isin(yearsvalidation)].index) #Samples to consider for validation
samplestest = list(samples[samples['year'].isin(yearstest)].index) #Samples to consider for testing

classes=np.unique(data[classcolumn])
classes.sort()


In [7]:
import torchvision.transforms as T
from h5ifcbdataset import H5IFCBDataset
from torch.utils.data import DataLoader

hdf5_files_path = "/media/nas/olayap/env_olaya/TFM/IFBC_HDF5_olaya/output/"

#files to load
filestraining = [hdf5_files_path+s+'.hdf5' for s in samplestraining]
filestest = [hdf5_files_path+s+'.hdf5' for s in samplestest]

#Define data loaders -- SimCLR transform (2 images data augmentation)
train_dset = H5IFCBDataset(filestraining,classes,classattribute=classcolumn, verbose=1,trainingset=False,transform=TransformsSimCLR(size=args.image_size))
train_loader = DataLoader(train_dset,batch_size=args.batch_size,num_workers=args.workers,shuffle=True,pin_memory=True,drop_last=True)

test_dset = H5IFCBDataset(filestest,classes,classattribute=classcolumn, verbose=1,trainingset=False,transform=TransformsSimCLR(size=args.image_size))
test_loader = DataLoader(test_dset,batch_size=args.batch_size,num_workers=args.workers,shuffle=False,pin_memory=True)

Loading samples: 100%|██████████| 164/164 [08:49<00:00,  3.23s/it]
Loading samples: 100%|██████████| 122/122 [10:15<00:00,  5.05s/it]


In [8]:
args.epochs = 3
args.num_epochs = 3

### Load the SimCLR model, optimizer and learning rate scheduler

In [9]:
def save_model(args, model, optimizer):
    out = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.current_epoch))

In [10]:
import torch.nn as nn

# initialize ResNet
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features  
encoder.fc = nn.Linear(encoder.fc.in_features, len(classes))


# initialize model
model = SimCLR(encoder, args.projection_dim, n_features)
if args.reload:
    model_fp = os.path.join(
        args.model_path, "checkpoint_{}.tar".format(args.epoch_num))
        
    model.load_state_dict(torch.load(model_fp, map_location=device.type))
model = model.to(device)

# optimizer / loss
scheduler = None
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)  

### Initialize the criterion (NT-Xent loss)

In [11]:
criterion = NT_Xent(args.batch_size, args.temperature, world_size=1)

### Train function

In [18]:
def train(args, train_loader, model, criterion, optimizer): 
    loss_epoch = 0
    for step, ((x_i, x_j),_, _)  in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)
    
        # positive pair, with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)

        loss = criterion(z_i, z_j)
        loss.backward()

        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        loss_epoch += loss.item()
        args.global_step += 1
    return loss_epoch

### Start training

In [19]:
torch.cuda.empty_cache()
torch.cuda.memory_summary(device=None, abbreviated=False)

args.global_step = 0
args.current_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    loss_epoch = train(args, train_loader, model, criterion, optimizer) 

    if scheduler:
        scheduler.step()

    # save every 10 epochs
    if epoch % 10 == 0:
        save_model(args, model, optimizer)

    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

# end training
save_model(args, model, optimizer)

Step [0/2624]	 Loss: 4.763722896575928
Step [164/2624]	 Loss: 4.879242897033691
Step [328/2624]	 Loss: 4.844408988952637
Step [492/2624]	 Loss: 4.723027229309082
Step [656/2624]	 Loss: 4.7388834953308105
Step [820/2624]	 Loss: 4.6154327392578125
Step [984/2624]	 Loss: 4.828637599945068
Step [1148/2624]	 Loss: 4.824036598205566
Step [1312/2624]	 Loss: 4.757033348083496
Step [1476/2624]	 Loss: 4.790659427642822
Step [1640/2624]	 Loss: 4.795215129852295
Step [1804/2624]	 Loss: 4.833906173706055
Step [1968/2624]	 Loss: 4.740353107452393
Step [2132/2624]	 Loss: 4.719961166381836
Step [2296/2624]	 Loss: 4.7204999923706055
Step [2460/2624]	 Loss: 4.6652140617370605


RuntimeError: shape '[308, 1]' is invalid for input of size 252

# Part 2:
## Linear evaluation using logistic regression, using weights from frozen, pre-trained SimCLR model

In [None]:
import torch
import torchvision
import numpy as np
import argparse

In [20]:
class LogisticRegression(nn.Module):
    def __init__(self, n_features, n_classes):
        super(LogisticRegression, self).__init__()

        self.model = nn.Linear(n_features, n_classes)

    def forward(self, x):
        return self.model(x)

In [21]:
def train(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    return loss_epoch, accuracy_epoch

In [None]:
def test(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch


In [22]:
from pprint import pprint
from utils import yaml_config_hook

parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args([])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [23]:
#args.resnet = "resnet50" # make sure to check this with the (pre-)trained checkpoint
args.model_path = "logs"
args.epoch_num = 10
args.logistic_epochs = 50

### Load dataset into train/test dataloaders

In [24]:
train_loader = torch.utils.data.DataLoader(
    train_dset,
    batch_size=args.logistic_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=args.workers,
)

test_loader = torch.utils.data.DataLoader(
    test_dset,
    batch_size=args.logistic_batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=args.workers,
)

### Load ResNet encoder / SimCLR and load model weights

In [None]:
encoder = get_resnet(args.resnet, pretrained=False) # don't load a pre-trained model from PyTorch repo
n_features = encoder.fc.in_features  # get dimensions of fc layer
encoder.fc.out_features = len(classes)

# load pre-trained model from checkpoint
simclr_model = SimCLR(args, encoder, n_features)
model_fp = os.path.join(
    args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
)
simclr_model.load_state_dict(torch.load(model_fp, map_location=device.type))
simclr_model = simclr_model.to(device)
    

In [None]:
## Logistic Regression
n_classes = len(classes) # stl-10 / cifar-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(device)

In [28]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

### Helper functions to map all input data $X$ to their latent representations $h$ that are used in linear evaluation (they only have to be computed once)

In [29]:
def inference(loader, simclr_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        with torch.no_grad():
            h, _, z, _ = simclr_model(x, x)

        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(context_model, train_loader, test_loader, device):
    train_X, train_y = inference(train_loader, context_model, device)
    test_X, test_y = inference(test_loader, context_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader

In [None]:
print("### Creating features from pre-trained context model ###")
(train_X, train_y, test_X, test_y) = get_features(
    simclr_model, train_loader, test_loader, args.device
)

arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
    train_X, train_y, test_X, test_y, args.logistic_batch_size
)

In [None]:
for epoch in range(args.logistic_epochs):
    loss_epoch, accuracy_epoch = train(args, arr_train_loader, simclr_model, model, criterion, optimizer)
    
    if epoch % 10 == 0:
      print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")


# final testing
loss_epoch, accuracy_epoch = test(
    args, arr_test_loader, simclr_model, model, criterion, optimizer
)
print(
    f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
)