# Pytorch Lightning

### 0. Import libs and Config

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

import os
import time
import pickle
import math
import itertools
import copy

In [5]:
# Device config
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# Hyper params
batch_size = 32
num_epochs = 50
learning_rate = 0.001
momentum = 0.9


# Random seed
seed = 2022
torch.random.manual_seed(seed)
np.random.seed(2022)

# Path
# Data directory
data_dir = './data/pascal_2007'
if not os.path.exists(data_dir):
    os.mkdir(data_dir)
    
# Save directory
save_dir = './saved'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
    
# Labels
labels = pickle.load(open(os.path.join(save_dir, 'labels.sav'), 'rb'))

# logger
logger = TensorBoardLogger('runs', 'lightning')

### 1. Data Module

In [6]:
# Define datasets
class MultilabelsDataset(Dataset):
    def __init__(self, images, labellists, labels, transforms=None):
        super().__init__()
        self.images = images
        self.labellists = labellists
        self.labels = labels
        self.transforms = transforms
        self.n_samples = len(self.images)
    
    # Convert str label to index
    def toIndex(self, labellist, labels):
        l = [self.labels.index(label) for label in labellist]
        return l
    
    def __getitem__(self, index): 
        # One hot labels
        labels_index = self.toIndex(self.labellists[index], self.labels)
        labels_index = torch.tensor(labels_index)
        labels_onehot = F.one_hot(labels_index, num_classes=len(self.labels))
        labels_onehot = labels_onehot.sum(dim=0).float()
        
        # Transform input
        image = self.images[index]
        if self.transforms:
            image = self.transforms(image)
            
        return image, labels_onehot
    
    def __len__(self):
        return self.n_samples

In [7]:
# Imagenet mean and std
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Data transform
data_transform = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(), # Random flip horizontal
        transforms.ColorJitter(), # Random change color
        transforms.RandomErasing(scale=(0.02, 0.08)), # Random erase a small rectangle
        transforms.Normalize(mean, std)
    ]),
    
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize(mean, std)
    ]),
    
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize(mean, std)
    ]),
}

In [8]:
class LitDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, save_dir, transforms, batch_size):
        super().__init__()
        
        # Path
        self.data_dir = data_dir
        self.train_img_dir = os.path.join(self.data_dir, 'train')
        self.test_img_dir = os.path.join(self.data_dir, 'test')
        self.train_anno_path = os.path.join(self.data_dir, 'train.csv')
        self.test_anno_path = os.path.join(self.data_dir, 'test.csv')
        self.save_dir = save_dir
        
        # Hyper parameters
        self.batch_size = batch_size
        
        # Dataset Transforms
        self.transforms = transforms
        
    def setup(self, stage=None):
        # Load data
        print('-' * 10, 'LOAD DATA', '-' * 10)
        (train_images, train_labels), (valid_images, valid_labels) = self.load_images_and_labels(self.train_img_dir, self.train_anno_path, True)
        (test_images, test_labels) = self.load_images_and_labels(self.test_img_dir, self.test_anno_path, False)
        
        # Get labels
        self.labels = self.get_labels(train_labels)
        
        # Create dataset
        self.image_datasets = {
            'train': MultilabelsDataset(train_images, train_labels, self.labels, data_transform['train']),
            'val': MultilabelsDataset(valid_images, valid_labels, self.labels, data_transform['val']),
            'test': MultilabelsDataset(test_images, test_labels, self.labels, data_transform['test'])
        }
    
    # Data Loader
    def train_dataloader(self):
        return DataLoader(self.image_datasets['train'], num_workers=12, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.image_datasets['val'], num_workers=12, batch_size=self.batch_size, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.image_datasets['test'], num_workers=12, batch_size=self.batch_size, shuffle=False)
    
    # Get label list
    def get_labels(self, labellists):
        # All labels
        labels = []

        for labellist in labellists:
            labels = labels + labellist
            
        labels = sorted(list(set(labels)))
        pickle.dump(labels, open(os.path.join(self.save_dir, 'labels.sav'), 'wb'))
        print('Number of labels: ', len(labels))
        return labels
    
    # Load images and labels
    def load_images_and_labels(self, image_dir, anno_path, is_train=True):
        # Load anno files
        anno = pd.read_csv(anno_path).values  
        
        if is_train:
            # Train set
            train_images = []
            train_labellists = []
            
            # Valid set
            valid_images = []
            valid_labellists = []
            
            # Split train and valid set
            train_anno = anno[~ anno[:, 2].astype(np.bool8)][:, :2]
            valid_anno = anno[anno[:, 2].astype(np.bool8)][:, :2]
            
            # Load training set
            print('Loading training set')
            print('-'*10)
            for filename, labellist in train_anno:
                
                image = plt.imread(os.path.join(image_dir, filename))
                train_images.append(image)
                
                train_labellists.append(labellist.split(' '))
            
            # Load valid set
            print('Loading validate set')
            print('-'*10)
            for filename, labellist in valid_anno:
                
                image = plt.imread(os.path.join(image_dir, filename))
                valid_images.append(image)
                
                valid_labellists.append(labellist.split(' '))
            
            return (train_images, train_labellists), (valid_images, valid_labellists)
        
        else:
            # Test set
            test_images = []
            test_labellists = []
            
            # Load test set
            print('Loading test set')
            print('-'*10)
            for filename, labellist in anno:
                
                image = plt.imread(os.path.join(image_dir, filename))
                test_images.append(image)
                
                test_labellists.append(labellist.split(' '))
            
            return (test_images, test_labellists) 

In [9]:
data_module = LitDataModule(data_dir, save_dir, transforms, batch_size)

### 2. Model Module

In [10]:
# define model
class FinetunedResnet50(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.pretrained_model = models.resnet50(pretrained=True)
        
        self.num_classes = num_classes
        
        # Finetune
        in_features = self.pretrained_model.fc.in_features
        self.pretrained_model.fc = nn.Linear(in_features, 128)
        self.fc_head = nn.Linear(128, self.num_classes)
        
    def forward(self, x):
        x = F.relu(self.pretrained_model(x))
        x = torch.sigmoid(self.fc_head(x))
        return x

In [11]:
class LitModelModule(pl.LightningModule):
    def __init__(
        self, 
        num_classes, 
        learning_rate=0.001, 
        momentum=0.9, 
        lr_schedule_step=5, 
        lr_schedule_factor=0.5
    ):
        super().__init__()
        self.num_classes = num_classes
        self.model = FinetunedResnet50(num_classes)
        
        # Hyper parameters
        self.learning_rate = learning_rate
        self.criterion = nn.BCELoss()
        self.momentum = momentum
        self.lr_schedule_step = lr_schedule_step
        self.lr_schedule_factor = lr_schedule_factor
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    # Config optimizer and lr_schedule
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.momentum)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.lr_schedule_step, gamma=self.lr_schedule_factor)
        return [optimizer], [lr_scheduler]
    
    # Config callbacks
    def configure_callbacks(self):
        return [
            EarlyStopping(monitor='val/accuracy', mode='max', verbose=1, patience=10), 
            ModelCheckpoint(os.path.join(save_dir, 'checkpoints'), monitor='val/accuracy', mode='max', save_weights_only=True, verbose=1)
        ]
    
    # Train, valid, test
    def training_step(self, batch, batch_idx):
        images, labellists = batch
        outputs = self(images)
        loss = self.criterion(outputs, labellists)
        
        preds = outputs.round()
        num_corrects = torch.sum(preds == labellists.data, dim=0)
        cls_accuracy = num_corrects / labellists.size(0)
        avg_accuracy = cls_accuracy.mean() * 100
        
        tensorboard_log = {'acc': avg_accuracy}
        
        return {
            'loss': loss,
            'log': tensorboard_log
        }
    
    def validation_step(self, batch, batch_idx):
        images, labellists = batch
        outputs = self(images)
        loss = self.criterion(outputs, labellists)
        
        preds = outputs.round()
        num_corrects = torch.sum(preds == labellists.data, dim=0)
        cls_accuracy = num_corrects / labellists.size(0)
        avg_accuracy = cls_accuracy.mean() * 100
        
        tensorboard_log = {'acc': avg_accuracy}
        
        return {
            'loss': loss,
            'log': tensorboard_log
        }
    
    def test_step(self, batch, batch_idx):
        images, labellists = batch
        outputs = self(images)
        loss = self.criterion(outputs, labellists)

        preds = outputs.round()
        num_corrects = torch.sum(preds == labellists.data, dim=0)
        cls_accuracy = num_corrects / labellists.size(0)
        avg_accuracy = cls_accuracy.mean() * 100
        tensorboard_log = {'acc': avg_accuracy}
        
        return {
            'loss': loss,
            'log': tensorboard_log
        }
    
    # Save and load model state
    def state_dict(self):
        return self.model.state_dict()
    
    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict)

In [12]:
model = LitModelModule(
    num_classes=len(labels),
    learning_rate=learning_rate,
    momentum=momentum
)

### 3. Callbacks

In [13]:
# Logging Callback
class MyLoggingCallback(pl.callbacks.Callback):
    def __init__(self, every_batch=100):
        super().__init__()
        self.every_batch = every_batch

    # Train
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, unused=0): 
        # Calculate and log batch accuracy
        self.log('train/loss', outputs['loss'])
        self.log('train/accuracy', outputs['log']['acc'])
    
    # Valid 
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, unused=0): 
        # Calculate and log batch accuracy
        self.log('val/loss', outputs['loss'])
        self.log('val/accuracy', outputs['log']['acc'])
        
    
    # Test
    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, unused=0): 
        # Calculate and log batch accuracy
        self.log('test/loss', outputs['loss'])
        self.log('test/accuracy', outputs['log']['acc'])

In [14]:
# Visualization Callback
class MyVisualizationCallback(pl.callbacks.Callback):
    
    def teardown(self, trainer, pl_module, stage=None):
        if stage:
            if stage == 'fit':
                pl_module.logger.experiment.add_image('filters', self.get_kernels(pl_module))
                
            # Visualize model    
            sampleImg=torch.rand((32, 3, 224, 224))
            pl_module.logger.experiment.add_graph(pl_module.model,sampleImg)
                
    def get_kernels(self, pl_module):
        kernels = pl_module.model.pretrained_model.conv1.weight.detach().cpu()
        
        # Min max scale
        kernels = (kernels - kernels.min()) / kernels.max()
        filters = torchvision.utils.make_grid(kernels.clamp(0, 1))
        return filters  

### 4. Trainer

In [15]:
trainer = pl.Trainer(
    gpus=[1],
    max_epochs=num_epochs,
    logger=logger,
    fast_dev_run=False,
    log_every_n_steps=1,
    callbacks=[MyLoggingCallback(), MyVisualizationCallback()]
)

trainer.fit(model, data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint


---------- LOAD DATA ----------
Loading training set
----------
Loading validate set
----------
Loading test set
----------
Number of labels:  20


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | FinetunedResnet50 | 23.8 M
1 | criterion | BCELoss           | 0     
------------------------------------------------
23.8 M    Trainable params
0         Non-trainable params
23.8 M    Total params
95.092    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved. New best score: 92.052
Epoch 0, global step 79: 'val/accuracy' reached 92.05180 (best 92.05180), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=0-step=79.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.185 >= min_delta = 0.0. New best score: 92.237
Epoch 1, global step 158: 'val/accuracy' reached 92.23705 (best 92.23705), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=1-step=158.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.014 >= min_delta = 0.0. New best score: 92.251
Epoch 2, global step 237: 'val/accuracy' reached 92.25100 (best 92.25100), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=2-step=237.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.032 >= min_delta = 0.0. New best score: 92.283
Epoch 3, global step 316: 'val/accuracy' reached 92.28287 (best 92.28287), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=3-step=316.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.145 >= min_delta = 0.0. New best score: 92.428
Epoch 4, global step 395: 'val/accuracy' reached 92.42828 (best 92.42828), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=4-step=395.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.253 >= min_delta = 0.0. New best score: 92.681
Epoch 5, global step 474: 'val/accuracy' reached 92.68127 (best 92.68127), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=5-step=474.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 553: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.122 >= min_delta = 0.0. New best score: 92.803
Epoch 7, global step 632: 'val/accuracy' reached 92.80279 (best 92.80279), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=7-step=632.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.078 >= min_delta = 0.0. New best score: 92.880
Epoch 8, global step 711: 'val/accuracy' reached 92.88048 (best 92.88048), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=8-step=711.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.056 >= min_delta = 0.0. New best score: 92.936
Epoch 9, global step 790: 'val/accuracy' reached 92.93626 (best 92.93626), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=9-step=790.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.028 >= min_delta = 0.0. New best score: 92.964
Epoch 10, global step 869: 'val/accuracy' reached 92.96414 (best 92.96414), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=10-step=869.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.066 >= min_delta = 0.0. New best score: 93.030
Epoch 11, global step 948: 'val/accuracy' reached 93.02988 (best 93.02988), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=11-step=948.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.074 >= min_delta = 0.0. New best score: 93.104
Epoch 12, global step 1027: 'val/accuracy' reached 93.10358 (best 93.10358), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=12-step=1027.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.022 >= min_delta = 0.0. New best score: 93.125
Epoch 13, global step 1106: 'val/accuracy' reached 93.12550 (best 93.12550), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=13-step=1106.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 14, global step 1185: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.016 >= min_delta = 0.0. New best score: 93.141
Epoch 15, global step 1264: 'val/accuracy' reached 93.14143 (best 93.14143), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=15-step=1264.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.052 >= min_delta = 0.0. New best score: 93.193
Epoch 16, global step 1343: 'val/accuracy' reached 93.19323 (best 93.19323), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=16-step=1343.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 1422: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 1501: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.016 >= min_delta = 0.0. New best score: 93.209
Epoch 19, global step 1580: 'val/accuracy' reached 93.20916 (best 93.20916), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=19-step=1580.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 1659: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.016 >= min_delta = 0.0. New best score: 93.225
Epoch 21, global step 1738: 'val/accuracy' reached 93.22510 (best 93.22510), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=21-step=1738.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.006 >= min_delta = 0.0. New best score: 93.231
Epoch 22, global step 1817: 'val/accuracy' reached 93.23108 (best 93.23108), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=22-step=1817.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 1896: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 1975: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 2054: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 2133: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 2212: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 2291: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 2370: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 2449: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val/accuracy improved by 0.012 >= min_delta = 0.0. New best score: 93.243
Epoch 31, global step 2528: 'val/accuracy' reached 93.24303 (best 93.24303), saving model to '/home/anhtranthe/workspace/3_pytorch/saved/checkpoints/epoch=31-step=2528.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 32, global step 2607: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 33, global step 2686: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 34, global step 2765: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 35, global step 2844: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 36, global step 2923: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 37, global step 3002: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 38, global step 3081: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 3160: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 40, global step 3239: 'val/accuracy' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val/accuracy did not improve in the last 10 records. Best score: 93.243. Signaling Trainer to stop.
Epoch 41, global step 3318: 'val/accuracy' was not in top 1


In [17]:
trainer.test(model, data_module)

The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint


---------- LOAD DATA ----------
Loading training set
----------
Loading validate set
----------
Loading test set
----------


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Number of labels:  20


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test/accuracy          93.37136840820312
        test/loss           0.21799778938293457
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/loss': 0.21799778938293457, 'test/accuracy': 93.37136840820312}]

### 5. Tensorboard

In [None]:
! tensorboard --logdir=runs/lightning