In [76]:
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import dlc_practical_prologue as prologue
%matplotlib inline
N=1000
from torch.utils.data import DataLoader, Dataset
import tqdm
from torch.autograd import Variable

In [77]:
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N)

In [104]:
class CompareNet(nn.Module):
    def __init__(self):
        super(CompareNet, self).__init__()
        self.base = nn.Sequential(nn.Linear(196,64),nn.ReLU(),nn.Linear(64,32))
        self.comparator = nn.Linear(64,2)
        self.classifier = nn.Linear(32,10)
    
    def forward(self, x):
        x_0 = x[:,0,:,:].flatten(1)
        x_1 = x[:,1,:,:].flatten(1)
        x_0 = F.relu(self.base(x_0))
        x_1 = F.relu(self.base(x_1))
        sign = F.relu(self.comparator(torch.cat([x_0,x_1],dim = 1)))
        digit_0 = F.relu(self.classifier(x_0))
        digit_1 = F.relu(self.classifier(x_1))
        return sign, digit_0, digit_1
    

In [105]:
class DigitPairsDataset(Dataset):
    def __init__(self,img_pair,targets, classes):
        super(DigitPairsDataset, self).__init__()
        self.img_pair = img_pair
        self.targets = targets
        self.classes = classes
    
    def __len__(self):
        return self.targets.size()[0]
    
    def __getitem__(self, idx):
        return self.img_pair[idx], self.targets[idx], self.classes[idx]

In [None]:
net = CompareNet()
criterion = nn.CrossEntropyLoss()
mu = 1.0
optimizer = optim.Adam(net.parameters(),lr = 0.001)
train_dataset = DigitPairsDataset(train_input,train_target,train_classes)
test_dataset = DigitPairsDataset(test_input,test_target,test_classes)
train_loader = DataLoader(train_dataset,batch_size=32, shuffle = True, num_workers = 4)
test_loader = DataLoader(test_dataset,batch_size=32, shuffle = True, num_workers = 4)

def calc_accuracy(data_loader,model):
    correct_count = 0.0
    for i, data in enumerate(data_loader,0):
        img_pair, target, classes = data
        pred_sign, pred_class0, pred_class1 = model(img_pair)
        pred = torch.argmax(pred_sign,-1)
        correct_count += int((target.eq(pred)).sum())
    return correct_count*100.0/N

epochs = 200
loss_arr = []
train_acc_arr = []
val_acc_arr = []

for epoch in tqdm.tqdm(range(epochs)):
    net.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader,0):
        img_pair, target, classes = data
        optimizer.zero_grad()

        pred_sign, pred_class0, pred_class1 = net(img_pair)
        loss = criterion(pred_sign,target)+ mu*(criterion(pred_class0, classes[:,0])+ criterion(pred_class1, classes[:,1]))
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    net.eval()
    running_loss /= N
    loss_arr.append(running_loss)
    train_acc = calc_accuracy(train_loader,net)
    val_acc = calc_accuracy(test_loader,net)
    train_acc_arr.append(train_acc)
    val_acc_arr.append(val_acc)
    print("Epoch : %d  ,   Train Accuracy : %.5f  , Validation Accuracy : %.5f , Training Loss : %.6f" %(epoch, train_acc, val_acc, running_loss))