<a href="https://colab.research.google.com/github/dssaenzml/simCLR_ML/blob/main/ml701_simCLR_RN50_V2_PL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SimCLR with PL + Downstream

List of full videos is here:    

https://www.youtube.com/playlist?list=PLaMu-SDt_RB4k8VXiB3hOdsn0Y3GoXo1k

Pretrained simCLR2: https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/pretrained?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false

In [None]:
# ------- USE FOR TPU SUPPPORT ------
# %%capture
# ! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
# torch.tensor([10.]*10000000000) # trick to gain RAM ( doesnt work anymore...)

In [None]:
%%capture
! pip install pytorch-lightning-bolts
! pip install pytorch-lightning
! pip install wandb


In [None]:
import torch
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)
import wandb
from pytorch_lightning.loggers import WandbLogger

In [None]:
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)


# model
model = SimCLR(max_epochs=5,num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10', gpus=1)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='simclr_ckp',
    monitor='val_loss',
    filename='{epoch}-{train_loss:.2f}-{val_loss:.2f}'    
)

# fit
trainer = pl.Trainer(max_epochs=1,progress_bar_refresh_rate=20,gpus=1,logger=wandb_logger,callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)

wandb_logger = WandbLogger(name='exp_1-epochs_170221',project='simCLR-ml701', id='2', log_model=True)
wandb_logger.watch(model, log="all", log_freq=50)

#save checkpoint with weights
checkpoint_file = "resnet50-cifar10-embeddings.ckpt"
trainer.save_checkpoint(checkpoint_file)
#save to W&B
trainer.logger.experiment.log_artifact(checkpoint_file, type="model")


In [None]:
#end run on W&B + sync results
wandb.finish()

In [None]:
# DOWNSTREAM TASK CLASSIFIER
class MyClassifier(nn.Module):
    def __init__(self, n_classes, freeze_base, embeddings_model_path, hidden_size=512):
        super().__init__()
        
        base_model = SimCLR.load_from_checkpoint(embeddings_model_path).model
        
        self.embeddings = base_model.embedding
        
        if freeze_base:
            print("Freezing embeddings")
            for param in self.embeddings.parameters():
                param.requires_grad = False
                
        # Only linear projection on top of the embeddings should be enough
        self.classifier = nn.Linear(in_features=base_model.projection[0].in_features, 
                      out_features = n_classes if n_classes &gt; 2 else 1)

    
    def forward(self, X, *args):
        emb = self.embeddings(X)
        return self.classifier(emb)

In [None]:
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from argparse import Namespace

class MyClassifierModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        hparams = Namespace(**hparams) if isinstance(hparams, dict) else hparams
        self.hparams = hparams
        self.model = MyShittyClassifier(hparams.n_classes, hparams.freeze_base, 
                                      hparams.embeddings_path,
                                      self.hparams.hidden_size)
        self.loss = nn.CrossEntropyLoss()
    
    def total_steps(self):
        return len(self.train_dataloader()) // self.hparams.epochs
    
    def preprocessing(seff):
        return transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    #TODO: look at custom data loading - check docs
    def get_dataloader(self, split):
        return DataLoader(CIFAR10DataModule(".", split=split, transform=self.preprocessing()),
                          batch_size=self.hparams.batch_size, 
                          shuffle=split=="train",
                          num_workers=cpu_count(),
                         drop_last=False)
    
    def train_dataloader(self):
        return self.get_dataloader("train")
    
    def val_dataloader(self):
        return self.get_dataloader("test")
    
    def forward(self, X):
        return self.model(X)
    
    def step(self, batch, step_name = "train"):
        X, y = batch
        y_out = self.forward(X)
        loss = self.loss(y_out, y)
        loss_key = f"{step_name}_loss"
        tensorboard_logs = {loss_key: loss}

        return { ("loss" if step_name == "train" else loss_key): loss, 'log': tensorboard_logs,
                        "progress_bar": {loss_key: loss}}
    
    def training_step(self, batch, batch_idx):
        return self.step(batch, "train")
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")
    
    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")
    
    def validation_end(self, outputs):
        if len(outputs) == 0:
            return {"val_loss": torch.tensor(0)}
        else:
            loss = torch.stack([x["val_loss"] for x in outputs]).mean()
            return {"val_loss": loss, "log": {"val_loss": loss}}

  #TODO: explore other schedulers and optimizers
  
    def configure_optimizers(self):
        optimizer = RMSprop(self.model.parameters(), lr=self.hparams.lr)
        schedulers = [
            CosineAnnealingLR(optimizer, self.hparams.epochs)
        ] if self.hparams.epochs &gt; 1 else []
        return [optimizer], schedulers

In [None]:

hparams_cls = Namespace(
    lr=1e-3,
    epochs=5,
    batch_size=160,
    n_classes=10,
    freeze_base=True,
    embeddings_path="resnet50-cifar10-embeddings.ckpt",
    hidden_size=512
)
module = MyClassifierModule(hparams_cls)

wandb_logger = WandbLogger(name='classifier-exp-1',project='simCLR-ml701', id='d1', log_model=True)
wandb_logger.watch(module, log="all", log_freq=50)

trainer = pl.Trainer(gpus=1, max_epochs=hparams_cls.epochs, logger=wandb_logger)



In [None]:
from sklearn.metrics import classification_report

def evaluate(data_loader, module):
    with torch.no_grad():
        progress = ["/", "-", "\\", "|", "/", "-", "\\", "|"]
        module.eval().cuda()
        true_y, pred_y = [], []
        for i, batch_ in enumerate(data_loader):
            X, y = batch_
            print(progress[i % len(progress)], end="\r")
            y_pred = torch.argmax(module(X.cuda()), dim=1)
            true_y.extend(y.cpu())
            pred_y.extend(y_pred.cpu())
        print(classification_report(true_y, pred_y, digits=3))
        return true_y, pred_y
        
_ = evaluate(module.val_dataloader(), module)