In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [2]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(8*100*100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In [24]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output, label):
        output1, output2 = output
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive


In [4]:
class Config():
    training_dir = "./data/faces/training/"
    testing_dir = "./data/faces/testing/"
    train_batch_size = 64
    train_number_epochs = 10

In [5]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self, root):
        self.class_dict = {}
        self.image_folder = datasets.ImageFolder(root)
        transform = transforms.Compose([
                        transforms.Resize((100,100)),
                        transforms.ToTensor()
                    ])
        self.all_images = []
        
        # 1. separate paths by index, open image, and transform
        for path, class_index in self.image_folder.imgs:
            if class_index not in self.class_dict:
                self.class_dict[class_index] = []
            image = transform(
                        Image.open(path).convert('L') # L is grayscale mode
                    )
            self.class_dict[class_index].append(image)
            self.all_images.append((image, class_index))
            
        # 2. prepare dict of other indexes that are not current index (for random)
        indexes = [self.image_folder.class_to_idx[class_name]
                   for class_name in self.image_folder.classes]
        self.not_index = {}
        for index in indexes:
            new_indexes = indexes[:]
            new_indexes.remove(index)
            self.not_index[index] = new_indexes
        
    def get_images_by_class_index(self, class_index):
        return self.class_dict[class_index]
    
    def __getitem__(self, index):
        image_0, class_index_0 = random.choice(self.all_images)
        is_same_class = random.randint(0, 1)
        if is_same_class:
            class_index_1 = class_index_0
        else:
            # to make sure getting images from same class 50%
            class_index_1 = random.choice(self.not_index[class_index_0])
        image_1 = random.choice(self.get_images_by_class_index(class_index_1))
        
        return image_0, image_1, torch.tensor(is_same_class, dtype=torch.float32)
    def __len__(self):
        return len(self.image_folder.imgs)

### Train

In [36]:
from exitai.learner import Learner
train_dataset = SiameseNetworkDataset(root=Config.training_dir)
test_dataset = SiameseNetworkDataset(root=Config.testing_dir)
train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=8, batch_size=Config.train_batch_size)
test_dataloader = DataLoader(test_dataset, shuffle=True, num_workers=8, batch_size=512)

net = SiameseNetwork()
criterion = ContrastiveLoss()
learner = Learner(train_dataloader, test_dataloader, net, criterion)

In [35]:
learner.lr_find(start_lr=1e-7, end_lr=1, num_it=100)

Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connectio

KeyboardInterrupt: 

In [37]:
def eval_func(loss, output, target):
    output1, output2 = output
    euclidean_distance = F.pairwise_distance(output1, output2)
    return (euclidean_distance<1).sum()
learner = Learner(train_dataloader, test_dataloader, net, criterion)
learner.fit(0.0005, eval_func=eval_func)

---- epoch:0 ------
euclidean_distance: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)
target: tensor([1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0.,
        1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1.,
        0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1.,
        0., 1., 0., 0., 1., 1., 1., 1., 1., 0.])
euclidean_distance: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)
target: tensor([0., 0., 1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
        0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0.,
   

Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/epinyoanun/miniconda3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/epinyoanun/miniconda3/lib/

KeyboardInterrupt: 

In [53]:
for i in range(10):
    threshold = (i+1)/11
    print('threshold:',threshold)
    def eval_func(loss, output, target):
        output1, output2 = output
        euclidean_distance = F.pairwise_distance(output1, output2)
        predict = (euclidean_distance > 0.3).float()
        return predict.eq(target).sum().item()
    learner.predict(eval_func)

threshold: 0.09090909090909091
num_data: 30
   [test] Average loss: 7.5510, acc: 43.33%
threshold: 0.18181818181818182
num_data: 30
   [test] Average loss: 8.4382, acc: 36.67%
threshold: 0.2727272727272727
num_data: 30
   [test] Average loss: 6.6619, acc: 50.00%
threshold: 0.36363636363636365
num_data: 30
   [test] Average loss: 9.3270, acc: 30.00%
threshold: 0.45454545454545453
num_data: 30
   [test] Average loss: 7.5510, acc: 43.33%
threshold: 0.5454545454545454
num_data: 30
   [test] Average loss: 6.2188, acc: 53.33%
threshold: 0.6363636363636364
num_data: 30
   [test] Average loss: 7.1059, acc: 46.67%
threshold: 0.7272727272727273
num_data: 30
   [test] Average loss: 7.1071, acc: 46.67%
threshold: 0.8181818181818182
num_data: 30
   [test] Average loss: 9.3267, acc: 30.00%
threshold: 0.9090909090909091
num_data: 30
   [test] Average loss: 6.2193, acc: 53.33%
