### MLP for CIFAR10

Multi-Layer Perceptron (MLP) is a simple neural network model that can be used for classification tasks. 

In this demo, we will train a 3-layer MLP on the CIFAR10 dataset. We will illustrate 2 MLP implementations.

Let us first import the required modules.

In [4]:
#!pip install wandb
#!pip install einops
#!pip install argparse
#!pip install pytorch_lightning
import torch
import torchvision
import wandb
import math
from torch import nn
from einops import rearrange
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.7.7-py3-none-any.whl (708 kB)
[K     |████████████████████████████████| 708 kB 7.7 MB/s 
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.10.0-py3-none-any.whl (529 kB)
[K     |████████████████████████████████| 529 kB 50.2 MB/s 
[?25hCollecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Installing collected packages: torchmetrics, pyDeprecate, pytorch-lightning
Successfully installed pyDeprecate-0.3.2 pytorch-lightning-1.7.7 torchmetrics-0.10.0


#### MLP using PyTorch `nn.Linear`

The most straightforward way to implement an MLP is to use the `nn.Linear` module. In the following code, we implement a 3-layer MLP with GELU activation function. The GELU can be replaced by other activation functions such as RELU.

Pls take note of the correct sizes. `fc1` input size is `n_features` which is size of the flattened input `x`. `fc1` output size is `n_hidden` which then becomes the input size of `fc2`. In other words, all input/output sizes up to `fc3` fit together perfectly.


In [5]:
class SimpleMLP(nn.Module):
    def __init__(self, n_features=3*32*32, n_hidden=512, num_classes=10):
        super().__init__()
        # the 3 Linear layers of the MLP
        self.fc1 = nn.Linear(n_features, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3 = nn.Linear(n_hidden, num_classes)

    def forward(self, x):
        # flatten x - (batch_size, 3, 32, 32) -> (batch_size, 3*32*32)
        # ascii art for the case that x is 1 x 2 x 2 (channel, height, width)
        # ---------         ----------------- 
        # | 1 | 2 |    ---> | 1 | 2 | 3 | 4 |
        # | 3 | 4 |         -----------------
        # ---------
        # we can use any of the following methods to flatten the tensor
        #y = torch.flatten(x, 1)
        #y = x.view(x.size(0), -1)
        # but this is the most intuitive since it shows the actual flattening
        y = rearrange(x, 'b c h w -> b (c h w)')
        y = nn.GELU()(self.fc1(y))
        y = nn.GELU()(self.fc2(y))
        y = self.fc3(y)
        return y
        # we dont need to compute softmax since it is already
        # built into the CE loss function in PyTorch
        #return F.log_softmax(y, dim=1)

#### MLP implementation using Tensors

In this case, we illustrate how to implement the formula of an MLP layer using weights and biases. Note that if we remove the initialization of the weights and biases, the model will not converge. In the previous example, `Linear` automatically performs the weights and biases initialization.

In [6]:
class TensorMLP(nn.Module):
    def __init__(self, n_features=3*32*32, n_hidden=512, num_classes=10):
        super().__init__()

        # weights and biases for layer 1
        self.w1 = nn.Parameter(torch.empty((n_hidden, n_features)))
        self.b1 = nn.Parameter(torch.empty((n_hidden,)))

         # weights and biases for layer 2
        self.w2 = nn.Parameter(torch.empty((n_hidden, n_hidden)))
        self.b2 = nn.Parameter(torch.empty((n_hidden,)))

         # weights and biases for layer 3
        self.w3 = nn.Parameter(torch.empty((num_classes, n_hidden)))
        self.b3 = nn.Parameter(torch.empty((num_classes,)))

        # initialize parameters manually bec we implemented the linear layer manually
        self.reset_parameters()

    def reset_parameters(self):
        # we use Kaiming initializer for weights
        nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
        # zero for biases
        nn.init.constant_(self.b1, 0)
        nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
        nn.init.constant_(self.b2, 0)
        nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5))
        nn.init.constant_(self.b3, 0)

    def forward(self, x):
        # flatten
        y = rearrange(x, 'b c h w -> b (c h w)')
        # we manually compute the output of each layer
        y = y @ self.w1.T + self.b1
        y = nn.GELU()(y)
        y = y @ self.w2.T + self.b2
        y = nn.GELU()(y)
        y = y @ self.w3.T + self.b3
        return y


#### PyTorch Lightning Module for MLP

This is the PL module so we can easily change the implementation of the MLP and compare the results.  More detailed results can be found on the `wandb.ai` page.

Using `model` parameter, we can easily switch between the two MLP implementations shown above. We also benchmark the result using a ResNet18 model. The rest of the code is similar to our PL module example for MNIST.

In [7]:
class LitCIFAR10Model(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, batch_size=64,
                 num_workers=4, max_epochs=30,
                 model=SimpleMLP):
        super().__init__()
        self.save_hyperparameters()
        self.model = model(num_classes=num_classes)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    # this is called during fit()
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        return {"loss": loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    # this is called at the end of an epoch
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y) * 100.
        # we use y_hat to display predictions during callback
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    # this is called at the end of all epochs
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc, on_epoch=True, prog_bar=True)

    # validation is the same as test
    def validation_step(self, batch, batch_idx):
       return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    # we use Adam optimizer
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]
    
    # this is called after model instatiation to initiliaze the datasets and dataloaders
    def setup(self, stage=None):
        self.train_dataloader()
        self.test_dataloader()

    # build train and test dataloaders using MNIST dataset
    # we use simple ToTensor transform
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(
                "./data", train=True, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(
                "./data", train=False, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return self.test_dataloader()

#### Arguments

Please change the `--model` argument to switch between the different models to be used as CIFAR10 classifier.

In [8]:
def get_args():
    parser = ArgumentParser(description="PyTorch Lightning MNIST Example")
    parser.add_argument("--max-epochs", type=int, default=30, help="num epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--num-classes", type=int, default=10, help="num classes")

    parser.add_argument("--devices", default=1)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--num-workers", type=int, default=4, help="num workers")
    
    #parser.add_argument("--model", default=torchvision.models.resnet18)
    #parser.add_argument("--model", default=TensorMLP)
    parser.add_argument("--model", default=SimpleMLP)
    args = parser.parse_args("")
    return args

#### Weights and Biases Callback

The callback logs train and validation metrics to `wandb`. It also logs sample predictions. This is similar to our `WandbCallback` example for MNIST.

In [9]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # process first 10 images of the first batch
        if batch_idx == 0:
            label_human = ["airplane", "automobile", "bird", "cat",
                           "deer", "dog", "frog", "horse", "ship", "truck"]
            n = 10
            x, y = batch
            outputs = outputs["y_hat"]
            outputs = torch.argmax(outputs, dim=1)
            # log image, ground truth and prediction on wandb table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), label_human[y_i], label_human[y_pred]] for x_i, y_i, y_pred in list(
                zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(
                key=pl_module.model.__class__.__name__,
                columns=columns,
                data=data)

#### Training and Validation of Different Models

The validation accuracy of both MLP model implmentations are almost the same at `~53%`. This shows that the 2 MLP implementations are almost the same.

Meanwhile the ResNet18 model has accuracy of `~78%`. The MLP model has still a long way to go.

In [10]:
if __name__ == "__main__":
    args = get_args()
    model = LitCIFAR10Model(num_classes=args.num_classes,
                           lr=args.lr, batch_size=args.batch_size,
                           num_workers=args.num_workers,
                           model=args.model,)
    model.setup()

    # printing the model is useful for debugging
    print(model)
    print(model.model.__class__.__name__)

    

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


  cpuset_checked))


Files already downloaded and verified
LitCIFAR10Model(
  (model): SimpleMLP(
    (fc1): Linear(in_features=3072, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=10, bias=True)
  )
  (loss): CrossEntropyLoss()
)
SimpleMLP


In [None]:


    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="mlp-cifar")
    
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger,
                      callbacks=[WandbCallback()])
    trainer.fit(model)
    trainer.test(model)

    wandb.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 