**Imports**

In [38]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from random import randint
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

**Model**

In [43]:
dropout_value = 0.05

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.convblock1 = nn.Sequential(
            
            nn.Conv2d(in_channels=1, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), # output_size = 26
            
        )
      

        self.convblock2 = nn.Sequential(
            
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), #output_size = 24
            
        )

        
        self.convblock3 = nn.Sequential(
            
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), #output_size = 22
            
        )

        self.pool3 = nn.MaxPool2d(2, 2) # output_size = 11


        self.convblock4 = nn.Sequential(
          

            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), # output_size = 9   
        )

        

        self.convblock5 = nn.Sequential(   
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False), 
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value),  # output_size = 7     
            
        )

        self.convblock6 = nn.Sequential(   
            nn.Conv2d(in_channels=12, out_channels=16, kernel_size=(3, 3), bias=False), 
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(dropout_value),   # output_size = 5     
            
        )

        self.convblock7 = nn.Sequential(   
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding = 1,bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(dropout_value),       # output_size = 5  
            
        )

        self.gap = nn.Sequential(
            nn.AvgPool2d(kernel_size=5) # output_size = 1*1*16
        ) 


        self.convblock8 = nn.Sequential(   
            nn.Conv2d(in_channels=16, out_channels=10, kernel_size=(1, 1),bias=False), # output_size = 1 * 1* 10
                      
        )
        self.addition_layer1 = nn.Linear(in_features=2, out_features=10)
        self.addition_layer2 = nn.Linear(in_features=10, out_features=10)
        self.addition_out_layer = nn.Linear(in_features=10, out_features=1)



    def forward(self, x, random_input):

        number_result = random_input
        x = self.convblock1(x)   
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.pool3(x)
        x = self.convblock4(x)
        x = self.convblock5(x)
        x = self.convblock6(x)
        x = self.convblock7(x)
        x = self.gap(x)
        x = self.convblock8(x)
        outputImage = x.view(-1, 10)
        imageOutput = torch.argmax(outputImage, dim=1)
        image_value_and_number  = torch.stack((imageOutput.float(), random_input), dim=1)
        addition_result = self.addition_layer1(image_value_and_number)
        addition_result = F.relu(addition_result)
        addition_result = self.addition_layer2(addition_result)
        addition_result = F.relu(addition_result)
        addition_result = self.addition_out_layer(addition_result)
        return outputImage, addition_result

**Custom Dataset**

In [44]:
from torch.utils.data import Dataset

class CustomDataSet(Dataset):
  def __init__(self, isTrain):
    if isTrain:
      self.data = datasets.MNIST('./data', train=isTrain, download=True, transform=transforms.Compose([transforms.RandomRotation((-5.0, 5.0), fill=(1,)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
    else:
      self.data = datasets.MNIST('./data', train=isTrain, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

  def __getitem__(self, index):
    r = self.data[index]
    image, label = r
    random_input = torch.tensor(randint(0,9), dtype=torch.float32)    #torch.randint(0, 10, (1,))
    random_target = random_input + label
    return image, label, random_input, random_target

  def __len__(self):
    return len(self.data)

## Train loader

In [46]:
train_set = CustomDataSet(True)
test_set = CustomDataSet(False)

torch.manual_seed(1)
batch_size = 32

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, **kwargs)



## Train and Test Functions

In [49]:
from tqdm import tqdm
mse_loss = nn.MSELoss()
cross_entropy = nn.CrossEntropyLoss()

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0
    correct_addition = 0
    for batch_idx, (data, target,random_input, random_target) in enumerate(pbar):
        data, target, random_input, random_target = data.to(device), target.to(device), random_input.to(device), random_target.to(device).reshape(-1,1)
        optimizer.zero_grad()
        output, number_output = model(data, random_input)
        loss = cross_entropy(output, target)
        loss_l1 = mse_loss(number_output, random_target)

        loss = loss + loss_l1

        loss.backward()
        optimizer.step()
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        correct_addition += (torch.round(number_output).reshape(1,-1) == torch.round(random_target).reshape(1,-1)).sum().item()
    pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx} Accuracy of Image={100*correct/processed}, Accuracy of addition={100*correct_addition/processed}')
   

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    correct_addition = 0
    with torch.no_grad():
        for data, target, random_input, random_target in test_loader:
            data, target, random_input, random_target = data.to(device), target.to(device), random_input.to(device), random_target.to(device).reshape(-1,1)
            output, number_output = model(data, random_input)
            loss1 = cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            loss_f1 = mse_loss(number_output, random_target)
            test_loss = loss1 + loss_f1
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            correct_addition += (torch.round(number_output).reshape(1,-1) == torch.round(random_target).reshape(1,-1)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    print(f'Test accuracy {100 * correct_addition/len(test_loader.dataset)}')

## Running Epochs

In [37]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 10):
    print("EPOCH:", epoch)
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    print('------------------------------------------')





  0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A

EPOCH: 1






loss=12.137669563293457 batch_id=0 Accuracy of Image=7.8125, Accuracy of addition=0.78125:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=11.742212295532227 batch_id=1 Accuracy of Image=10.15625, Accuracy of addition=1.953125:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=11.742212295532227 batch_id=1 Accuracy of Image=10.15625, Accuracy of addition=1.953125:   0%|          | 2/469 [00:00<00:33, 14.03it/s][A[A[A[A



loss=11.639280319213867 batch_id=2 Accuracy of Image=10.416666666666666, Accuracy of addition=1.3020833333333333:   0%|          | 2/469 [00:00<00:33, 14.03it/s][A[A[A[A



loss=11.577682495117188 batch_id=3 Accuracy of Image=12.109375, Accuracy of addition=0.9765625:   0%|          | 2/469 [00:00<00:33, 14.03it/s]                  [A[A[A[A



loss=11.577682495117188 batch_id=3 Accuracy of Image=12.109375, Accuracy of addition=0.9765625:   1%|          | 4/469 [00:00<00:30, 15.25it/s][A[A[A[A



loss=11.321386337280273 batch_i


Test set: Average loss: 0.0002, Accuracy: 9802/10000 (98.02%)
Test accuracy 12.79
------------------------------------------
EPOCH: 2






loss=1.4300951957702637 batch_id=0 Accuracy of Image=96.875, Accuracy of addition=8.59375:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=1.4300951957702637 batch_id=0 Accuracy of Image=96.875, Accuracy of addition=8.59375:   0%|          | 1/469 [00:00<01:03,  7.32it/s][A[A[A[A



loss=1.6742832660675049 batch_id=1 Accuracy of Image=96.875, Accuracy of addition=7.421875:   0%|          | 1/469 [00:00<01:03,  7.32it/s][A[A[A[A



loss=0.44024932384490967 batch_id=2 Accuracy of Image=96.61458333333333, Accuracy of addition=35.9375:   0%|          | 1/469 [00:00<01:03,  7.32it/s][A[A[A[A



loss=0.44024932384490967 batch_id=2 Accuracy of Image=96.61458333333333, Accuracy of addition=35.9375:   1%|          | 3/469 [00:00<00:52,  8.87it/s][A[A[A[A



loss=0.6879429817199707 batch_id=3 Accuracy of Image=96.484375, Accuracy of addition=41.2109375:   1%|          | 3/469 [00:00<00:52,  8.87it/s]      [A[A[A[A



loss=0.6402328014373779 batch_id=4 Accuracy


Test set: Average loss: 0.0000, Accuracy: 9889/10000 (98.89%)
Test accuracy 49.45
------------------------------------------
EPOCH: 3






loss=0.6380431652069092 batch_id=0 Accuracy of Image=99.21875, Accuracy of addition=50.78125:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.6380431652069092 batch_id=0 Accuracy of Image=99.21875, Accuracy of addition=50.78125:   0%|          | 1/469 [00:00<00:55,  8.44it/s][A[A[A[A



loss=0.38571763038635254 batch_id=1 Accuracy of Image=99.21875, Accuracy of addition=60.546875:   0%|          | 1/469 [00:00<00:55,  8.44it/s][A[A[A[A



loss=1.11574125289917 batch_id=2 Accuracy of Image=98.95833333333333, Accuracy of addition=47.65625:   0%|          | 1/469 [00:00<00:55,  8.44it/s][A[A[A[A



loss=1.11574125289917 batch_id=2 Accuracy of Image=98.95833333333333, Accuracy of addition=47.65625:   1%|          | 3/469 [00:00<00:47,  9.81it/s][A[A[A[A



loss=1.1292146444320679 batch_id=3 Accuracy of Image=98.828125, Accuracy of addition=40.4296875:   1%|          | 3/469 [00:00<00:47,  9.81it/s]    [A[A[A[A



loss=0.19551651179790497 batch_id=4 Acc


Test set: Average loss: 0.0000, Accuracy: 9914/10000 (99.14%)
Test accuracy 99.14
------------------------------------------
EPOCH: 4






loss=0.0926138162612915 batch_id=0 Accuracy of Image=99.21875, Accuracy of addition=99.21875:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.0926138162612915 batch_id=0 Accuracy of Image=99.21875, Accuracy of addition=99.21875:   0%|          | 1/469 [00:00<01:03,  7.33it/s][A[A[A[A



loss=0.4150506854057312 batch_id=1 Accuracy of Image=99.21875, Accuracy of addition=88.671875:   0%|          | 1/469 [00:00<01:03,  7.33it/s][A[A[A[A



loss=0.228979229927063 batch_id=2 Accuracy of Image=99.47916666666667, Accuracy of addition=92.44791666666667:   0%|          | 1/469 [00:00<01:03,  7.33it/s][A[A[A[A



loss=0.228979229927063 batch_id=2 Accuracy of Image=99.47916666666667, Accuracy of addition=92.44791666666667:   1%|          | 3/469 [00:00<00:52,  8.90it/s][A[A[A[A



loss=0.3019122779369354 batch_id=3 Accuracy of Image=99.4140625, Accuracy of addition=92.96875:   1%|          | 3/469 [00:00<00:52,  8.90it/s]               [A[A[A[A



loss=0.275


Test set: Average loss: 0.0001, Accuracy: 9896/10000 (98.96%)
Test accuracy 51.62
------------------------------------------
EPOCH: 5






loss=0.875479519367218 batch_id=0 Accuracy of Image=97.65625, Accuracy of addition=50.78125:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.875479519367218 batch_id=0 Accuracy of Image=97.65625, Accuracy of addition=50.78125:   0%|          | 1/469 [00:00<01:17,  6.03it/s][A[A[A[A



loss=0.5208895802497864 batch_id=1 Accuracy of Image=96.875, Accuracy of addition=62.890625:   0%|          | 1/469 [00:00<01:17,  6.03it/s][A[A[A[A



loss=0.7694327235221863 batch_id=2 Accuracy of Image=96.875, Accuracy of addition=49.21875:   0%|          | 1/469 [00:00<01:17,  6.03it/s] [A[A[A[A



loss=0.8648442029953003 batch_id=3 Accuracy of Image=97.265625, Accuracy of addition=42.1875:   0%|          | 1/469 [00:00<01:17,  6.03it/s][A[A[A[A



loss=0.8648442029953003 batch_id=3 Accuracy of Image=97.265625, Accuracy of addition=42.1875:   1%|          | 4/469 [00:00<01:01,  7.62it/s][A[A[A[A



loss=0.22354240715503693 batch_id=4 Accuracy of Image=97.8125, Acc


Test set: Average loss: 0.0003, Accuracy: 9924/10000 (99.24%)
Test accuracy 88.75
------------------------------------------
EPOCH: 6






loss=0.33430224657058716 batch_id=0 Accuracy of Image=98.4375, Accuracy of addition=90.625:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.33430224657058716 batch_id=0 Accuracy of Image=98.4375, Accuracy of addition=90.625:   0%|          | 1/469 [00:00<01:16,  6.15it/s][A[A[A[A



loss=0.39379429817199707 batch_id=1 Accuracy of Image=98.828125, Accuracy of addition=79.6875:   0%|          | 1/469 [00:00<01:16,  6.15it/s][A[A[A[A



loss=0.4491690993309021 batch_id=2 Accuracy of Image=98.95833333333333, Accuracy of addition=74.47916666666667:   0%|          | 1/469 [00:00<01:16,  6.15it/s][A[A[A[A



loss=0.4491690993309021 batch_id=2 Accuracy of Image=98.95833333333333, Accuracy of addition=74.47916666666667:   1%|          | 3/469 [00:00<01:00,  7.69it/s][A[A[A[A



loss=0.1589049994945526 batch_id=3 Accuracy of Image=99.21875, Accuracy of addition=80.6640625:   1%|          | 3/469 [00:00<01:00,  7.69it/s]                [A[A[A[A



loss=0.4918


Test set: Average loss: 0.0000, Accuracy: 9926/10000 (99.26%)
Test accuracy 99.26
------------------------------------------
EPOCH: 7






loss=0.14295637607574463 batch_id=0 Accuracy of Image=100.0, Accuracy of addition=100.0:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.14295637607574463 batch_id=0 Accuracy of Image=100.0, Accuracy of addition=100.0:   0%|          | 1/469 [00:00<01:33,  5.00it/s][A[A[A[A



loss=0.3509063422679901 batch_id=1 Accuracy of Image=100.0, Accuracy of addition=93.359375:   0%|          | 1/469 [00:00<01:33,  5.00it/s][A[A[A[A



loss=0.13468024134635925 batch_id=2 Accuracy of Image=99.73958333333333, Accuracy of addition=95.3125:   0%|          | 1/469 [00:00<01:33,  5.00it/s][A[A[A[A



loss=0.28444188833236694 batch_id=3 Accuracy of Image=99.8046875, Accuracy of addition=92.578125:   0%|          | 1/469 [00:00<01:33,  5.00it/s]     [A[A[A[A



loss=0.28444188833236694 batch_id=3 Accuracy of Image=99.8046875, Accuracy of addition=92.578125:   1%|          | 4/469 [00:00<01:10,  6.61it/s][A[A[A[A



loss=0.13926246762275696 batch_id=4 Accuracy of Imag


Test set: Average loss: 0.0002, Accuracy: 9925/10000 (99.25%)
Test accuracy 99.25
------------------------------------------
EPOCH: 8






loss=0.14571115374565125 batch_id=0 Accuracy of Image=99.21875, Accuracy of addition=99.21875:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.14571115374565125 batch_id=0 Accuracy of Image=99.21875, Accuracy of addition=99.21875:   0%|          | 1/469 [00:00<01:10,  6.66it/s][A[A[A[A



loss=0.40011829137802124 batch_id=1 Accuracy of Image=98.828125, Accuracy of addition=94.53125:   0%|          | 1/469 [00:00<01:10,  6.66it/s][A[A[A[A



loss=0.12642383575439453 batch_id=2 Accuracy of Image=98.95833333333333, Accuracy of addition=96.09375:   0%|          | 1/469 [00:00<01:10,  6.66it/s][A[A[A[A



loss=0.12642383575439453 batch_id=2 Accuracy of Image=98.95833333333333, Accuracy of addition=96.09375:   1%|          | 3/469 [00:00<00:57,  8.14it/s][A[A[A[A



loss=0.46805238723754883 batch_id=3 Accuracy of Image=98.828125, Accuracy of addition=90.8203125:   1%|          | 3/469 [00:00<00:57,  8.14it/s]      [A[A[A[A



loss=0.23776721954345703 bat


Test set: Average loss: 0.0000, Accuracy: 9935/10000 (99.35%)
Test accuracy 99.35
------------------------------------------
EPOCH: 9






loss=0.13143008947372437 batch_id=0 Accuracy of Image=100.0, Accuracy of addition=100.0:   0%|          | 0/469 [00:00<?, ?it/s][A[A[A[A



loss=0.13143008947372437 batch_id=0 Accuracy of Image=100.0, Accuracy of addition=100.0:   0%|          | 1/469 [00:00<01:12,  6.49it/s][A[A[A[A



loss=0.10746511071920395 batch_id=1 Accuracy of Image=100.0, Accuracy of addition=100.0:   0%|          | 1/469 [00:00<01:12,  6.49it/s][A[A[A[A



loss=0.2762308120727539 batch_id=2 Accuracy of Image=100.0, Accuracy of addition=94.79166666666667:   0%|          | 1/469 [00:00<01:12,  6.49it/s][A[A[A[A



loss=0.2762308120727539 batch_id=2 Accuracy of Image=100.0, Accuracy of addition=94.79166666666667:   1%|          | 3/469 [00:00<01:00,  7.75it/s][A[A[A[A



loss=0.10484619438648224 batch_id=3 Accuracy of Image=100.0, Accuracy of addition=96.09375:   1%|          | 3/469 [00:00<01:00,  7.75it/s]        [A[A[A[A



loss=0.4433159828186035 batch_id=4 Accuracy of Image=99.843


Test set: Average loss: 0.0001, Accuracy: 9920/10000 (99.20%)
Test accuracy 57.98
------------------------------------------


In [35]:
model(next(iter(train_set))[0], next(iter(train_set))[2])

RuntimeError: ignored

In [None]:
t1 = torch.tensor([1,2,4,5])

t2 = torch.tensor([5,6,3,2])

In [None]:
t1 == t2