PyTorch Lightning
* Enforced standards
* Rids code of boilerplate
* Abstracts away a lot of extras - logging parallelization

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl


In [None]:
train_data_set = MNIST(root='data', train=True, download=True, transform=ToTensor())
validation_data_set = MNIST(root='data', train=False, download=True, transform=ToTensor())

batch_size = 64

training_dataloader = DataLoader(train_data_set, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_data_set, batch_size=batch_size, shuffle=False)

In [None]:
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)
        self.learning_rate = 0.5

    def forward(self, data_batch):
        data_batch = data_batch.flatten(1, -1)
        return self.lin(data_batch)
    
    def training_step(self, batch, batch_idx):
        data_batch, label_batch = batch
        pred = self(data_batch)
        loss = F.cross_entropy(pred, label_batch)
        
        # Calculate accuracy
        acc = (pred.argmax(dim=1) == label_batch).float().mean()
        
        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        data_batch, label_batch = batch
        pred = self(data_batch)
        loss = F.cross_entropy(pred, label_batch)
        
        # Calculate accuracy
        acc = (pred.argmax(dim=1) == label_batch).float().mean()
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
        return loss

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=self.learning_rate)

In [7]:
from pytorch_lightning.loggers import TensorBoardLogger
tensor_board_logger = TensorBoardLogger('tb_logs')

ModuleNotFoundError: Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either.
Requirement 'tensorboardX' not met. HINT: Try running `pip install -U 'tensorboardX'`
Requirement 'tensorboard' not met. HINT: Try running `pip install -U 'tensorboard'`

In [None]:
mnist_model = MNISTModel()

trainer = pl.Trainer(max_epochs=3, logger=tensor_board_logger)

trainer.fit(mnist_model, training_dataloader, validation_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\badmo\anaconda3\envs\DL4begs\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name | Type   | Params | Mode 
----------------------------------------
0 | lin  | Linear | 7.9 K  | train
----------------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
1         Modules in train mode
0         Modules in eval mode
c:\

Epoch 1: 100%|██████████| 938/938 [00:13<00:00, 71.51it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 938/938 [00:13<00:00, 71.45it/s, v_num=0]


In [None]:
# Test the model
trainer.test(mnist_model, validation_dataloader)

# Make a prediction on a single sample
import matplotlib.pyplot as plt

# Get a sample from validation set
sample_data, sample_label = next(iter(validation_dataloader))
sample_image = sample_data[0]
sample_true_label = sample_label[0]

# Make prediction
mnist_model.eval()
with torch.no_grad():
    prediction = mnist_model(sample_image.unsqueeze(0))
    predicted_label = prediction.argmax(dim=1).item()

# Display the image and prediction
plt.figure(figsize=(6, 4))
plt.subplot(1, 2, 1)
plt.imshow(sample_image.squeeze(), cmap='gray')
plt.title(f'True: {sample_true_label}, Predicted: {predicted_label}')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.bar(range(10), F.softmax(prediction, dim=1).squeeze().numpy())
plt.title('Prediction Probabilities')
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.show()