In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import pytorch_lightning as pl


batch_size = 128


transform = transforms.Compose(
    [
    transforms.ToTensor(),
    #transforms.Normalize((0.1307,), (0.3081,))
    ])

trainset = torchvision.datasets.MNIST(os.getcwd(), train=True,download=True, transform=transform)
trainset, valset = random_split(trainset, [55000, 5000])
testset = torchvision.datasets.MNIST(os.getcwd(), train=False,download=True, transform=transform)

train_dataloader = DataLoader(trainset, batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(valset, batch_size=batch_size,shuffle=True)
test_dataloader = DataLoader(testset, batch_size=batch_size,shuffle=False)

In [5]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=128):
        super().__init__()
        #self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self):
        #download
        MNIST(os.getcwd(), train=True, download=False)
        MNIST(os.getcwd(), train=False, download=False)

    def setup(self, stage: str):
        self.mnist_test = MNIST(os.getcwd(), train=False, transform=transforms.ToTensor())
        self.mnist_predict = MNIST(os.getcwd(), train=False, transform=transforms.ToTensor())
        mnist_full = MNIST(os.getcwd(), train=True, transform=transforms.ToTensor())
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)
    
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

In [6]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Conv2d(1,32,3),
            nn.ReLU(),
            nn.Conv2d(32,64,3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(64*12*12, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128,10),
            #nn.Softmax()
        )

    def forward(self, x):
        x = self.stack(x)
        return x

In [7]:
#How to compute conv layers
layer = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
example = trainset[0][0]
layer(example).shape

torch.Size([32, 26, 26])

In [8]:
import pytorch_lightning as pl
from torchmetrics import Accuracy

class MyLitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = Net()
        self.loss_fn = nn.CrossEntropyLoss()
        self.train_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)
        
    
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        acc = self.train_accuracy(logits, y)
        self.log("train_accuracy", acc, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss)
        acc = self.val_accuracy(logits, y)
        self.log("val_accuracy", acc)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log("test_loss", loss)
        acc = self.test_accuracy(logits, y)
        self.log("test_accuracy", acc)

    def configure_optimizers(self):
        #optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9)
        optimizer = torch.optim.Adadelta(self.parameters())
        return optimizer

In [9]:
from pytorch_lightning.callbacks import EarlyStopping

early_stopping_calback = EarlyStopping(monitor="train_loss", min_delta=0.1, mode="min")

In [6]:
model = MyLitModel()
dm = MNISTDataModule()

trainer = pl.Trainer(
    max_epochs=12,
    accelerator="auto",
    devices="auto",
    callbacks=early_stopping_calback
)

trainer.fit(model, train_dataloader, val_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | model          | Net                | 1.2 M 
1 | loss_fn        | CrossEntropyLoss   | 0     
2 | train_accuracy | MulticlassAccuracy | 0     
3 | val_accuracy   | MulticlassAccuracy | 0     
4 | test_accuracy  | MulticlassAccuracy | 0     
------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.800     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [7]:
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.03228896111249924, 'test_accuracy': 0.989300012588501}]