Quickly debugging the ViT trainer code before doing some test runs 

In [1]:

from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import cv2


import sys
sys.path.append('../vit_pytorch/')
sys.path.append('..')
from vit import ViT
from recorder import Recorder # import the Recorder and instantiate
#from dataloaders import *


import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from argparse import ArgumentParser, Namespace
import os
import random
import sys
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
from dataloader import get_CIFAR_data
from pets_loader import get_PETS_data
%matplotlib inline

In [8]:

class ViT_Trainer(pl.LightningModule):
    def __init__(self, hparams=None):
        super(ViT_Trainer,self).__init__()
        self.__check_hparams(hparams)
        self.hparams = hparams
        self.prepare_data()

        self.__model = ViT(
                            dim=self.dim,
                            image_size=self.image_size,
                            patch_size=self.patch_size,
                            num_classes=self.num_classes,
                            channels=self.channels,
                            depth = self.depth,
                            heads=self.heads, 
                            mlp_dim=self.mlp_dim,
                            dropout=self.dropout
                        )
        self.rec = Recorder()

    def forward(self,x):

        y_pred = self.__model(x,rec = self.rec)# returns the predicted class for this dataset. 

        return y_pred


    def _run_step(self, batch, batch_idx,step_name):

        img, y_true  = batch
        y_pred = self(img) 

        if batch_idx % 1500 == 0:
            # log progress. save a few images from the batch, what they are, and what their prediction is. 
            self.__log_step(img,y_true,y_pred, step_name)



        loss = F.cross_entropy(y_pred, y_true)

        return loss , y_pred, y_true


    def training_step(self, batch, batch_idx):

        train_loss, _, _ = self._run_step( batch, batch_idx,step_name='train') 
        train_tensorboard_logs = {'train_loss': train_loss}
        
        return {'loss': train_loss, 'log': train_tensorboard_logs}


    def validation_step(self, batch, batch_idx):

        val_log_dict = {}
        val_loss, y_pred, y_true = self._run_step(batch, batch_idx, step_name='valid')
        y_pred = y_pred.argmax(dim=1).detach().cpu()
        y_true = y_true.detach().cpu()
        val_log_dict['val_loss'] = val_loss
        val_acc = torch.from_numpy(np.array([accuracy_score(y_pred,y_true)]))
        val_log_dict['val_acc'] = val_acc

        return val_log_dict 


    def validation_epoch_end(self, outputs):   

        val_tensorboard_logs = {}
        avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()  
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        val_tensorboard_logs['avg_val_acc'] = avg_val_acc
        val_tensorboard_logs['avg_val_loss'] = avg_val_loss

        return {'val_loss': avg_val_loss, 'log': val_tensorboard_logs}

    def test_step(self, batch, batch_idx):

        test_log_dict = {}
        test_loss, y_pred, y_true = self._run_step(batch, batch_idx, step_name='test')
        y_pred = y_pred.argmax(dim=1).detach().cpu()
        y_true = y_true.detach().cpu()
        test_log_dict['test_loss'] = test_loss
        test_acc = torch.from_numpy(np.array([accuracy_score(y_pred,y_true)]))
        test_log_dict['test_acc'] = test_acc

        return test_log_dict 


    def test_epoch_end(self, outputs):    

        test_tensorboard_logs = {}
        avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()  
        avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        test_tensorboard_logs['avg_test_acc'] = avg_test_acc
        test_tensorboard_logs['avg_test_loss'] = avg_test_loss

        return {'test_loss': avg_test_loss, 'log': test_tensorboard_logs}

    def configure_optimizers(self):
        optimizer =  torch.optim.Adam(self.parameters(), lr = self.learning_rate ,weight_decay = self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience = 4)
        return [optimizer], [scheduler] 
    
    def prepare_data(self):
        # the dataloaders are run batch by batch where this is run fully and once before beginning training
#         self.train_loader, self.valid_loader, self.test_loader = get_CIFAR_data(batch_size=self.batch_size,
#                                                                                  dset = self.dataset, 
#                                                                                  )
        self.train_loader, self.valid_loader, self.test_loader = get_PETS_data(batch_size=self.batch_size)


    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader

    def test_dataloader(self):
        return self.test_loader

    def __log_step(self,img, y_true, y_pred, step_name, limit=1):
        ## Plot attention map 
        j = 0 # using the jth element from that batch 
        attn_mat = self.rec.attn[j].cpu()
        im = img[j].cpu().numpy().transpose(1,2,0)
        attn_mat = torch.mean(attn_mat, dim=1) # average across heads 
        # To account for residual connections, we add an identity matrix to the
        # attention matrix and re-normalize the weights.
        residual_att = torch.eye(attn_mat.size(1))
        aug_att_mat = attn_mat + residual_att
        aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
        # Recursively multiply the weight matrices
        joint_attentions = torch.zeros(aug_att_mat.size())
        joint_attentions[0] = aug_att_mat[0]
        for n in range(1, aug_att_mat.size(0)):
            joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
            
        # combines all the different layers which apply attention. 

        # Attention from the output token to the input space.
        v = joint_attentions[-1]
        grid_size = int(np.sqrt(aug_att_mat.size(-1)))
        mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
        mask = cv2.resize(mask / mask.max(), (self.image_size,self.image_size))[..., np.newaxis]
        result = (mask * im.astype("uint8"))
        #TODO 
        fig, ax = plt.subplots()
        ax.imshow(im) #grayscale
        tag = f'{step_name}_image'
        self.logger.experiment.add_figure(tag, fig, global_step=self.trainer.global_step, close=True, walltime=None)
          

        fig, ax = plt.subplots()
        ax.imshow(mask)
        tag = f'{step_name}_attention_mask'
        self.logger.experiment.add_figure(tag, fig, global_step=self.trainer.global_step, close=True, walltime=None)

        
    
    def __check_hparams(self, hparams):
        self.channels = hparams.channels if hasattr(hparams, 'channels') else 3
        self.image_size = hparams.image_size if hasattr(hparams, 'image_size') else 32
        self.patch_size = hparams.patch_size if hasattr(hparams, 'patch_size') else 8
        self.depth = hparams.depth if hasattr(hparams, 'depth') else 8
        self.heads = hparams.heads if hasattr(hparams, 'heads') else 8
        self.dim = hparams.dim if hasattr(hparams, 'dim') else 768
        self.mlp_dim = hparams.mlp_dim if hasattr(hparams, 'mlp_dim') else 512
        self.dropout = hparams.dropout if hasattr(hparams, 'dropout') else 0
        self.num_classes = hparams.num_classes if hasattr(hparams, 'num_classes') else 100

        self.batch_size = hparams.batch_size if hasattr(hparams, 'batch_size') else 128
        self.learning_rate = hparams.learning_rate if hasattr(hparams, 'learning_rate') else 0.001
        self.weight_decay = hparams.weight_decay if hasattr(hparams, 'weight_decay') else 0.001
        self.seed = hparams.seed if hasattr(hparams, 'seed') else 32
#         self.dataset = hparams.dataset if hasattr(hparams, 'dataset') else 'cifar100'
        self.dataset = hparams.dataset if hasattr(hparams, 'dataset') else 'pets'
#


    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = HyperOptArgumentParser(parents=[parent_parser], add_help=False)

        # architecture specific arguments
        parser.add_argument('--channels', type=int, default=3) 
        parser.add_argument('--image_size', type=int, default=32)  
        parser.add_argument('--patch_size', type=int, default=4)  # not really specified
        parser.add_argument('--depth', type=int, default=12)  # 12, 24, 32
        parser.add_argument('--heads', type=int, default=12)  # 12, 16, 16
        parser.add_argument('--dim', type=int, default=768)  # 768, 1024, 1280
        parser.add_argument('--mlp_dim', type=int, default=3072) # 3072, 4096, 5120
        parser.add_argument('--dropout', type=float, default=0)  # 0 or .1
        parser.add_argument('--num_classes', type=int, default=100) 

        # setup arguments
        parser.add_argument('--batch_size', type=int, default=128)  # 4096 
        parser.add_argument('--learning_rate', type=int, default=1e-4) # .9, .999 (Adam)
        parser.add_argument('--weight_decay', type=int, default=.001) # .1
        parser.add_argument('--seed', type=int, default = 42) # shuffling samples in data loader 
#         parser.add_argument('--dataset',type=str, default = 'cifar100') # which data set to train with. 
        parser.add_argument('--dataset',type=str, default = 'pets') # which data set to train with. 

        return parser



In [9]:
# Init our model
model = ViT_Trainer()


checking for corrupted images


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7390.0), HTML(value='')))

[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Abyssinian_34.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_139.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_145.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_167.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_177.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_191.jpg



In [10]:
for batch in model.train_dataloader():
    break

In [4]:
for batch_pets in t:
    break

In [5]:
batch_pets[0].shape

torch.Size([128, 3, 32, 32])

In [13]:
# Initialize a trainer
trainer = pl.Trainer(gpus=0, max_epochs=30, progress_bar_refresh_rate=20)

# Train the model 
trainer.fit(model)

INFO:lightning:GPU available: False, used: False


checking for corrupted images


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7390.0), HTML(value='')))

[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Abyssinian_34.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_139.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_145.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_167.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_177.jpg
[INFO] Corrupted Image: ../data/oxford_iiit_pet/images/Egyptian_Mau_191.jpg



INFO:lightning:
    | Name                                                   | Type        | Params
-----------------------------------------------------------------------------------
0   | _ViT_Trainer__model                                    | ViT         | 19 M  
1   | _ViT_Trainer__model.to_patch_embedding                 | Sequential  | 148 K 
2   | _ViT_Trainer__model.to_patch_embedding.0               | Rearrange   | 0     
3   | _ViT_Trainer__model.to_patch_embedding.1               | Linear      | 148 K 
4   | _ViT_Trainer__model.dropout                            | Dropout     | 0     
5   | _ViT_Trainer__model.transformer                        | Transformer | 18 M  
6   | _ViT_Trainer__model.transformer.layers                 | ModuleList  | 18 M  
7   | _ViT_Trainer__model.transformer.layers.0               | ModuleList  | 2 M   
8   | _ViT_Trainer__model.transformer.layers.0.0             | PreNorm     | 1 M   
9   | _ViT_Trainer__model.transformer.layers.0.0.norm       



HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1

In [None]:
%debug