<a href="https://colab.research.google.com/github/mln00b/end2.0/blob/main/MNIST_Sum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [1]:
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

from typing import Dict, Tuple
from argparse import Namespace

## Dataset and DataLoaders

In [2]:
class MNISTSumDataset(Dataset):
    def __init__(self, train: bool) -> None:
        mnist_transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.mnist_ds = datasets.MNIST(
            './dataset', train=train, download=True,
            transform=mnist_transform
        )
    
    def __len__(self):
        return len(self.mnist_ds)
    
    def __getitem__(self, idx) -> Dict:
        random_num = ((torch.rand(1)[0]*9).int()).float()  # random no. b/w 0-9
        img, lbl = self.mnist_ds[idx]
        sum_lbl = random_num + lbl
        return {"img": img, "rand_num": random_num, "lbl": lbl, "sum_lbl": sum_lbl}


def get_data() -> Tuple[DataLoader, DataLoader]:
    
    train_ds = MNISTSumDataset(train=True)
    val_ds = MNISTSumDataset(train=False)

    train_kwargs = {'batch_size': 64}
    test_kwargs = {'batch_size': 1000}

    if torch.cuda.is_available():
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)


    train_loader = DataLoader(train_ds,**train_kwargs)
    val_loader = DataLoader(val_ds,**test_kwargs)

    return train_loader, val_loader

## Network

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # For MNIST classification
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

        # For sum calculation
        self.fc3 = nn.Linear(1, 16)
        self.fc4 = nn.Linear(144, 16)
        self.fc5 = nn.Linear(16, 1)

    def forward(self, x, num):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)

        y = self.fc3(num.unsqueeze(dim=1))
        y = F.relu(y)
        y = torch.cat((x, y), dim=1)
        y = self.fc4(y)
        y = F.relu(y)
        y = self.fc5(y)

        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        return x, y

## Trainer

In [4]:
def train_mnist(args):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    train_loader, val_loader = get_data()

    model = Net().to(device)
    print(model)
    opt = optim.Adam(model.parameters(), lr=3e-4)

    cross_el_loss = nn.CrossEntropyLoss()
    l1_loss = nn.L1Loss()

    for epoch in range(args.epochs):
        # Train
        model.train()
        for batch_idx, data in enumerate(train_loader):
            img = data["img"].to(device)
            rand_num = data["rand_num"].to(device)
            cls_lbl = data["lbl"].to(device)
            sum_lbl = data["sum_lbl"].unsqueeze(dim=1).to(device)

            opt.zero_grad()
            cls_pred, sum_pred = model(img, rand_num)

            loss_cls = cross_el_loss(cls_pred, cls_lbl)
            loss_sum = l1_loss(sum_pred, sum_lbl)

            loss = loss_cls + loss_sum
            
            loss.backward()
            opt.step()

        print(f"Epoch: {epoch}, Train Cls loss: {loss_cls.item()}, Train Sum loss: {loss_sum.item()}, Train Total loss: {loss.item()}")

        # Eval
        model.eval()
        total_cls = 0
        correct_cls = 0
        total_sum = 0
        correct_sum = 0
        with torch.no_grad():
            for batch_idx, data in enumerate(val_loader):
                img = data["img"].to(device)
                rand_num = data["rand_num"].to(device)
                cls_lbl = data["lbl"].to(device)
                sum_lbl = data["sum_lbl"].unsqueeze(dim=1).to(device)

                cls_pred, sum_pred = model(img, rand_num)

                # Can be made better
                for idx, i in enumerate(cls_pred):
                    if torch.argmax(i) == cls_lbl[idx]:
                        correct_cls += 1
                    total_cls += 1
                
                # Can be made better
                for idx, i in enumerate(sum_pred):
                    p = sum_pred[idx][0].int()
                    l = sum_lbl[idx][0].int()
                    if (p == l):
                        correct_sum += 1
                    total_sum += 1
        
        print(f"Epoch: {epoch}, Test Cls accuracy: {correct_cls / total_cls}, Test Sum accuracy: {correct_sum / total_sum}")

## Run

In [6]:
args = Namespace(
    epochs=20
)

train_mnist(args)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (fc3): Linear(in_features=1, out_features=16, bias=True)
  (fc4): Linear(in_features=144, out_features=16, bias=True)
  (fc5): Linear(in_features=16, out_features=1, bias=True)
)
Epoch: 0, Train Cls loss: 0.05834564194083214, Train Sum loss: 0.7465299963951111, Train Total loss: 0.8048756122589111
Epoch: 0, Test Cls accuracy: 0.9789, Test Sum accuracy: 0.4539
Epoch: 1, Train Cls loss: 0.08374294638633728, Train Sum loss: 0.45235127210617065, Train Total loss: 0.5360941886901855
Epoch: 1, Test Cls accuracy: 0.9848, Test Sum accuracy: 0.4868
Epoch: 2, Train Cls loss: 0.03668471425771713, Train Sum loss: 0.2454855740070343, Train Total loss: 0.2821702957