# Siamese One Shot Learning Network

In [1]:
import os
import codecs
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
import datetime
import time

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

## Set Parameters

In [3]:
DO_LEARN = True
SAVE_FREQUENCY = 2
BATCH_SIZE = 16
LR = 0.001
N_EPOCHS = 10
WEIGHT_DECAY = 0.0001

In [4]:
LOAD_MODEL_PATH = './weights/'

## Set Utils

In [5]:
def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)

In [6]:
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
    assert get_int(data[:4]) == 2051
    length = get_int(data[4:8])
    num_rows = get_int(data[8:12])
    num_cols = get_int(data[12:16])
    
    images = []
    parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
    
    return torch.from_numpy(parsed).view(length, num_rows, num_cols)

In [7]:
def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
    assert get_int(data[:4]) == 2049
    length = get_int(data[4:8])
    parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
    
    return torch.from_numpy(parsed).view(length).long()

## Configure Custom Data Loader

In [8]:
class BalancedMNISTPair(torch.utils.data.Dataset):
    """ Dataset that on each iteration will provides two pairs of MNIST images randomly. 
        One pair is of the same number (positive sample) and 
        other one is of the two different numbers (negative sample).
    """
    
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        
        if download:
            self.download()
            
        if not self._check_exists():
            raise RuntimeError('Dataset not found.' + ' You can use download=True to download it.')
            
        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
            
            train_data_class = []
            train_labels_class = []
            for i in range(10):
                indices = torch.squeeze((self.train_labels == i).nonzero())
                train_data_class.append(torch.index_select(self.train_data, 0, indices))
                train_labels_class.append(torch.index_select(self.train_labels, 0, indices))

            # generated balanced pairs
            self.train_data = []
            self.train_labels = []
            lengths = [x.shape[0] for x in train_labels_class]
            for i in range(10):
                for j in range(500):
                    random_class  = random.randint(0,8)
                    if random_class >= i:
                        random_class = random_class + 1
                        
                    random_dist = random.randint(0,100)
                    
                    self.train_data.append(torch.stack([train_data_class[i][j], train_data_class[i][j+random_dist], train_data_class[random_class][j]]))
                    self.train_labels.append([1,0])
                    
            self.train_data = torch.stack(self.train_data)
            self.train_labels = torch.tensor(self.train_labels)
            
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))
            
            test_data_class = []
            test_labels_class = []
            for i in range(10):
                indices = torch.squeeze((self.test_labels == i).nonzero())
                test_data_class.append(torch.index_select(self.test_data, 0, indices))
                test_labels_class.append(torch.index_select(self.test_labels, 0, indices))
                
            # generated balanced pairs
            self.test_data = []
            self.test_labels = []
            lengths = [x.shape[0] for x in test_labels_class]
            for i in range(10):
                for j in range(500):
                    random_class  = random.randint(0,8)
                    if random_class >= i:
                        random_class = random_class + 1
                        
                    random_dist = random.randint(0,100)
                    
                    self.test_data.append(torch.stack([test_data_class[i][j], test_data_class[i][j+random_dist], test_data_class[random_class][j]]))
                    self.test_labels.append([1,0])
                    
            self.test_data = torch.stack(self.test_data)
            self.test_labels = torch.tensor(self.test_labels)
            
    def __getitem__(self, index):
        if self.train:
            images, targets = self.train_data[index], self.train_labels[index]
        else:
            images, targets = self.test_data[index], self.test_labels[index]
            
        image_list = []
        for i in range(len(images)):
            image = Image.fromarray(images[i].numpy(), mode='L')
            if self.transform is not None:
                image = self.transform(image)
            image_list.append(image)
            
        if self.target_transform is not None:
            targets = self.target_transform(targets)
            
        return image_list, targets
    
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)
        
    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
         os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
    
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        
        return fmt_str
        
    def download(self):
        """ Download the MNIST data if it doesn't exist in processed_folder already. """
        from six.moves import urllib
        import gzip
        
        if self._check_exists():
            return
        
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise
        
        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                   gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)
            
        print('Processing...')
        
        training_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

## Build Siamese Network Architecture

In [9]:
class SiameseNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 7)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 128, 5)
        self.conv3 = nn.Conv2d(128, 256, 5)
        self.linear1 = nn.Linear(2304, 512)
        self.linear2 = nn.Linear(512, 2)
        
    def forward(self, data):
        fvectors = []
        for i in range(2): # the layers in the two subnetworks share the same weights
            x = data[i]
            x = self.conv1(x)
            x = F.relu(x)
            x = self.pool1(x)
            
            x = self.conv2(x)
            x = F.relu(x)
            x = self.conv3(x)
            x = F.relu(x)
            
            x = x.view(x.shape[0], -1)
            x = self.linear1(x)
            fvectors.append(F.relu(x))
            
        distance = torch.abs(fvectors[1] - fvectors[0])
        score = self.linear2(distance)
        
        return score

In [10]:
train_loss_to_display = []
def train(model, device, train_loader, epoch, optimizer):
    model.train()
    
    print('# -----------------')
    print('# TRAINING PROCESS')
    print('# -----------------')
    for batch_idx, (data, target) in enumerate(train_loader):
        for i in range(len(data)):
            data[i] = data[i].to(device)
            
        optimizer.zero_grad()
        output_positive = model(data[:2])
        output_negative = model(data[0:3:2])
        
        target = target.type(torch.LongTensor).to(device)
        target_positive = torch.squeeze(target[:,0])
        target_negative = torch.squeeze(target[:,1])
        
        loss_positive = F.cross_entropy(output_positive, target_positive)
        loss_negative = F.cross_entropy(output_negative, target_negative)
        
        loss = loss_positive + loss_negative
        loss.backward()
        
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(epoch,
                                                                           batch_idx*BATCH_SIZE,
                                                                           len(train_loader.dataset),
                                                                           100. * batch_idx*BATCH_SIZE / len(train_loader.dataset),
                                                                           loss.item()))
        train_loss_to_display.append(loss.item())

In [11]:
test_loss_to_display = []
def test(model, device, test_loader):
    model.eval()
    
    print('# -----------------')
    print('# TESTING PROCESS')
    print('# -----------------')
    with torch.no_grad():
        
        accurate_labels = 0
        all_labels = 0
        loss = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            for i in range(len(data)):
                data[i] = data[i].to(device)
                
            output_positive = model(data[:2])
            output_negative = model(data[0:3:2])
            
            target = target.type(torch.LongTensor).to(device)
            target_positive = torch.squeeze(target[:,0])
            target_negative = torch.squeeze(target[:,1])
            
            loss_positive = F.cross_entropy(output_positive, target_positive)
            loss_negative = F.cross_entropy(output_negative, target_negative)
            
            loss = loss + loss_positive + loss_negative
            
            accurate_labels_positive = torch.sum(torch.argmax(output_positive, dim=1) == target_positive).cpu()
            accurate_labels_negative = torch.sum(torch.argmax(output_negative, dim=1) == target_negative).cpu()
            
            accurate_labels = accurate_labels + accurate_labels_positive + accurate_labels_negative
            all_labels = all_labels + len(target_positive) + len(target_negative)
            
        accuracy = 100. * accurate_labels / all_labels
        print('Test Accuracy: {}/{} ({:.3f}%) Loss: {:.6f}'.format(accurate_labels, all_labels, accuracy, loss))
        test_loss_to_display.append(loss)

In [12]:
def one_shot():
    model.eval()
    
    with torch.no_grad():
        for i in range(len(data)):
            data[i] = data[i].to(device)
            
        output = model(data)
        return torch.squeeze(torch.argmax(output, dim=1)).cpu().item()

In [13]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5,), (1.0,))])
    
    model = SiameseNet().to(device)
    
    if DO_LEARN:
        train_loader = torch.utils.data.DataLoader(BalancedMNISTPair('./datasets', train=True, download=True, transform=transform), 
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True)
        test_loader = torch.utils.data.DataLoader(BalancedMNISTPair('./datasets', train=False, download=True, transform=transform),
                                                  batch_size=BATCH_SIZE,
                                                  shuffle=False)
        
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        for epoch in range(N_EPOCHS):
            train(model, device, train_loader, epoch, optimizer)
            test(model, device, test_loader)
            if epoch & SAVE_FREQUENCY == 0:
                if not os.path.exists('./weights'): os.makedirs('./weights')
                torch.save(model, 'siamese_{:03}.pt'.format(epoch))
                
    else:
        prediction_loader = torch.utils.data.DataLoader(BalancedMNISTPair('./datasets', train=False, download=True, transform=transform),
                                                        batch_size=1,
                                                        shuffle=True)
        model.load_state_dict(torch.load(LOAD_MODEL_PATH))
        
        data = []
        data.extend(next(iter(prediction_loader))[0][:3:2])
        
        similarity_score = one_shot(model)
        if similarity_score > 0:
            print('These two images are of the same number.')
        else:
            print('These two images are not of the same number.')

In [14]:
if __name__ == '__main__':
    main()

# -----------------
# TRAINING PROCESS
# -----------------
# -----------------
# TESTING PROCESS
# -----------------
Test Accuracy: 8496/10000 (84.000%) Loss: 228.786758
# -----------------
# TRAINING PROCESS
# -----------------


  "type " + obj.__name__ + ". It won't be checked "




KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10,5))
plt.title("Train & Test Loss of Siamese Network")
plt.plot(train_loss_to_display, label="Train Loss")
plt.plot(test_loss_to_display, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

if not os.path.exists('./images/'): os.makedirs('./images/')
plt.savefig('./images/final_loss.png')
plt.show()

---