In [1]:
from sklearn import datasets

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import lightning.pytorch as pl

import warnings
warnings.filterwarnings('ignore')

In [111]:
from torchmetrics import ExplainedVariance, Accuracy

# Define a Conv Classifier
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, (3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Flatten(),
            nn.LazyLinear(50),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        return self.model(x)

class ConvNetTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = ConvNet()
        
        self.accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        predictions = torch.argmax(logits, dim=1)
        batch_value = self.accuracy(predictions, y)
        self.log("train_acc", batch_value)
        return loss

    def on_train_epoch_end(self):
        self.accuracy.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        test_loss = F.cross_entropy(logits, y)
        test_accuracy = torch.sum(torch.argmax(logits, dim=1) == y) / len(y)
        self.log("test_loss", test_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("test_acc", test_accuracy, on_step=False, on_epoch=True, sync_dist=True)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        val_loss = F.cross_entropy(logits, y)
        val_accuracy = torch.sum(torch.argmax(logits, dim=1) == y) / len(y)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("val_acc", val_accuracy, on_step=False, on_epoch=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer
        
    def predict_step(self, batch, batch_idx):
        X, y = batch
        return self(X), y

In [112]:
import torch.utils.data as data
from torchvision.datasets import MNIST
from torchvision import transforms

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str='.'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

    # called within a single process on CPU
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    # automatically called after prepare_data; don't assign state here
    def setup(self, stage: str):
        if stage == "fit":
            train_dataset = MNIST(self.data_dir, train=True, transform=self.transform)
            train_set_size = int(len(train_dataset) * 0.8)
            valid_set_size = len(train_dataset) - train_set_size
            self.train_dataset, self.val_dataset = data.random_split(train_dataset, [train_set_size, valid_set_size])
        if stage == 'test':
            self.test_dataset = MNIST(self.data_dir, train=False, transform=self.transform)
        if stage == 'predict':
            self.test_dataset = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=64)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

In [None]:
# load recent checkpoint and eval
#model = ConvNetTrainer.load_from_checkpoint(r"lightning_logs/version_27/checkpoints/epoch=3-step=3000.ckpt")
model = ConvNetTrainer()
trainer = pl.Trainer()

# [!] predict method not working 
trainer.test(model, datamodule=MNISTDataModule())

In [114]:
from lightning.pytorch.profilers import AdvancedProfiler
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import DeviceStatsMonitor

## early stopping hook performed after each validation step by default
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="max")

# train model
model = ConvNetTrainer()
model.train()
trainer = pl.Trainer(
        default_root_dir=".",
        accelerator='auto',
        devices='auto',
        callbacks=[early_stop_callback],
        fast_dev_run=False,
)
trainer.fit(model, datamodule=MNISTDataModule())

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: .\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | model    | ConvNet            | 830   
1 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
830       Trainable params
0         Non-trainable params
830       Total params
0.003     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [115]:
predictor = pl.Trainer()
predictions = predictor.test(model, datamodule=MNISTDataModule())

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

In [116]:
data_module = MNISTDataModule()
data_module.setup(stage="predict")
dataloader = data_module.predict_dataloader()

In [117]:
# shape 157 x (64, 10) = min_batches x (batch_size, num_classes)
predictions = []
labels = []
for batch in dataloader:
    X, y = batch
    predictions.append(model(X))
    labels.append(y)
predictions = torch.cat(predictions)
predictions = torch.argmax(predictions, dim=1)
labels = torch.cat(labels)

In [130]:
from sklearn.metrics import precision_recall_fscore_support

# Performance
metrics = {"overall": {}, "class": {}}

In [131]:
overall_metrics = precision_recall_fscore_support(labels, predictions, average="macro")
metrics["overall"]["precision"] = overall_metrics[0]
metrics["overall"]["recall"] = overall_metrics[1]
metrics["overall"]["f1"] = overall_metrics[2]

In [132]:
class_metrics = precision_recall_fscore_support(labels, predictions, average=None)
for i in range(10):  
    metrics["class"][i] = {
        "precision": class_metrics[0][i],
        "recall": class_metrics[1][i],
        "f1": class_metrics[2][i],
    }

In [126]:
import json
from collections import OrderedDict
sorted_metrics = OrderedDict(sorted(metrics['class'].items()), key=lambda t: t[1]['f1'])
sorted_metrics

OrderedDict([(0,
              {'precision': array([0.96059113, 0.98003472, 0.96960784, 0.97616683, 0.96603397,
                      0.98526077, 0.98427673, 0.96157541, 0.98723404, 0.97368421]),
               'recall': array([0.99489796, 0.99471366, 0.95833333, 0.97326733, 0.98472505,
                      0.97421525, 0.98016701, 0.97373541, 0.95277207, 0.95341923]),
               'f1': array([0.97744361, 0.98731963, 0.96393762, 0.97471492, 0.97528996,
                      0.97970688, 0.98221757, 0.96761721, 0.96969697, 0.96344517])}),
             (1,
              {'precision': array([0.96059113, 0.98003472, 0.96960784, 0.97616683, 0.96603397,
                      0.98526077, 0.98427673, 0.96157541, 0.98723404, 0.97368421]),
               'recall': array([0.99489796, 0.99471366, 0.95833333, 0.97326733, 0.98472505,
                      0.97421525, 0.98016701, 0.97373541, 0.95277207, 0.95341923]),
               'f1': array([0.97744361, 0.98731963, 0.96393762, 0.97471492, 0.9752