# Modified MNIST Demo

Modifying the code from our CoE 197 class to use the dataset provided in /data/imagenet/ for training and testing.
Added model checkpoint callback to save model with best test_acc.

### Image Recognition on MNIST using PyTorch Lightning

Demonstrating the elements of machine learning:

1) **E**xperience (Datasets and Dataloaders)<br>
2) **T**ask (Classifier Model)<br>
3) **P**erformance (Accuracy)<br>

**Experience:** <br>
We use MNIST dataset for this demo. MNIST is made of 28x28 images of handwritten digits, `0` to `9`. The train split has 60,000 images and the test split has 10,000 images. Images are all gray-scale.

**Task:**<br>
Our task is to classify the images into 10 classes. We use ResNet18 model from torchvision.models. The ResNet18 first convolutional layer (`conv1`) is modified to accept a single channel input. The number of classes is set to 10.

**Performance:**<br>
We use accuracy metric to evaluate the performance of our model on the test split. `torchmetrics.functional.accuracy`  calculates the accuracy.

**[Pytorch Lightning](https://www.pytorchlightning.ai/):**<br>
Our demo uses Pytorch Lightning to simplify the process of training and testing. Pytorch Lightning `Trainer` trains and evaluates our model. The default configurations are for a GPU-enabled system with 48 CPU cores. Please change the configurations if you have a different system.

**[Weights and Biases](https://www.wandb.ai/):**<br>
`wandb` is used by PyTorch Lightining Module to log train and evaluations results. Use `--no-wandb` to disable `wandb`.


Let us install `pytorch-lightning` and `torchmetrics`.

In [None]:
%pip install pytorch-lightning --upgrade
%pip install torchmetrics --upgrade

In [None]:
import torch
import torchvision
import os
import wandb
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy

### Custom Data Loader

In [None]:
class ImageNetDataModule(LightningDataModule):
    def __init__(self, path, batch_size=128, num_workers=0, class_dict={}, 
                 **kwargs):
        super().__init__(**kwargs)
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.class_dict = class_dict

    def prepare_data(self):
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Grayscale(),
            torchvision.transforms.CenterCrop(100),
            torchvision.transforms.ToTensor(),
            ])

        self.train_dataset = torchvision.datasets.ImageNet("/data/imagenet/", split='train', transform=self.transform)
                                                                
        self.val_dataset = self.train_dataset

        self.test_dataset = torchvision.datasets.ImageNet("/data/imagenet/", split='val', transform=self.transform)

    def setup(self, stage=None):
        self.prepare_data()

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

### Pytorch Lightning Module

PyTorch Lightning Module has a PyTorch ResNet18 Model. It is a subclass of LightningModule. The model part is subclassed to support a single channel input. We replaced the input convolutional layer to support single channel inputs. The Lightning Module is also a container for the model, the optimizer, the loss function, the metrics, and the data loaders.

`ResNet` class can be found [here](https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html).

By using PyTorch Lightning, we simplify the training and testing processes since we do not need to write boiler plate code blocks. These include automatic transfer to chosen device (i.e. `gpu` or `cpu`), model `eval` and `train` modes, and backpropagation routines.

In [None]:
class LitMNISTModel(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
        self.model = torchvision.models.resnet18(num_classes=num_classes)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7,
                                           stride=2, padding=3, bias=False)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    # this is called during fit()
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        return {"loss": loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    # this is called at the end of an epoch
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y) * 100.
        # we use y_hat to display predictions during callback
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    # this is called at the end of all epochs
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc, on_epoch=True, prog_bar=True)

    # validation is the same as test
    def validation_step(self, batch, batch_idx):
       return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    # we use Adam optimizer
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

### PyTorch Lightning Callback

We can instantiate a callback object to perform certain tasks during training. In this case, we log sample images, ground truth labels, and predicted labels from the test dataset.

We can also `ModelCheckpoint` callback to save the model after each epoch.

In [None]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # process first 10 images of the first batch
        if batch_idx == 0:
            n = 10
            x, y = batch
            outputs = outputs["y_hat"]
            outputs = torch.argmax(outputs, dim=1)
            # log image, ground truth and prediction on wandb table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(
                zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(
                key='ResNet18 on ImageNet Predictions',
                columns=columns,
                data=data)


### Program Arguments

When running on command line, we can pass arguments to the program. For the jupyter notebook, we can pass arguments using the `%run` magic command.

```

In [None]:
def get_args():
    parser = ArgumentParser(description="PyTorch Lightning MNIST Example")
    parser.add_argument("--max-epochs", type=int, default=5, help="num epochs")
    parser.add_argument("--batch-size", type=int, default=32, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--path", type=str, default="./")

    parser.add_argument("--num-classes", type=int, default=1000, help="num classes")

    parser.add_argument("--devices", default=1)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--num-workers", type=int, default=48, help="num workers")
    
    parser.add_argument("--no-wandb", default=False, action='store_true')
    args = parser.parse_args("")
    return args

### Training and Evaluation using `Trainer`

Get command line arguments. Instatiate a Pytorch Lightning Model. Train the model. Evaluate the model.

In [None]:
if __name__ == "__main__":
    args = get_args()
    model = LitMNISTModel(num_classes=args.num_classes,
                          lr=args.lr, batch_size=args.batch_size)
    datamodule = ImageNetDataModule(path=args.path, batch_size=args.batch_size, num_workers=args.num_workers)
    datamodule.setup()

    # printing the model is useful for debugging
    print(model)

    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="pl-mnist")

    model_checkpoint = ModelCheckpoint(
        dirpath=os.path.join(args.path, "checkpoints"),
        filename="resnet18-mnist-best-acc",
        save_top_k=1,
        verbose=True,
        monitor='test_acc',
        mode='max',
    )
    
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger if not args.no_wandb else None,
                      callbacks=[model_checkpoint, WandbCallback() if not args.no_wandb else None])
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    wandb.finish()
