<a href="https://colab.research.google.com/github/darshats/TSAI/blob/main/tsai_assignment4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.optim as optim
import random
import numpy as np

In [2]:
mnist_train_set = torchvision.datasets.FashionMNIST(root='./data', 
                                                    train=True, 
                                                    download=True, 
                                                    transform = transforms.Compose([transforms.ToTensor()])
)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
## Wrapper class on MNIST plus random generated scalar to add. 
## Each item is a tuple (img, b), each label is a tuple (i, i+b)
class compound_dataset(Dataset):
  def __init__(self, mnist_dataset):
    ## stash away mnist just in case
    self.mnist = mnist_dataset
    ## generate random numbers for the second input
    add_list = [random.randrange(0,10) for k in range(len(mnist_dataset))]

    self.compound_input = []
    self.compound_label = []
    for a,b in zip(mnist_train_set, add_list):
      ## input is image, and random number
      self.compound_input.append((a[0], b))
      ## label is the image value and sum with random number
      self.compound_label.append((a[1], a[1]+b))

  def __len__(self):
    return len(self.compound_input)

  def __getitem__(self, idx):
    input = self.compound_input[idx]
    ## labels need to be one hot encoded
    mnist_label, add_label = self.compound_label[idx]
    mnist_label = torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(mnist_label), value=1)
    add_label = torch.zeros(19, dtype=torch.float).scatter_(dim=0, index=torch.tensor(add_label), value=1)
    return (self.compound_input[idx], (mnist_label, add_label))

In [4]:
## wrapper classes to iterate over the input
## batchsize is 1. Specifying >1 is causing a mess with wierd errors related to one-hot encoding of second digit. Should we do it, or shouldnt we?
compound_ds = compound_dataset(mnist_train_set)
train_loader = torch.utils.data.DataLoader(compound_ds, batch_size=1, shuffle=True)

**Network**

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class Network(nn.Module):
  def __init__(self):
    super().__init__()

    ## This part is same as MNIST network taught in class
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5) 
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

    ## this maps one hot encoded second input to 20 features
    self.fc_u = nn.Linear(in_features=10, out_features=20)

    ## this maps concatenated (activation of image + second digit) to 100 features
    self.fc_combo = nn.Linear(in_features=212, out_features=100)

    ## define two outputs
    ## mnist
    self.out_mnist = nn.Linear(in_features=100, out_features=10)
    ## sum
    self.out_sum = nn.Linear(in_features=100, out_features=19)

  ## t is the batch of images, u is batch of numbers to add
  def forward(self, t, u):
    x = t

    ## regular mnist network for images
    ## conv1 layer
    x = self.conv1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2) #28 | 24 | 12

    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2) #12 | 6 | 4 >> 12x4x4

    # reshape
    x = x.reshape(-1, 12*4*4)

    ## now pass the second input through fcn
    ## one hot encode 
    u = torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=u, value=1)
    u = self.fc_u(u)
    u = F.relu(u)
    u = u.reshape(-1, 20)

    ## combine the vectors for image and digit encoding
    combo = torch.cat([x, u], dim=1) #192,20=212
    combo = self.fc_combo(combo) #100
    combo = F.relu(combo)

    out_mnist = self.out_mnist(combo) #input 100, output 10
    out_sum = self.out_sum(combo) #input 100, output 19
    return (out_mnist, out_sum)

In [6]:
network = Network()
optimizer = optim.Adam(network.parameters(), lr = 0.01)

In [None]:
for epoch in range(10):
  total_loss = 0
  total_mnist_correct = 0
  total_add_correct = 0
  for batch in train_loader:
    image, label = batch
    
    preds = network(image[0], image[1])
    loss_mnist = F.cross_entropy(preds[0], label[0].argmax(dim=1))
    loss_add = F.cross_entropy(preds[1], label[1].argmax(dim=1))
    #print(f'mnist loss {loss_mnist.item()}, add loss {loss_add.item()}')

    optimizer.zero_grad()
    ## only move backward through addition loss since the networks are overlapping. If we try to work backwards through both the losses, 
    ## we get an error: RuntimeError: Trying to backward through the graph a second time 
    loss_add.backward()
    optimizer.step()
    total_loss += loss_add.item()
    total_mnist_correct += preds[0].argmax(dim=1).eq(label[0].argmax(dim=1)).sum().item()
    total_add_correct += preds[1].argmax(dim=1).eq(label[1].argmax(dim=1)).sum().item()
  print(f'epoch {epoch} total_loss {total_loss}, add loss {add_loss}, mnist loss {mnist_loss}, mnist correct {total_mnist_correct}, add correct {total_add_correct}')

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
