# Siamese One-Shot-Learning Network, AT&T Faces

In [1]:
import os
import codecs
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import random
import datetime
import time

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

## Set Configs

In [3]:
TRAIN_DIR = './datasets/faces/training/'
TEST_DIR = './datasets/faces/testing/'
BATCH_SIZE = 64
N_EPOCHS = 100

## Set Helpers

In [4]:
def show_img(img, text=None, should_save=False):
    
    img = img.numpy()
    plt.axis('off')
    if text:
        plt.text(75, 8, text, style='italic', fontweight='bold',
            bbox={'facecolor': 'white', 'aplha': 0.8, 'pad': 10 })
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

In [5]:
def show_plot(iteration, loss):
    plt.plot(iteration, loss)
    plt.show()

## Configure Custom Data Loader

In [6]:
class SiameseNetworkDataset(torch.utils.data.Dataset):
    
    def __init__(self, dataset, transform=None, should_invert=True):
        self.dataset = dataset
        self.transform = transform
        self.should_invert = should_invert
        
    def __getitem__(self, index):
        
        imageA_tupple = random.choice(self.dataset.imgs)
        should_get_same_class = random.randint(0,1) # make sure approx 50% of images are in the same class
        if should_get_same_class:
            while True:
                # keep looping untul the same class image is found
                imageB_tupple = random.choice(self.dataset.imgs)
                if imageA_tupple[1] == imageB_tupple[1]:
                    break
        else:
            while True:
                # keep looping untill the different class image is found
                imageB_tupple = random.choice(self.dataset.imgs)
                if imageA_tupple[1] == imageB_tupple[1]:
                    break
                    
        imageA = Image.open(imageA_tupple[0])
        imageB = Image.open(imageB_tupple[0])
        
        imageA = imageA.convert('L')
        imageB = imageB.convert('L')
        
        if self.should_invert:
            imageA = ImageOps.invert(imageA)
            imageB = ImageOps.invert(imageB)
            
        if self.transform is not None:
            imageA = self.transform(imageA)
            imageB = self.transform(imageB)
            
        return imageA, imageB, torch.from_numpy(np.array([int(imageA_tupple[1] != imageB_tupple[1])], dtype=np.float32))
    
    def __len__(self):
        return len(self.dataset.imgs)

---