In [None]:
from skimage import io
import numpy as np
import matplotlib.pyplot as plt

In [None]:
im_size = 28
n_channels = 1

def get_batch(n, phase='train'):
    if phase == 'train':
        data_path = '/home/cicconet/Development/MachineLearning/MNIST/Train'
        n_categories = 10
        n_imgs_per_category = 1000
    elif phase == 'test':
        data_path = '/home/cicconet/Development/MachineLearning/MNIST/Test'
        n_categories = 10
        n_imgs_per_category = 100
    
    a_batch = np.zeros((n, n_channels, im_size, im_size))
    b_batch = np.zeros((n, n_channels, im_size, im_size))
#     y_batch = np.zeros((n, 2))
    y_batch = np.zeros((n, ))
    for j in range(n//2):
        # same category
        category = np.random.randint(n_categories)
        ind_a = np.random.randint(n_imgs_per_category)
        ind_b = np.random.randint(n_imgs_per_category)
        im_a = io.imread('%s/%d/Image%05d.png' % (data_path, category, ind_a)).astype('double')/255
        im_b = io.imread('%s/%d/Image%05d.png' % (data_path, category, ind_b)).astype('double')/255

        a_batch[2*j,:,:,:] = np.expand_dims(im_a, 0)
        b_batch[2*j,:,:,:] = np.expand_dims(im_b, 0)
#         y_batch[2*j,0] = 1
        y_batch[2*j] = 0

        # diff categories
        category_a = np.random.randint(n_categories)
        category_b = category_a
        while category_b == category_a:
            category_b = np.random.randint(n_categories)
        ind_a = np.random.randint(n_imgs_per_category)
        ind_b = np.random.randint(n_imgs_per_category)
        im_a = io.imread('%s/%d/Image%05d.png' % (data_path, category_a, ind_a)).astype('double')/255
        im_b = io.imread('%s/%d/Image%05d.png' % (data_path, category_b, ind_b)).astype('double')/255

        a_batch[2*j+1,:,:,:] = np.expand_dims(im_a, 0)
        b_batch[2*j+1,:,:,:] = np.expand_dims(im_b, 0)
#         y_batch[2*j+1,1] = 1
        y_batch[2*j+1] = 1

    return a_batch, b_batch, y_batch

In [None]:
n = 8
A, B, Y = get_batch(n)

fig = plt.figure(figsize=(n,2))
for i in range(n):
#     im_a, im_b, y = np.mean(A[i,:,:,:], axis=0), np.mean(B[i,:,:,:], axis=0), np.argmax(Y[i,:])
    im_a, im_b, y = np.mean(A[i,:,:,:], axis=0), np.mean(B[i,:,:,:], axis=0), Y[i]
    plt.subplot(2, n, 1+i)
    plt.imshow(im_a, cmap='gray'); plt.axis('off')
    plt.title('same' if y == 0 else 'diff')
    plt.subplot(2, n, n+1+i)
    plt.imshow(im_b, cmap='gray'); plt.axis('off')
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

embed_dim = 64

class CNNBranch(nn.Module):
    def __init__(self):
        super(CNNBranch, self).__init__()
        self.conv1 = nn.Conv2d(n_channels, 6, 5) # n chan in, n chan out, kernel size
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, embed_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

cnn_branch = CNNBranch()
A, B, Y = get_batch(n)
A, B, Y = torch.from_numpy(A).float(), torch.from_numpy(B).float(), torch.from_numpy(Y).long()
print(A.shape, B.shape, Y.shape)
print(cnn_branch(A).shape, cnn_branch(B).shape)

In [None]:
n_classes = 2

class SiameseNet(nn.Module):
    def __init__(self):
        super(SiameseNet, self).__init__()
        self.cnn_branch = CNNBranch()
        self.fc = nn.Linear(embed_dim, n_classes)

    def forward(self, x_a, x_b):
        embed_a = self.cnn_branch(x_a)
        embed_b = self.cnn_branch(x_b)
        merge = torch.abs(torch.sub(embed_a, embed_b))
#         return F.softmax(self.fc(merge),dim=1)
        return self.fc(merge)
    
siamese_net = SiameseNet()
prediction = siamese_net(A, B)
print(prediction.shape, Y.shape)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(siamese_net.parameters(), lr=0.001, momentum=0.9)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
running_loss = 0.0
n_steps =10000

siamese_net.to(device)
criterion.to(device)
for i_step in range(n_steps):
    A, B, Y = get_batch(n)
    A = torch.from_numpy(A).float().to(device)
    B = torch.from_numpy(B).float().to(device)
    Y = torch.from_numpy(Y).long().to(device)

    optimizer.zero_grad()

    prediction = siamese_net(A, B)
    loss = criterion(prediction, Y)
    loss.backward()
    optimizer.step()

    running_loss = 0.5*running_loss+0.5*loss.item()
    if i_step % 1000 == 999:
        print('step', i_step+1, 'loss', running_loss)

In [None]:
correct = 0
total = 0

n_batches = 100
with torch.no_grad():
    for i_batch in range(n_batches):
        A, B, Y = get_batch(n, 'test')
        A = torch.from_numpy(A).float().to(device)
        B = torch.from_numpy(B).float().to(device)
        Y = torch.from_numpy(Y).long().to(device)
        prediction = siamese_net(A, B)
        mx, imx = torch.max(prediction,1)
        total += len(Y)
        correct += (imx == Y).sum().item()
    
print('accuracy', correct / total)