In [None]:
!pip install pytorch-lightning 
!pip install torchvision
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F

In [None]:
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())

train_set, val_set, test_set, other = torch.utils.data.random_split(dataset, [5000, 1000, 1000, 53000])
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(),
                                nn.Linear(64, 128), nn.ReLU(),
                                nn.Linear(128, 10))

    def forward(self, x):
        return self.l1(x)

In [None]:
class SimpleTraining(pl.LightningModule):
    def __init__(self, simplenet):
        super().__init__()
        self.simplenet = simplenet
        self.loss = nn.CrossEntropyLoss()
        self.train_accuracy = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1) 
        y_hat = self.simplenet(x)
        J = self.loss(y_hat, y)
        self.train_accuracy(y_hat, y)

        self.log('train_loss', J, on_step=False, on_epoch=True)
        self.log('train_acc_step', self.train_accuracy, on_step=False, on_epoch=True)
        return J

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())

train_set, val_set, test_set, other = torch.utils.data.random_split(dataset, [5000, 1000, 1000, 53000])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=True)

In [None]:
class SimpleTrainingWithValidation(pl.LightningModule):
    def __init__(self, simplenet):
        super().__init__()
        self.simplenet = simplenet
        self.loss = nn.CrossEntropyLoss()
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1) 
        y_hat = self.simplenet(x)
        loss = self.loss(y_hat, y)
        self.train_accuracy(y_hat, y)

        self.log('train_loss', loss, on_step=False, on_epoch=True)
        self.log('train_acc_step', self.train_accuracy, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1) 
        y_hat = self.simplenet(x)
        loss = self.loss(y_hat, y)
        self.val_acc(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)

In [None]:
simple_training_with_valid = SimpleTrainingWithValidation(SimpleNet())

%load_ext tensorboard
%tensorboard --logdir tb_logs/
tensorboard = TensorBoardLogger("tb_logs", name="simple_model-with_valid")

trainer = pl.Trainer(logger=tensorboard, max_epochs=100)
trainer.fit(model=simple_training_with_valid, train_dataloaders=train_loader, val_dataloaders=val_loader)