In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt

import pytorch_lightning as pl
import torchmetrics

from pytorch_lightning.loggers import TensorBoardLogger

In [2]:
%%javascript
IPython.notebook.kernel.execute('notebook = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [3]:
notebook

'MNIST-Lightning-TFLogger.ipynb'

In [4]:
class CNN_pl(pl.LightningModule):
    def __init__(self):
        super(CNN_pl, self).__init__()
        
        self.loss = nn.CrossEntropyLoss()
        self.lr = 0.01
        
        self.train_accm = torchmetrics.Accuracy()
        self.valid_accm = torchmetrics.Accuracy()
        self.train_acc = 0.
        self.avg_train_loss = 0.
        
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1 * 28 * 28, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 10)
        )
            
    def forward(self, X):
        X = self.model(X)
        return X
    
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters())
        return optim
    
    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        output = self(inputs)
        loss_train = self.loss(output, targets)
        predictions = torch.argmax(output, dim=1)
        acc_train = self.train_accm(predictions, targets)
        return loss_train
    
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        output = self(inputs)
        loss_valid = self.loss(output, targets)
        predictions = torch.argmax(output, dim=1)
        acc_valid = self.valid_accm(predictions, targets)
        return {"loss": loss_valid, "p": predictions, "y": targets}
    
    def training_epoch_end(self, outputs):
        self.train_acc = self.train_accm.compute() * 100
        self.avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.logger.experiment.add_scalar('Train Loss', self.avg_train_loss, self.current_epoch+1)
        self.logger.experiment.add_scalar('Train Acc', self.train_acc, self.current_epoch+1)
        self.train_accm.reset()
        
    def validation_epoch_end(self, outputs):
        valid_acc = self.valid_accm.compute() * 100
        avg_valid_loss = torch.stack([x['loss'] for x in outputs]).mean()
        print(f'Epoch {self.current_epoch+1}/{self.trainer.max_epochs} : Train Accuracy: {self.train_acc:.2f}%, Valid Accuracy: {valid_acc:.2f}%, Avg. Train Loss: {self.avg_train_loss:.4f}, Avg. Valid Loss: {avg_valid_loss:.4f}')
        self.logger.experiment.add_scalar('Valid Loss', avg_valid_loss, self.current_epoch+1)
        self.logger.experiment.add_scalar('Valid Acc', valid_acc, self.current_epoch+1)
        self.valid_accm.reset()
        if self.current_epoch == self.trainer.max_epochs - 1:
            self.validation_end(outputs)
    
    def validation_end(self, outputs):
        pb = [x['p'] for x in outputs]
        yb = [x['y'] for x in outputs]
        p = torch.cat(pb, 0).view(-1)
        y = torch.cat(yb, 0).view(-1)
    
    def prepare_data(self):
        self.train_dataset = torchvision.datasets.MNIST(
                    root='.',
                    train=True,
                    transform=transforms.ToTensor(),
                    download=True
                    )
        self.test_dataset  = torchvision.datasets.MNIST(
                    root='.',
                    train=False,
                    transform=transforms.ToTensor(),
                    download=True
                    )
        
    def train_dataloader(self):
        train_ds = torch.utils.data.DataLoader(self.train_dataset,
                                               batch_size= 32,
                                               shuffle=True,
                                               num_workers=4)
        return train_ds
    
    def val_dataloader(self):
        test_ds = torch.utils.data.DataLoader(self.test_dataset,
                                               batch_size= 32,
                                               shuffle=False,
                                               num_workers=4)
        return test_ds

In [5]:
logger = TensorBoardLogger('tb_logs', name=notebook)
model = CNN_pl()
trainer = pl.Trainer(max_epochs=20, num_sanity_val_steps=1, logger=logger)
trainer.fit(model)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name       | Type             | Params
------------------------------------------------
0 | loss       | CrossEntropyLoss | 0     
1 | train_accm | Accuracy         | 0     
2 | valid_accm | Accuracy         | 0     
3 | model      | Sequential       | 55.1 K
------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Epoch 1/20 : Train Accuracy: 0.00%, Valid Accuracy: 6.25%, Avg. Train Loss: 0.0000, Avg. Valid Loss: 2.3103


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 1/20 : Train Accuracy: 88.71%, Valid Accuracy: 95.15%, Avg. Train Loss: 0.3731, Avg. Valid Loss: 0.1611


Validating: 0it [00:00, ?it/s]

Epoch 2/20 : Train Accuracy: 94.74%, Valid Accuracy: 96.59%, Avg. Train Loss: 0.1771, Avg. Valid Loss: 0.1169


Validating: 0it [00:00, ?it/s]

Epoch 3/20 : Train Accuracy: 95.92%, Valid Accuracy: 97.05%, Avg. Train Loss: 0.1351, Avg. Valid Loss: 0.0946


Validating: 0it [00:00, ?it/s]

Epoch 4/20 : Train Accuracy: 96.48%, Valid Accuracy: 97.09%, Avg. Train Loss: 0.1126, Avg. Valid Loss: 0.0955


Validating: 0it [00:00, ?it/s]

Epoch 5/20 : Train Accuracy: 96.96%, Valid Accuracy: 97.18%, Avg. Train Loss: 0.0993, Avg. Valid Loss: 0.0891


Validating: 0it [00:00, ?it/s]

Epoch 6/20 : Train Accuracy: 97.20%, Valid Accuracy: 97.64%, Avg. Train Loss: 0.0886, Avg. Valid Loss: 0.0816


Validating: 0it [00:00, ?it/s]

Epoch 7/20 : Train Accuracy: 97.40%, Valid Accuracy: 97.59%, Avg. Train Loss: 0.0792, Avg. Valid Loss: 0.0841


Validating: 0it [00:00, ?it/s]

Epoch 8/20 : Train Accuracy: 97.59%, Valid Accuracy: 97.64%, Avg. Train Loss: 0.0748, Avg. Valid Loss: 0.0787


Validating: 0it [00:00, ?it/s]

Epoch 9/20 : Train Accuracy: 97.73%, Valid Accuracy: 97.48%, Avg. Train Loss: 0.0718, Avg. Valid Loss: 0.0837


Validating: 0it [00:00, ?it/s]

Epoch 10/20 : Train Accuracy: 97.87%, Valid Accuracy: 97.60%, Avg. Train Loss: 0.0664, Avg. Valid Loss: 0.0818


Validating: 0it [00:00, ?it/s]

Epoch 11/20 : Train Accuracy: 98.00%, Valid Accuracy: 97.66%, Avg. Train Loss: 0.0621, Avg. Valid Loss: 0.0825


Validating: 0it [00:00, ?it/s]

Epoch 12/20 : Train Accuracy: 98.05%, Valid Accuracy: 97.53%, Avg. Train Loss: 0.0601, Avg. Valid Loss: 0.0817


Validating: 0it [00:00, ?it/s]

Epoch 13/20 : Train Accuracy: 98.05%, Valid Accuracy: 97.71%, Avg. Train Loss: 0.0591, Avg. Valid Loss: 0.0834


Validating: 0it [00:00, ?it/s]

Epoch 14/20 : Train Accuracy: 98.24%, Valid Accuracy: 97.82%, Avg. Train Loss: 0.0550, Avg. Valid Loss: 0.0838


Validating: 0it [00:00, ?it/s]

Epoch 15/20 : Train Accuracy: 98.26%, Valid Accuracy: 97.75%, Avg. Train Loss: 0.0521, Avg. Valid Loss: 0.0861


Validating: 0it [00:00, ?it/s]

Epoch 16/20 : Train Accuracy: 98.30%, Valid Accuracy: 97.72%, Avg. Train Loss: 0.0528, Avg. Valid Loss: 0.0835


Validating: 0it [00:00, ?it/s]

Epoch 17/20 : Train Accuracy: 98.48%, Valid Accuracy: 97.80%, Avg. Train Loss: 0.0481, Avg. Valid Loss: 0.0897


Validating: 0it [00:00, ?it/s]

Epoch 18/20 : Train Accuracy: 98.39%, Valid Accuracy: 97.80%, Avg. Train Loss: 0.0485, Avg. Valid Loss: 0.0855


Validating: 0it [00:00, ?it/s]

Epoch 19/20 : Train Accuracy: 98.43%, Valid Accuracy: 97.99%, Avg. Train Loss: 0.0470, Avg. Valid Loss: 0.0845


Validating: 0it [00:00, ?it/s]

Epoch 20/20 : Train Accuracy: 98.52%, Valid Accuracy: 97.85%, Avg. Train Loss: 0.0447, Avg. Valid Loss: 0.0857


1

In [None]:
%load_ext tensorboard
%tensorboard --logdir tb_logs --port 6006