**Imports**

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from random import randint
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

**Model**

In [16]:
dropout_value = 0.06

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.convblock1 = nn.Sequential(
            
            nn.Conv2d(in_channels=1, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), # output_size = 26
            
        )
      

        self.convblock2 = nn.Sequential(
            
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), #output_size = 24
            
        )

        
        self.convblock3 = nn.Sequential(
            
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), #output_size = 22
            
        )

        self.pool3 = nn.MaxPool2d(2, 2) # output_size = 11


        self.convblock4 = nn.Sequential(
          

            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value), # output_size = 9   
        )

        

        self.convblock5 = nn.Sequential(   
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), bias=False), 
            nn.ReLU(),
            nn.BatchNorm2d(12),
            nn.Dropout(dropout_value),  # output_size = 7     
            
        )

        self.convblock6 = nn.Sequential(   
            nn.Conv2d(in_channels=12, out_channels=16, kernel_size=(3, 3), bias=False), 
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(dropout_value),   # output_size = 5     
            
        )

        self.convblock7 = nn.Sequential(   
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding = 1,bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(dropout_value),       # output_size = 5  
            
        )

        self.gap = nn.Sequential(
            nn.AvgPool2d(kernel_size=5) # output_size = 1*1*16
        ) 


        self.convblock8 = nn.Sequential(   
            nn.Conv2d(in_channels=16, out_channels=10, kernel_size=(1, 1),bias=False), # output_size = 1 * 1* 10
                      
        )
        self.addition_layer1 = nn.Linear(in_features=2, out_features=10)
        self.addition_layer2 = nn.Linear(in_features=10, out_features=30)
        self.addition_out_layer = nn.Linear(in_features=30, out_features=1)



    def forward(self, x, random_input):

        #print(f'x= {x.shape}')
        number_result = random_input
        x = self.convblock1(x)   
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.pool3(x)
        x = self.convblock4(x)
        x = self.convblock5(x)
        x = self.convblock6(x)
        x = self.convblock7(x)
        x = self.gap(x)
        x = self.convblock8(x)
        outputImage = x.view(-1, 10)
        # outputImage = torch.flatten(x, start_dim=1)
        imageOutput = torch.argmax(outputImage, dim=1)
        image_value_and_number  = torch.stack((imageOutput.float(), random_input), dim=1)
        # print(f' concatenated = {image_value_and_number.shape}')
        #print(f'imageOutPut = {imageOutput} outputshape = {imageOutput.shape}')


        #print(f'x shape {x.shape}')
        y=0
        addition_result = self.addition_layer1(image_value_and_number)
        addition_result = self.addition_layer2(addition_result)
        addition_result = self.addition_out_layer(addition_result)
        # print(f'additionalInput = {addition_result}, ishape = {addition_result.shape}')
        # if random_input != None:
        #   y = 0
        #   print(f'r= {random_input.shape}')
        #   #random_input = F.one_hot(random_input, num_classes=10)
        #   #random_input = random_input.reshape(-1,1,1,10);
        #   #print(f'f= {random_input.shape}')
        #   #print(f'random_input = {random_input}')
        #   #random_input = F.one_hot(random_input, num_classes=10)
        #   #addition_result = self.addition_layer1(torch.tensor(random_input, dtype=torch.float32,));
        #   #addition_result = self.addition_layer2(addition_result);
        #   #addition_result = self.addition_out_layer(addition_result);
        #   #number_result = addition_result

        #return F.log_softmax(outputImage), y
        return outputImage, addition_result

In [17]:
# !pip install torchsummary
# from torchsummary import summary
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda" if use_cuda else "cpu")
# model = Net().to(device)
# summary(model, input_size=(1, 28, 28))

**Custom Dataset**

In [18]:
from torch.utils.data import Dataset

class CustomDataSet(Dataset):
  def __init__(self, isTrain):
    if isTrain:
      self.data = datasets.MNIST('./data', train=isTrain, download=True, transform=transforms.Compose([transforms.RandomRotation((-5.0, 5.0), fill=(1,)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
    else:
      self.data = datasets.MNIST('./data', train=isTrain, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

  def __getitem__(self, index):
    r = self.data[index]
    image, label = r
    random_input = torch.tensor(randint(0,9), dtype=torch.float32)    #torch.randint(0, 10, (1,))
    random_target = random_input + label
    #print(f'raandom target = {random_target}, random_label={label}, random_input= {random_input} ')
    return image, label, random_input, random_target

  def __len__(self):
    return len(self.data)

In [19]:
train_set = CustomDataSet(True)
test_set = CustomDataSet(False)

torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, **kwargs)


# next(iter(train_loader))[1]

In [35]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0
    for batch_idx, (data, target,random_input, random_target) in enumerate(pbar):
        data, target, random_input, random_target = data.to(device), target.to(device), random_input.to(device), random_target.to(device)
        optimizer.zero_grad()
        output, number_output = model(data, random_input)
        #print(f' label= {target.shape} image= {data.shape} random_label = {random_target.shape} ')
        loss = F.cross_entropy(output, target)
        loss_l1 = F.l1_loss(number_output, random_target)

        loss = loss + loss_l1
        loss.backward()
        optimizer.step()
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
   

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target, random_input, random_target in test_loader:
            # data, target = data.to(device), target.to(device)
            data, target, random_input, random_target = data.to(device), target.to(device), random_input.to(device), random_target.to(device)
            output, number_output = model(data, random_input)
            loss1 = F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            loss_f1 = F.l1_loss(number_output, random_target)
            test_loss = loss1 + loss_f1
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [36]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 2):
    print("EPOCH:", epoch)
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    print('------------------------------------------')



  0%|          | 0/469 [00:00<?, ?it/s][A[A

EPOCH: 1


  del sys.path[0]


loss=10.742785453796387 batch_id=0 Accuracy=12.50:   0%|          | 0/469 [00:00<?, ?it/s][A[A

loss=10.742785453796387 batch_id=0 Accuracy=12.50:   0%|          | 1/469 [00:00<00:49,  9.40it/s][A[A

loss=9.761452674865723 batch_id=1 Accuracy=11.72:   0%|          | 1/469 [00:00<00:49,  9.40it/s] [A[A

loss=9.273329734802246 batch_id=2 Accuracy=11.72:   0%|          | 1/469 [00:00<00:49,  9.40it/s][A[A

loss=8.594279289245605 batch_id=3 Accuracy=11.33:   0%|          | 1/469 [00:00<00:49,  9.40it/s][A[A

loss=8.594279289245605 batch_id=3 Accuracy=11.33:   1%|          | 4/469 [00:00<00:41, 11.30it/s][A[A

loss=7.170487880706787 batch_id=4 Accuracy=11.09:   1%|          | 4/469 [00:00<00:41, 11.30it/s][A[A

loss=6.906628608703613 batch_id=5 Accuracy=11.98:   1%|          | 4/469 [00:00<00:41, 11.30it/s][A[A

loss=7.484557628631592 batch_id=6 Accuracy=12.95:   1%|          | 4/469 [00:00<00:41, 11.30it/s][A[A

loss=7.484557628631592 batch_id=6 Accura


Test set: Average loss: 0.0007, Accuracy: 9838/10000 (98.38%)

------------------------------------------




In [31]:
next(iter(train_set))[0].shape, next(iter(train_set))[2]

(torch.Size([1, 28, 28]), tensor(8.))