# Introduction
<hr style="border:2px solid black"> </hr>


**What?** MLP in PyTorch and Lightning



# Import modules
<hr style="border:2px solid black"> </hr>

In [4]:
import os
import torch
from torch import nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
import pytorch_lightning as pl

# Load the dataset
<hr style="border:2px solid black"> </hr>

In [None]:
# Prepare CIFAR-10 dataset
dataset = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)

# Model via PyTorch
<hr style="border:2px solid black"> </hr>

In [None]:
class MLP(nn.Module):
    """Multi Layers Perceptron
    """    
    
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
          nn.Flatten(),
          nn.Linear(32 * 32 * 3, 64),
          nn.ReLU(),
          nn.Linear(64, 32),
          nn.ReLU(),
          nn.Linear(32, 10)
        )


    def forward(self, x):
        '''Forward pass'''
        return self.layers(x)

In [None]:
# Set fixed random number seed
torch.manual_seed(42)

# Initialize the MLP
mlp = MLP()

# Define the loss function and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

# Run the training loop
for epoch in range(0, 5): # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

        # Get inputs
        inputs, targets = data

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 500 == 499:
            print('Loss after mini-batch %5d: %.3f' % (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

# Pytorch Lightining
<hr style="border:2px solid black"> </hr>


- PyTorch Lightning makes creating PyTorch models easier. writing even a simple PyTorch model means writing a lot of code. And in fact, writing a lot of code that does nothing more than the default training process (like our training loop above). In Lightning, these elements are automated as much as possible.

- We are going to add two new function inside the `MLP` class: `training_step` and `configure_optimizers`. This is mandatory because Lightning takes care of the training loop. 
- Please note the super class has now changed to `pl.LightningModule`



In [3]:
class MLP(pl.LightningModule):
    """Multi Layers Perceptron
    """    
    
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
          nn.Flatten(),
          nn.Linear(32 * 32 * 3, 64),
          nn.ReLU(),
          nn.Linear(64, 32),
          nn.ReLU(),
          nn.Linear(32, 10)
        )


    def forward(self, x):
        '''Forward pass'''
        return self.layers(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.layers(x)
        # Cross entropy
        loss = self.ce(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer    

In [None]:
dataset = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
pl.seed_everything(42)
mlp = MLP()
trainer = pl.Trainer(auto_scale_batch_size='power', gpus=0, deterministic=True, max_epochs=5)
trainer.fit(mlp, DataLoader(dataset))

# References
<hr style="border:2px solid black"> </hr>


- https://www.machinecurve.com/index.php/2021/01/26/creating-a-multilayer-perceptron-with-pytorch-and-lightning/

