In [1]:
import time
import shutil


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar


import tableprint as tp
import torchmetrics

In [None]:
!pip install dvc
!dvc get https://github.com/iterative/dataset-registry tutorials/versioning/data.zip
!unzip -q data.zip
!rm -f data.zip

In [2]:
TRAIN_PATH = './data/train'
TEST_PATH = './data/validation'

img_transforms = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

train_data = datasets.ImageFolder(TRAIN_PATH, transform=img_transforms)
val_data = datasets.ImageFolder(TEST_PATH, transform=img_transforms)

In [3]:
num_workers = 4
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers,shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size,  num_workers=num_workers)

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # conv layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=18, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(in_channels=18, out_channels=3, kernel_size=3, stride=2)
        
        # dense layers
        self.fc1 = nn.Linear(2187 , 1024)
        self.fc2 = nn.Linear(1024 , 2)
            
    def forward(self, X):
        X = F.relu(self.conv1(X)) # here, RELU is being treated as a function rather than a layer/module
        X = F.relu(self.conv2(X))
        X = F.relu(self.conv3(X))
        X = X.view(-1, 2187)
        X = F.dropout(X, p=0.2)
        X = F.relu(self.fc1(X))
        X = F.dropout(X, p=0.2)
        X = self.fc2(X)
        return X

In [5]:
base_model = CNN()

In [6]:
class Model(pl.LightningModule):
    def __init__(self, model):
        super(Model, self).__init__()
        self.model = model
        self.avg_train_loss = 0.
        self.avg_valid_loss = 0.
        self.table_context = None
        self.loss_fn = nn.CrossEntropyLoss()
        self.start_time = 0
        self.end_time = 0
        self.epoch_mins = 0
        self.epoch_secs = 0
        self.table_context = None
        self.train_accm = torchmetrics.Accuracy()
        self.valid_accm = torchmetrics.Accuracy()
        self.train_acc = 0.
        self.valid_acc = 0.

        

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.0005)
        return optim


    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        _, predictions = torch.max(output, 1)
        acc_train = self.train_accm(predictions, target)
        loss = self.loss_fn(output, target)
        return {"loss": loss, "p": predictions, "y": target}
    
    
    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        _, predictions = torch.max(output, 1)
        acc_valid = self.valid_accm(predictions, target)
        loss_valid = self.loss_fn(output, target)
        return {"loss": loss_valid, "p": predictions, "y": target}


    def on_train_epoch_start(self) :
        self.start_time = time.time()


    def validation_epoch_end(self, outputs):
        if self.trainer.sanity_checking:
          return
        
        self.avg_valid_loss = torch.stack([x['loss'] for x in outputs]).mean().item()
        self.valid_acc = (self.valid_accm.compute() * 100).item()
        self.valid_accm.reset()
        self.log("epoch_num", int(self.current_epoch+1), on_step=False, on_epoch=True, prog_bar=False, logger=False)
        self.log("val_loss", self.avg_valid_loss, on_step=False, on_epoch=True, prog_bar=False, logger=False)
        self.log("val_acc", self.valid_acc, on_step=False, on_epoch=True, prog_bar=False, logger=False)
        
          

    def training_epoch_end(self, outputs):
        self.avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean().item()
        self.train_acc = (self.train_accm.compute() * 100).item()
        self.train_accm.reset()
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=False, logger=False)
        self.log("train_loss", self.avg_train_loss, on_step=False, on_epoch=True, prog_bar=False, logger=False)
        self.log("epoch_num", int(self.current_epoch+1), on_step=False, on_epoch=True, prog_bar=False, logger=False)


    def on_train_epoch_end(self):
        self.end_time = time.time()
        self.epoch_mins, self.epoch_secs = self.epoch_time(self.start_time, self.end_time)
        time_int = f'{self.epoch_mins}m {self.epoch_secs}s'
    
        metrics = {'epoch': self.current_epoch+1, 'Train Acc': self.train_acc, 'Train Loss': self.avg_train_loss,  'Valid Acc': self.valid_acc, 'Valid Loss': self.avg_valid_loss}
        if self.table_context is None:
          self.table_context = tp.TableContext(headers=['epoch', 'Train Acc', 'Train Loss', 'Valid Acc', 'Valid Loss', 'Time'])
          self.table_context.__enter__()
        self.table_context([self.current_epoch+1, self.train_acc, self.avg_train_loss, self.valid_acc, self.avg_valid_loss, time_int])
        self.logger.log_metrics(metrics)

        if self.current_epoch == self.trainer.max_epochs - 1:
          self.table_context.__exit__()


    def epoch_time(self, start_time, end_time):
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs

In [11]:
model = Model(base_model)

In [12]:
print('Before training')
print(f'From base model: {base_model.state_dict()["fc2.bias"]}')
print(f'From PTL Module: {model.state_dict()["model.fc2.bias"]}')
print('')

Before training
From base model: tensor([-0.0006, -0.0033])
From PTL Module: tensor([-0.0006, -0.0033])



In [13]:
checkpoint_callback = ModelCheckpoint(
    monitor='train_acc',
    dirpath='./ckpt',
    filename='model-{epoch_num:.0f}-{val_loss:.2f}',
    mode='max'
)

In [14]:
csvlogger = CSVLogger('csv_logs', name='E1', version=0)
trainer = pl.Trainer(max_epochs=5, num_sanity_val_steps=0, logger=csvlogger, gpus=0, callbacks=[checkpoint_callback], log_every_n_steps=1)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name       | Type             | Params
------------------------------------------------
0 | model      | CNN              | 2.2 M 
1 | loss_fn    | CrossEntropyLoss | 0     
2 | train_accm | Accuracy         | 0     
3 | valid_accm | Accuracy         | 0     
------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.981     Total estimated model params size (MB)


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

╭─────────────┬─────────────┬─────────────┬─────────────┬─────────────┬─────────────╮
│       epoch │   Train Acc │  Train Loss │   Valid Acc │  Valid Loss │        Time │
├─────────────┼─────────────┼─────────────┼─────────────┼─────────────┼─────────────┤
│           1 │        55.4 │      0.6837 │      59.625 │     0.67976 │      0m 10s │


Validating: 0it [00:00, ?it/s]

│           2 │        64.4 │     0.65185 │        57.5 │     0.67269 │      0m 10s │


Validating: 0it [00:00, ?it/s]

│           3 │        70.4 │     0.58671 │       59.25 │     0.68084 │      0m 11s │


Validating: 0it [00:00, ?it/s]

│           4 │        75.1 │     0.53537 │      59.875 │     0.71231 │      0m 11s │


Validating: 0it [00:00, ?it/s]

│           5 │        79.9 │     0.43738 │      58.625 │     0.78774 │      0m 11s │
╰─────────────┴─────────────┴─────────────┴─────────────┴─────────────┴─────────────╯


In [15]:
print('After training')
print(f'From base model: {base_model.state_dict()["fc2.bias"]}')
print(f'From PTL Module: {model.state_dict()["model.fc2.bias"]}')
print('')

After training
From base model: tensor([ 0.0071, -0.0110])
From PTL Module: tensor([ 0.0071, -0.0110])

