# PyTorch example on MNIST

## References

* [examples/mnist at master · pytorch/examples](https://github.com/pytorch/examples/blob/master/mnist/main.py)

In [None]:
import os
import argparse

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

# For type hinting
from typing import Union, Dict, List, Any
from torch import Tensor

## References

* [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)

### `__init__`

* [torch.nn.Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)
* [torch.nn.MaxPool2d](https://pytorch.org/docs/1.9.0/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d)
* [torch.nn.Dropout2d](https://pytorch.org/docs/stable/generated/torch.nn.Dropout2d.html#torch.nn.Dropout2d)
* [torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear)

### `forward`

* [forward](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html?highlight=forward#forward)
    - [torch.nn.forward](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.forward)
* [torch.nn.functional.relu](https://pytorch.org/docs/1.9.0/generated/torch.nn.functional.relu.html)
    - [torch.nn.ReLU](https://pytorch.org/docs/1.9.0/generated/torch.nn.ReLU.html#torch.nn.ReLU)
* [torch.flatten](https://pytorch.org/docs/stable/generated/torch.flatten.html)
* [torch.nn.functional.log_softmax](https://pytorch.org/docs/1.9.0/generated/torch.nn.functional.log_softmax.html#torch.nn.functional.log_softmax)
    - [torch.nn.LogSoftmax](https://pytorch.org/docs/1.9.0/generated/torch.nn.LogSoftmax.html#torch.nn.LogSoftmax)


In [None]:
class MNISTConvNet(torch.nn.Module):
    def __init__(self,
            conv1_in_channels: int, conv1_out_channels: int, conv1_kernel_size: int, conv1_stride: int,
            conv2_in_channels: int, conv2_out_channels: int, conv2_kernel_size: int, conv2_stride: int,
            pool1_kernel_size: int, dropout1_p: float, dropout2_p: float,
            fullconn1_in_features: int, fullconn1_out_features: int, fullconn2_in_features: int, fullconn2_out_features: int,
            adadelta_lr: float, adadelta_rho: float, adadelta_eps: float, adadelta_weight_decay: float,
            dataset_root: str, dataset_download: bool,
            dataloader_mean: tuple, dataloader_std: tuple, dataloader_batch_size: int, dataloader_num_workers: int
            ) -> None:
        super(MNISTConvNet, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=conv1_in_channels, out_channels=conv1_out_channels, kernel_size=conv1_kernel_size, stride=conv1_stride)
        self.conv2 = torch.nn.Conv2d(in_channels=conv2_in_channels, out_channels=conv2_out_channels, kernel_size=conv2_kernel_size, stride=conv2_stride)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=pool1_kernel_size)
        self.dropout1 = torch.nn.Dropout2d(p=dropout1_p, inplace=False)
        self.dropout2 = torch.nn.Dropout2d(p=dropout2_p, inplace=False)
        self.fullconn1 = torch.nn.Linear(in_features=fullconn1_in_features, out_features=fullconn1_out_features)
        self.fullconn2 = torch.nn.Linear(in_features=fullconn2_in_features, out_features=fullconn2_out_features)

        self.adadelta_params = {
            'lr': adadelta_lr,
            'rho': adadelta_rho,
            'eps': adadelta_eps,
            'weight_decay': adadelta_weight_decay,
        }

        self.dataset_params = {
            'root': dataset_root,
            'download': dataset_download,
        }

        self.dataloader_params = {
            'mean': dataloader_mean,
            'std': dataloader_std,
            'batch_size': dataloader_batch_size,
            'num_workers': dataloader_num_workers,
        }

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = F.relu(input=x)
        x = self.conv2(x)
        x = F.relu(input=x)
        x = self.pool1(x)
        x = self.dropout1(x)
        x = torch.flatten(input=x, start_dim=1)
        x = self.fullconn1(x)
        x = F.relu(input=x)
        x = self.dropout2(x)
        x = self.fullconn2(x)
        return F.log_softmax(input=x, dim=1)

In [None]:
def train(args, model, device, train_loader, optimizer, epoch) -> None:
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

In [None]:
def test(model, device, test_loader) -> None:
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
def get_argparser():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--conv1-in-channels', type=int, default=1)
    parser.add_argument('--conv1-out-channels', type=int, default=32)
    parser.add_argument('--conv1-kernel-size', type=int, default=3)
    parser.add_argument('--conv1-stride', type=int, default=1)
    parser.add_argument('--conv2-in-channels', type=int, default=32)
    parser.add_argument('--conv2-out-channels', type=int, default=64)
    parser.add_argument('--conv2-kernel-size', type=int, default=3)
    parser.add_argument('--conv2-stride', type=int, default=1)
    parser.add_argument('--pool1-kernel-size', type=int, default=2)
    parser.add_argument('--dropout1-p', type=float, default=0.25)
    parser.add_argument('--dropout2-p', type=float, default=0.5)
    parser.add_argument('--fullconn1-in-features', type=int, default=12*12*64)
    parser.add_argument('--fullconn1-out-features', type=int, default=128)
    parser.add_argument('--fullconn2-in-features', type=int, default=128)
    parser.add_argument('--fullconn2-out-features', type=int, default=10)
    parser.add_argument('--adadelta-lr', type=float, default=1.0)
    parser.add_argument('--adadelta-rho', type=float, default=0.9)
    parser.add_argument('--adadelta-eps', type=float, default=1e-06)
    parser.add_argument('--adadelta-weight-decay', type=float, default=0)
    parser.add_argument('--dataset-root', type=str, default=os.getcwd())
    parser.add_argument('--dataset-download', action='store_true', default=True)
    parser.add_argument('--dataloader-mean', type=tuple, default=(0.1302,))
    parser.add_argument('--dataloader-std', type=tuple, default=(0.3069,))
    parser.add_argument('--dataloader-batch-size', type=int, default=32)
    parser.add_argument('--dataloader-num-workers', type=int, default=4)

    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    return parser

In [None]:

def main(args=None) -> None:
    if not args:
        args = get_argparser().parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}

    if use_cuda:
        cuda_kwargs = {
            'num_worksers': 1,
            'pin_memory': True,
            'shuffle': True,
        }
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)
    
    tensor = [
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307,), std=(0.3081,), inplace=False)
    ]
    transform = transforms.Compose(tensor)
    dataset1 = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
    dataset2 = datasets.MNIST(root="../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = MNISTConvNet(conv1_in_channels=args.conv1_in_channels, conv1_out_channels=args.conv1_out_channels, conv1_kernel_size=args.conv1_kernel_size, conv1_stride=args.conv1_stride,
        conv2_in_channels=args.conv2_in_channels, conv2_out_channels=args.conv2_out_channels, conv2_kernel_size=args.conv2_kernel_size, conv2_stride=args.conv2_stride,
        pool1_kernel_size=args.pool1_kernel_size, dropout1_p=args.dropout1_p, dropout2_p=args.dropout2_p,
        fullconn1_in_features=args.fullconn1_in_features, fullconn1_out_features=args.fullconn1_out_features, fullconn2_in_features=args.fullconn2_in_features, fullconn2_out_features=args.fullconn2_out_features,
        adadelta_lr=args.adadelta_lr, adadelta_rho=args.adadelta_rho, adadelta_eps=args.adadelta_eps, adadelta_weight_decay=args.adadelta_weight_decay,
        dataset_root=args.dataset_root, dataset_download=args.dataset_download,
        dataloader_mean=args.dataloader_mean, dataloader_std=args.dataloader_std, dataloader_batch_size=args.dataloader_batch_size, dataloader_num_workers=args.dataloader_num_workers
        ).to(device)
    optimizer = torch.optim.Adadelta(params=model.parameters(), lr=args.lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=1, gamma=args.gamma)

    for epoch in range(1, args.epochs+1):
        train(args=args, model=model, device=device, train_loader=train_loader, optimizer=optimizer, epoch=epoch)
        test(model=model, device=device, test_loader=test_loader)
        scheduler.step()
    
    if args.save_model:
        torch.save(obj=model.state_dict(), f="mnist_cnn.pt")


In [None]:
argparser = get_argparser()
args = argparser.parse_args(
    [
        "--batch-size", str(32),
        "--test-batch-size", str(32),
        "--epochs", str(1),
        "--lr", str(0.5),
        "--gamma", str(0.7),
        "--no-cuda",
        "--seed", str(1),
        "--log-interval", str(10),
        "--save-model",
    ]
)
main(args)