In [None]:
!pip install -q pytorch_lightning
!pip install -q lightly

[K     |████████████████████████████████| 585 kB 12.1 MB/s 
[K     |████████████████████████████████| 596 kB 56.3 MB/s 
[K     |████████████████████████████████| 141 kB 64.1 MB/s 
[K     |████████████████████████████████| 419 kB 64.5 MB/s 
[K     |████████████████████████████████| 459 kB 12.0 MB/s 
[K     |████████████████████████████████| 151 kB 16.5 MB/s 
[K     |████████████████████████████████| 117 kB 57.2 MB/s 
[K     |████████████████████████████████| 79 kB 6.7 MB/s 
[?25h  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone


In [None]:
# Databricks notebook source
import os
import glob
from PIL import Image
from torch.multiprocessing import cpu_count
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import copy
import lightly
from torchvision.models import resnet18, ResNet18_Weights
from lightly.models.modules.heads import SimCLRProjectionHead
from lightly.loss import NTXentLoss
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger

In [None]:
# Mount Google drive to upload datasets
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Set the PyTorch lightning random seed to ensure reproducibility
pl.seed_everything(1)

Global seed set to 1


1

In [None]:
# Add a colate function to be applied to the data loader - Moco uses SIMCLR augmentations
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=224,
)

In [None]:
# We use the moco augmentations for training moco
dataset_train_simclr = lightly.data.LightlyDataset(
    input_dir='/content/drive/My Drive/train/',
    transform = torchvision.transforms.Resize((224,224)),
)

In [None]:
# Create a validation dataset to check for overfitting
dataset_val_simclr = lightly.data.LightlyDataset(
    input_dir='/content/drive/My Drive/val/',
    transform = torchvision.transforms.Resize((224,224)),
)

In [None]:
num_workers=4
batch_size=320
# Create the dataloaders to train the embeddings and the classifier
dataloader_train_simclr = torch.utils.data.DataLoader(
    dataset_train_simclr,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

In [None]:
# Create the dataloaders to train the embeddings and the classifier
dataloader_val_simclr = torch.utils.data.DataLoader(
    dataset_val_simclr,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

In [None]:
max_epochs=180

In [None]:
class SimCLRModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.backbone = nn.Sequential(
            *list(resnet.children())[:-1])
        hidden_dim = resnet.fc.in_features

        # create a simclr model based on ResNet
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)
 
        # create our loss with the optional memory bank
        self.criterion = NTXentLoss()
    
    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z
    
    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)

        loss = self.criterion(z0, z1)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True, batch_size=batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)

        loss = self.criterion(z0, z1)
        self.log('val_loss', loss, on_step=False, on_epoch=True, logger=True, batch_size=batch_size)
        return loss
        
    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(),
            lr=6e-2,
            momentum=0.9,
            weight_decay=5e-4,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optim, max_epochs
        )
        return [optim], [scheduler]

In [None]:
%cd /content/drive/My Drive/individual_rec_models

/content/drive/My Drive/individual_rec_models


In [None]:
# Define a CSV logger and write the losses to a local CSV file
logger = CSVLogger('embedding_training_log', name='retrain_embeddings')

# Define the pytorch trainer and allow for early stopping
stop_callback = EarlyStopping(monitor='val_loss', patience=5, verbose=True,mode='min')

In [None]:
# Create an instance of the Moco model
simclrmodel = SimCLRModel()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [None]:
# Train the model
gpus = 1 if torch.cuda.is_available() else 0
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, callbacks=[stop_callback], logger=logger, log_every_n_steps=20)
trainer.fit(simclrmodel,dataloader_train_simclr, dataloader_val_simclr)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params
---------------------------------------------------------
0 | backbone        | Sequential           | 11.2 M
1 | projection_head | SimCLRProjectionHead | 328 K 
2 | criterion       | NTXentLoss           | 0     
---------------------------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.019    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 5.710


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.255 >= min_delta = 0.0. New best score: 5.456


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.114 >= min_delta = 0.0. New best score: 5.341


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 5.331


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.102 >= min_delta = 0.0. New best score: 5.229


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.047 >= min_delta = 0.0. New best score: 5.182


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 5.160


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 5.147


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 5.143


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 5.143


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 5.124


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 5.110


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 5.100


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 5.072


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 5.071


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 5.054


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 5.043


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 5.043. Signaling Trainer to stop.


In [None]:
# Save the retrained model backbone and projection head
pretrained_backbone = simclrmodel.backbone
backbone_state_dict = {
        'resnet18_parameters': pretrained_backbone.state_dict()
        }
pretrained_projection_head = simclrmodel.projection_head
projection_head_state_dict = {
        'projection_parameters': pretrained_projection_head.state_dict()
        }
torch.save(backbone_state_dict, 'simclrresnet18embed.pth')
torch.save(projection_head_state_dict, 'simclr_projectionhead.pth')

In [None]:
# Save the final model
trainer.save_checkpoint('allanimal_simclr_Resnet18_bs320.ckpt')