In [None]:
### Import all the necessary libraries
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import os
import pickle

In [None]:
%matplotlib inline

Combo tensor:

[Batch Size x Pairs x (3x4) --> 12]  
[Batch Size x 2 x 12]

In [None]:
class SiameseNet(nn.Module):
    def __init__(self):
        super().__init__() # Python 3 doesn't require recursive call to self in super()
        self.fc1 = nn.Linear(12, 24) # 1x12 input, 24 output
        self.fc2 = nn.Linear(24, 10) # Let's try for a shallow network to see if it works
        self.fc3 = nn.Linear(10, 1) # Output layer
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, combo_tensor):
        output_0 = self.score_one_vector(combo_tensor[:,0]) # Score the first vector of the pair
        output_1 = self.score_one_vector(combo_tensor[:,1]) # Score the second vector of the pair
        diff = torch.abs(output_0 - output_1) # This could also be a Euclidian distance between the two vecs if abs diff doesn't work well
        diff = self.fc3(diff) # Push diff through the final FC layer to get a single output
        # We don't perform a sigmoid for binary CE during training b/c this will be done by the loss func automatically
        return diff
    
    def score_one_vector(self, input_vec):
        data = self.relu(self.fc1(input_vec))
        data = self.relu(self.fc2(data))
        return data
    
    def eval_forward(self, combo_tensor):
        '''Helper function to run during evaluation.
        Performs a sigmoid on the output of forward pass to do logistic regression'''
#         print(self.forward(combo_tensor))
        return self.sigmoid(self.forward(combo_tensor))

**Instantiate the model**

In [None]:
my_model = SiameseNet()

# Scratch code to test whether the forward model works

In [None]:
t1 = torch.randn((100,2,12))

In [None]:
t1[:,0].shape

In [None]:
dl = DataLoader(t1,batch_size=10,shuffle=True, num_workers=4)

In [None]:
batch = next(iter(dl))

In [None]:
batch[0]

In [None]:
with torch.no_grad():
    train_pairs = []
    my_model.eval()
    for i in dl:
        output_0 = my_model.forward(i)
#         output_1 = my_model.forward(i[:,1])
        train_pairs.append(output_0)

In [None]:
with torch.no_grad():
    output_pairs = []
    my_model.eval()
    for i in dl:
        output_0 = my_model.eval_forward(i)
#         output_1 = my_model.eval_forward(i[:,1])
        output_pairs.append(output_0)

In [None]:
torch.nn.Sigmoid()(train_pairs[0][0])

In [None]:
output_pairs[0].round()

In [None]:
t2 = torch.tensor([0,0,0,1,0,0,1,0,0,1])

In [None]:
t2.reshape(10,1)

In [None]:
(output_pairs[0] > .5).to(torch.float).eq(t2.reshape(t2.shape[0],-1)).sum().item()

In [None]:
train_pairs[0]

Instantiate loss function and optimizer

In [None]:
loss_func = torch.nn.BCEWithLogitsLoss()
optimizer_func = optim.Adam(my_model.parameters(), lr=0.001)

In [None]:
def train(model, training_loader, loss, optimizer, 
          training_device, num_epochs,
         print_interval):
    
    epoch_loss = []
    loss_counter = 0
    model.train()
    
    for epoch in range(0, num_epochs):
        
        for batch_idx, (data, target) in enumerate(training_loader):
            
            data, target = data.to(training_device), target.to(training_device)

            # target should be a 0/1 scalar indicating match/no match
            optimizer.zero_grad()
            
            output = model(data)
#             output_negative = model(data[:, 1])
            
            loss_value = loss(output, target)
#             loss_negative = loss(output_negative, target[:, 1])
            
#             loss = loss_positive + loss_negative
            
            loss_value.backward()
            
            optimizer.step()
            
            loss_counter += loss_value.item()
            
            if batch_idx % print_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(training_loader.dataset),
                    100. * batch_idx / len(training_loader), loss.item()))
                
        epoch_loss.append(loss_counter / batch_idx)
        loss_counter = 0
    return epoch_loss

def test(model, test_loader, loss,
          training_device):
    
    model.eval()
    
    batch_loss_array = []
    
    percent_correct_array = []
    
#     percent_correct_negative = []
    
    
    with torch.no_grad():
        
        for batch_idx, (data, target) in enumerate(training_loader):
            
            data, target = data.to(training_device), target.to(training_device)
            # target should be a 0/1 scalar indicating match/no match
            
            output = model(data)
#             output_negative = model(data[:, 1])
            # Run forward and eval forward pass sep to get loss and cats
            loss_value = loss(output, target[:, 0])
#             loss_negative = loss(output_negative, target[:, 1])
            
#             loss = loss_positive + loss_negative
            
            batch_loss_array.append(loss_value.item())
            
            num_correct = output_positive.round().eq(target).sum()
            percent_correct = (num_correct / num_correct.shape[0]).item()
            percent_correct_array.append(percent_correct)
            
            print(f'Batch Number: {batch_idx} - Percent Correct: {percent_correct}\n')
            
#             if batch_idx % print_interval == 0:
#                 print('Test Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                     epoch, batch_idx * len(data), len(training_loader.dataset),
#                     100. * batch_idx / len(training_loader), loss.item()))
                
    return batch_loss_array, percent_correct