In [1]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torchvision
import pytorch_lightning as pl
import torch
import torch.nn as nn

In [2]:
import wandb

In [3]:
wandb.init(
    project="demo",
    name="mnist-lighting-demo",
    tags=["demo"],
    config={
        "lr": 1e-4,
        "epoch": 4,
        "batch_size": 128,
        "weight_decay": 1e-5
    }
)
wandb.define_metric("loss", summary="min")
wandb.define_metric("val_loss", summary="min")
wandb.define_metric("val_acc", summary="max")
wandb.define_metric("test_loss", summary="mean")
wandb.define_metric("test_acc", summary="mean")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzendwang040302[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01114896111111092, max=1.0)…

<wandb.sdk.wandb_metric.Metric at 0x17c5de2d0>

In [4]:
class MnistDataModule(pl.LightningDataModule):
    
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    
    def prepare_data(self):
        torchvision.datasets.MNIST(root='./data', train=True, download=True)
        torchvision.datasets.MNIST(root='./data', train=False, download=True)
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            full_train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor())
            train_size = int(0.8 * len(full_train_dataset))
            val_size = len(full_train_dataset) - train_size
            self.train, self.val = random_split(full_train_dataset, [train_size, val_size])
        else:
            self.test = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor())
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)    

In [5]:
from typing import Any


class MnistModel(pl.LightningModule):
    
    def __init__(self):
        super(MnistModel, self).__init__()
        x_len = 28
        self.conv = nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1
        )
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten_before_attn = nn.Flatten(start_dim=-2)
        self.attn = nn.MultiheadAttention(
            embed_dim=(x_len // 4) ** 2, num_heads=1, batch_first=True
        )
        self.flatten_after_attn = nn.Flatten(start_dim=-2)
        flattened_last_dim = 64 * (x_len // 4) ** 2
        self.fc1 = nn.Linear(flattened_last_dim, flattened_last_dim * 2)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(flattened_last_dim * 2, flattened_last_dim)
        self.activation2 = nn.ReLU()
        self.fc3 = nn.Linear(flattened_last_dim, 10)
        
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.flatten_before_attn(x)
        x, _ = self.attn(x, x, x)
        x = self.flatten_after_attn(x)
        x = self.fc1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(), 
            lr=wandb.config.lr,
            weight_decay=wandb.config.weight_decay
        )

In [6]:
from pytorch_lightning.loggers.wandb import WandbLogger

In [7]:
logger = WandbLogger(wandb.run)

In [8]:
callbacks = [
    pl.callbacks.ModelCheckpoint(
        monitor='val_loss',
        dirpath='model',
        filename='mnist-{epoch:02d}-{val_loss:.2f}',
        save_top_k=1,
        mode='min',
    ),
    pl.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        mode='min'
    ),
    pl.callbacks.LearningRateMonitor(logging_interval='step')
]

In [9]:
trainer = pl.Trainer(
    max_epochs=wandb.config.epoch,
    logger=logger,
    callbacks=callbacks
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
data_module = MnistDataModule(wandb.config.batch_size)
model = MnistModel()

In [11]:
trainer.fit(model, data_module)

/Users/zend/Desktop/faster-pytorch/.pixi/envs/default/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

   | Name                | Type               | Params | Mode 
--------------------------------------------------------------------
0  | conv                | Conv2d             | 320    | train
1  | relu                | ReLU               | 0      | train
2  | maxpool1            | MaxPool2d          | 0      | train
3  | conv2               | Conv2d             | 18.5 K | train
4  | maxpool2            | MaxPool2d          | 0      | train
5  | flatten_before_attn | Flatten            | 0      | train
6  | attn                | MultiheadAttention | 9.8 K  | train
7  | flatten_after_attn  | Flatten            | 0      | train
8  | fc1                 | Linear             | 19.

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/zend/Desktop/faster-pytorch/.pixi/envs/default/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/zend/Desktop/faster-pytorch/.pixi/envs/default/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=4` reached.


In [12]:
trainer.test(model, data_module)

/Users/zend/Desktop/faster-pytorch/.pixi/envs/default/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9613000154495239
        test_loss           0.11706113070249557
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.11706113070249557, 'test_acc': 0.9613000154495239}]

In [13]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆█
loss,█▄▄▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁
lr-Adam,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
test_loss,▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
val_acc,▁▅▇█
val_loss,█▄▂▁

0,1
epoch,4.0
lr-Adam,0.0001
trainer/global_step,1500.0
