In [3]:
import import_ipynb
import Model

import lightning as L
import torchmetrics
import torch
import torch.nn.functional as F

# Lightning Model
- Lightning model is more like an interface that allows train & val & test and logging mechanism easier.
- You need to implement __init__ with model & hyperparameters. Define train & val & test step operations and optimizer of learning.

In [4]:
class LightningModel(L.LightningModule):
    def __init__(self, model, lr):
        super().__init__()
        self.model = model
        self.lr = lr

        self.save_hyperparameters(ignore=["model"])
        
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch):
        images, true_labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, true_labels)
        predicted_labels = torch.argmax(logits, dim=1)
        return loss, predicted_labels, true_labels
    
    def training_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._common_step(batch)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.train_acc(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._common_step(batch)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._common_step(batch)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.test_acc(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, prog_bar=True, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = self.lr)
        return optimizer
    