In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [16]:
class MNISTNET(nn.Module):
    def __init__(self):
      super(MNISTNET, self).__init__()

      self.input_block = nn.Sequential(
      #INPUT 28X28X1 >>CONV 3X3X1X16 >>26X26X16
          nn.Conv2d(1, 16, 3, bias=False), 
          nn.ReLU(),
          nn.BatchNorm2d(16),
          nn.Dropout2d(0.1),

 
      #INPUT 26X26X16 >>CONV 3X3X16X32 >>24X24X32
          nn.Conv2d(16, 32, 3, bias=False), 
          nn.ReLU(),
          nn.BatchNorm2d(32),
          nn.Dropout2d(0.1),
      )
            # translation layer
      
      self.trans1 = nn.Sequential(
          #24X24x32 >>CONV 1X1X32X8 >>24X24X8
          nn.Conv2d(32, 8, 1, bias=False), 
          nn.ReLU(),
          #24X24x8 >>MAXPOOL (2,2) >>12X12X8
          nn.MaxPool2d(2, 2),
      )
      self.conv_block = nn.Sequential(
          #12X12x8 >>CONV 3X3X8X16 PAD=1 >>12X12X16
          nn.Conv2d(8, 16, 3,padding=1, bias=False),
          nn.ReLU(),
          nn.BatchNorm2d(16),
          nn.Dropout2d(0.1),

          #12X12x16 >>CONV 3X3X16X32 >>10X10X32

         nn.Conv2d(16, 32, 3, bias=False),
          nn.ReLU(),
          nn.BatchNorm2d(32),
          nn.Dropout2d(0.1),

      )
                  # translation layer
      self.trans2 = nn.Sequential(
          #10X10x32 >>CONV 1X1X32X8 >>10X10X8
          nn.Conv2d(32, 8, 1, bias=False), 
          nn.ReLU(),
          #10X10x8 >>MAXPOOL (2,2) >>5X5X8
          nn.MaxPool2d(2, 2),
      )
      self.conv_block2 = nn.Sequential(
        #5X5X8 >>CONV 3X3X8X16 PAD=1 >>5X5X16
         nn.Conv2d(8, 16, 3,padding=1,bias=False),
        nn.ReLU(),
        nn.BatchNorm2d(16),
        nn.Dropout2d(0.1),  
        #5X5X16 >>CONV 3X3X8X32 PAD=0 >>3X3X32
        nn.Conv2d(16, 32, 3,bias=False),
        nn.ReLU(),
        nn.BatchNorm2d(32),
        nn.Dropout2d(0.1),

      )

          
      self.avg_pool = nn.Sequential(
      #3X3X32 >>CONV 1X1X32X10  >>3X3X10
      nn.Conv2d(32, 10, 1, bias=False),
      #3X3X10 >>AVG pool(3X3) >>1X1X10
      nn.AvgPool2d(3)
      )

    """forward: performs a forward pass when model(x) is called
    Params
        x: the input data
    Returns
        y: the output of the model
    """
    def forward(self, x):
        x = self.input_block(x)
        x = self.trans1(x)
        x = self.conv_block(x)
        x = self.trans2(x)
        x = self.conv_block2(x)
        x = self.avg_pool(x)
        x = x.view(-1, 10)
        #return F.log_softmax(x)
        return x

In [8]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = MNISTNET().to(device)
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 26, 26]             144
              ReLU-2           [-1, 16, 26, 26]               0
       BatchNorm2d-3           [-1, 16, 26, 26]              32
         Dropout2d-4           [-1, 16, 26, 26]               0
            Conv2d-5           [-1, 32, 24, 24]           4,608
              ReLU-6           [-1, 32, 24, 24]               0
       BatchNorm2d-7           [-1, 32, 24, 24]              64
         Dropout2d-8           [-1, 32, 24, 24]               0
            Conv2d-9            [-1, 8, 24, 24]             256
             ReLU-10            [-1, 8, 24, 24]               0
        MaxPool2d-11            [-1, 8, 12, 12]               0
           Conv2d-12           [-1, 16, 12, 12]           1,152
             ReLU-13           [-1, 16, 12, 12]               0
      BatchNorm2d-14           [-1, 16,



In [18]:
class ADDERNET(nn.Module):
    def __init__(self):
        super(ADDERNET, self).__init__()

        self.num_classes = 10
        self.dims = (1, 28, 28)

        #INPUT 28X28X1 => OUT 1X1X10
        self.mnist_base = MNISTNET()

        # IN: 20 (10 mnist + 10 OHE rand num)
        #IN 1X1x20  #OUT 1X1X60
        self.prefinal_layer1 = nn.Sequential(
            nn.Linear(in_features=20, out_features=60, bias=False),
            nn.BatchNorm1d(60),
            nn.ReLU(),
        )
        #IN 1X1x60  #OUT 1X1X60
        self.prefinal_layer2 = nn.Sequential(
            nn.Linear(in_features=60, out_features=60, bias=False),
            nn.BatchNorm1d(60),
            nn.ReLU(),
            
        )

        # IN: 60 ; OUT: 10
        self.mnist_final_layer = nn.Sequential(
          nn.Linear(in_features=60, out_features=10, bias=False),
        )

        # IN: 60 ; OUT 19
        self.adder_final_layer = nn.Sequential(
        nn.Linear(in_features=60, out_features=19, bias=False)
        )
        self.loss = nn.CrossEntropyLoss()

    def forward(self, mnist_img, rand_num):
        rand_num = F.one_hot(rand_num, num_classes=self.num_classes)
        
        # mnist embedding: 1x10
        mnist_embed = self.mnist_base(mnist_img)

        # concat the mnist embedding and the random number = 10+10=20 features
        ccat = torch.cat([mnist_embed, rand_num], dim=-1)
        #IN 1X1x20  #OUT 1X1X60 
        pre_out = self.prefinal_layer1(ccat)
        #IN 60  # OUT 60 
        pre_out = self.prefinal_layer2(pre_out)
        #IN 60 OUT 10 
        mnist_out = self.mnist_final_layer(pre_out)
        #IN 60 OUT 19
        adder_out = self.adder_final_layer(pre_out)

        return mnist_out, adder_out


In [29]:
#!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = ADDERNET().to(device)
summary(model, (1, 28, 28,1))

TypeError: ignored

In [None]:


torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)

In [None]:

        model = ADDERNET().to(device)
        from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output_mnist,output_adder = model(data,)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx}')
        tqdm._instances.clear()

    def training_step(self, batch, batch_idx):
        (mnist_img, rand_num), (mnist_y, adder_y) = batch

        mnist_pred, adder_pred = self(mnist_img, rand_num)

        # both mnist and adder use cross entropy loss
        mnist_loss = self.loss(mnist_pred, mnist_y)
        adder_loss = self.loss(adder_pred, adder_y)

        # final loss is sum of the two loss
        loss = mnist_loss + adder_loss

        return loss
      
    def validation_step(self, batch, batch_idx):
        (mnist_img, rand_num), (mnist_y, adder_y) = batch

        mnist_pred, adder_pred = self(mnist_img, rand_num)

        mnist_loss = self.loss(mnist_pred, mnist_y)
        adder_loss = self.loss(adder_pred, adder_y)

        loss = mnist_loss + adder_loss

        mnist_pred = torch.argmax(F.log_softmax(mnist_pred, dim=1), dim=1)
        adder_pred = torch.argmax(F.log_softmax(adder_pred, dim=1), dim=1)

        mnist_acc = accuracy(mnist_pred, mnist_y)
        adder_acc = accuracy(adder_pred, adder_y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('mnist_acc', mnist_acc, prog_bar=True)
        self.log('adder_acc', adder_acc, prog_bar=True)
        self.log('total_acc', mnist_acc * adder_acc, prog_bar=True)

        return loss
