In [1]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_dataset = datasets.MNIST(
    root="./mnist", train=True, transform=transforms.ToTensor(), download=True
)

test_dataset = datasets.MNIST(
    root="./mnist", train=False, transform=transforms.ToTensor()
)

In [2]:
import torch
from torch.utils.data.dataset import random_split

torch.manual_seed(1)
train_dataset, val_dataset = random_split(train_dataset, lengths=[55000, 5000])

In [11]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=256,
    shuffle=True,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=256,
    shuffle=False,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=256,
    shuffle=False,
)

In [4]:
from mlp import PyTorchMLP
model = PyTorchMLP(num_features=28 * 28, num_classes=10)

In [5]:
import lightning as L
from classification_module import LightningModel
lightning_model = LightningModel(model, learning_rate=0.05)

trainer = L.Trainer(max_epochs=10, accelerator="cpu")
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(lightning_model, dataloaders=test_loader)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/dcabrera/UPSE/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | PyTorchMLP         | 40.8 K | train
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size

                                                                           

/home/dcabrera/UPSE/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/dcabrera/UPSE/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 9: 100%|██████████| 860/860 [00:07<00:00, 110.56it/s, v_num=0, val_loss=0.115, val_acc=0.964, train_acc=0.973]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:07<00:00, 110.51it/s, v_num=0, val_loss=0.115, val_acc=0.964, train_acc=0.973]


/home/dcabrera/UPSE/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 146.13it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             0.97079998254776
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.97079998254776}]

In [12]:
from cnn import PyTorchCNN
model = PyTorchCNN(num_classes=10)

In [13]:
cnn_lightning_model = LightningModel(model, learning_rate=0.05)
trainer = L.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(cnn_lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(cnn_lightning_model, dataloaders=test_loader)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | PyTorchCNN         | 6.2 K  | train
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
22        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 215/215 [00:03<00:00, 63.11it/s, v_num=4, val_loss=0.0661, val_acc=0.980, train_acc=0.987]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 215/215 [00:03<00:00, 63.01it/s, v_num=4, val_loss=0.0661, val_acc=0.980, train_acc=0.987]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 40/40 [00:00<00:00, 96.70it/s] 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9815999865531921
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.9815999865531921}]