# Siamese One Shot Learning Network

In [5]:
import os
import codecs
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
import datetime
import time

In [9]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

## Set Parameters

In [4]:
DO_LEARN = True
SAVE_FREQUENCY = 2
BATCH_SIZE = 16
LR = 0.001
N_EPOCHS = 10
WEIGHT_DECAY = 0.0001

## Set Utils

In [6]:
def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)

In [8]:
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
    assert get_int(data[:4]) == 2051
    length = get_int(data[4:8])
    num_rows = get_int(data[8:12])
    num_cols = get_int(data[12:16])
    
    images = []
    parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
    
    return torch.from_numpy(parsed).view(length, num_rows, num_cols)

In [7]:
def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
    assert get_int(data[:4]) == 2049
    length = get_int(data[4:8])
    parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
    
    return torch.from_numpy(parsed).view(length).long()

In [10]:
class BalancedMNISTPair(torch.utils.data.Dataset):
    pass

In [11]:
class SiameseNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 7)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 128, 5)
        self.conv3 = nn.Conv2d(128, 256, 5)
        self.linear1 = nn.Linear(2304, 512)
        self.linear2 = nn.Linear(512, 2)
        
    def forward(self, data):
        fvectors = []
        for i in range(2): # the layers in the two subnetworks share the same weights
            x = data[i]
            x = self.conv1(x)
            x = F.relu(x)
            x = self.pool1(x)
            
            x = self.conv2(x)
            x = F.relu(x)
            x = self.conv3(x)
            x = F.relu(x)
            
            x = x.view(x.shape[0], -1)
            x = self.linear1(x)
            fvectors.append(F.relu(x))
            
        distance = torch.abs(fvectors[1] - fvectors[0])
        score = self.linear2(distance)
        
        return score

In [12]:
def train():
    pass

In [13]:
def test():
    pass

In [15]:
def one_shot():
    pass

---