**Imports**

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from random import randint
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

**Model**

In [2]:
dropout_value = 0.05

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.convblock1 = nn.Sequential(
            
            nn.Conv2d(in_channels=1, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), # output_size = 26
            
        )
      

        self.convblock2 = nn.Sequential(
            
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), #output_size = 24
            
        )

        
        self.convblock3 = nn.Sequential(
            
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), #output_size = 22
            
        )

        self.pool3 = nn.MaxPool2d(2, 2) # output_size = 11


        self.convblock4 = nn.Sequential(
          

            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), # output_size = 9   
        )

        

        self.convblock5 = nn.Sequential(   
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False), 
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value),  # output_size = 7     
            
        )

        self.convblock6 = nn.Sequential(   
            nn.Conv2d(in_channels=12, out_channels=16, kernel_size=(3, 3), bias=False), 
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(dropout_value),   # output_size = 5     
            
        )

        self.convblock7 = nn.Sequential(   
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding = 1,bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(dropout_value),       # output_size = 5  
            
        )

        self.gap = nn.Sequential(
            nn.AvgPool2d(kernel_size=5) # output_size = 1*1*16
        ) 


        self.convblock8 = nn.Sequential(   
            nn.Conv2d(in_channels=16, out_channels=10, kernel_size=(1, 1),bias=False), # output_size = 1 * 1* 10
                      
        )
        self.addition_layer1 = nn.Linear(in_features=2, out_features=10)
        self.addition_layer2 = nn.Linear(in_features=10, out_features=10)
        self.addition_out_layer = nn.Linear(in_features=10, out_features=1)



    def forward(self, x, random_input):

        number_result = random_input
        x = self.convblock1(x)   
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.pool3(x)
        x = self.convblock4(x)
        x = self.convblock5(x)
        x = self.convblock6(x)
        x = self.convblock7(x)
        x = self.gap(x)
        x = self.convblock8(x)
        outputImage = x.view(-1, 10)
        imageOutput = torch.argmax(outputImage, dim=1)
        image_value_and_number  = torch.stack((imageOutput.float(), random_input), dim=1)
        addition_result = self.addition_layer1(image_value_and_number)
        addition_result = F.relu(addition_result)
        addition_result = self.addition_layer2(addition_result)
        addition_result = F.relu(addition_result)
        addition_result = self.addition_out_layer(addition_result)
        return outputImage, addition_result

**Custom Dataset**

In [3]:
from torch.utils.data import Dataset

class CustomDataSet(Dataset):
  def __init__(self, isTrain):
    if isTrain:
      self.data = datasets.MNIST('./data', train=isTrain, download=True, transform=transforms.Compose([transforms.RandomRotation((-5.0, 5.0), fill=(1,)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
    else:
      self.data = datasets.MNIST('./data', train=isTrain, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

  def __getitem__(self, index):
    r = self.data[index]
    image, label = r
    random_input = torch.tensor(randint(0,9), dtype=torch.float32)    #torch.randint(0, 10, (1,))
    random_target = random_input + label
    return image, label, random_input, random_target

  def __len__(self):
    return len(self.data)

## Train loader

In [4]:
train_set = CustomDataSet(True)
test_set = CustomDataSet(False)

torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, **kwargs)



## Train and Test Functions

In [5]:
from tqdm import tqdm
# mse_loss = nn.MSELoss()
# cross_entropy = nn.CrossEntropyLoss()

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    # pbar = tqdm(train_loader)
    correct = 0
    processed = 0
    correct_addition = 0
    for batch_idx, (data, target,random_input, random_target) in enumerate(train_loader):
        data, target, random_input, random_target = data.to(device), target.to(device), random_input.to(device), random_target.to(device).reshape(-1,1)
        optimizer.zero_grad()
        output, number_output = model(data, random_input)
        loss = F.cross_entropy(output, target)
        loss_l1 = F.mse_loss(number_output, random_target)

        loss = loss + loss_l1

        loss.backward()
        optimizer.step()
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        correct_addition += (torch.round(number_output).reshape(1,-1) == torch.round(random_target).reshape(1,-1)).sum().item()
    print(f'loss={loss.item()} batch_id={batch_idx} Accuracy of Image={100*correct/processed}, Accuracy of addition={100*correct_addition/processed}, correct: {(torch.round(number_output) == torch.round(random_target)).sum().item()}')
   

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    correct_addition = 0
    with torch.no_grad():
        for data, target, random_input, random_target in test_loader:
            data, target, random_input, random_target = data.to(device), target.to(device), random_input.to(device), random_target.to(device).reshape(-1,1)
            output, number_output = model(data, random_input)
            loss1 = F.cross_entropy(output, target)  # sum up batch loss
            loss_f1 = F.mse_loss(number_output, random_target)
            test_loss = loss1 + loss_f1
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            correct_addition += (torch.round(number_output).reshape(1,-1) == torch.round(random_target).reshape(1,-1)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    print(f'Test accuracy {100 * correct_addition/len(test_loader.dataset)}')

## Running Epochs

In [6]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(1, 10):
    print("EPOCH:", epoch)
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    print('------------------------------------------')

EPOCH: 1
loss=2.6698544025421143 batch_id=468 Accuracy of Image=64.38666666666667, Accuracy of addition=32.74, correct: 76

Test set: Average loss: 0.0002, Accuracy: 9048/10000 (90.48%)
Test accuracy 74.42
------------------------------------------
EPOCH: 2
loss=1.3775510787963867 batch_id=468 Accuracy of Image=92.35166666666667, Accuracy of addition=82.97833333333334, correct: 91

Test set: Average loss: 0.0000, Accuracy: 9572/10000 (95.72%)
Test accuracy 94.65
------------------------------------------
EPOCH: 3
loss=0.8939895629882812 batch_id=468 Accuracy of Image=95.4, Accuracy of addition=92.59333333333333, correct: 95

Test set: Average loss: 0.0003, Accuracy: 9686/10000 (96.86%)
Test accuracy 96.07
------------------------------------------
EPOCH: 4
loss=1.2808877229690552 batch_id=468 Accuracy of Image=96.53833333333333, Accuracy of addition=94.91, correct: 89

Test set: Average loss: 0.0003, Accuracy: 9777/10000 (97.77%)
Test accuracy 96.87
------------------------------------