## Prepare dataset

In [1]:
import os
from PIL import Image

from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision import transforms, utils, models
import torch

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# https://www.kaggle.com/pinocookie/pytorch-dataset-and-dataloader
import random

class DatasetPix2Pix:
    def __init__(self, data_path):
        self.filenames = [os.path.join(data_path, filename) for filename in os.listdir(data_path)]
        
        self.data_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
    def __getitem__(self, index):
        image = Image.open(self.filenames[index])
        
        unused = self.filenames.copy()
        unused.pop(index)
        
        anchor = image.crop((0, 0, 256, 256))
        positive = image.crop((256, 0, 512, 256))
        negative = Image.open(random.choice(unused)).crop((256, 0, 512, 256))
        
        return self.data_transform(anchor),\
               self.data_transform(positive),\
               self.data_transform(negative)
    
    def __len__(self):
        return len(self.filenames)
    
    def __add__(self, other):
        return ConcatDataset([self, other])

In [3]:
data = DatasetPix2Pix('edges2shoes/train/')
dataloader = DataLoader(data, batch_size=8, shuffle=True, num_workers=4)

## Models

In [4]:
import torchvision

class Vectorizer(torch.nn.Module):
    def __init__(self, resnet, enc_size=256):
        super(self.__class__, self).__init__()
        
        self.resnet = resnet
        
        self.code = torch.nn.Sequential(torch.nn.Linear(1000, 512), 
                                        torch.nn.ReLU(), 
                                        torch.nn.Linear(512, enc_size))
        
    def forward(self, x):
        resnet_code = self.resnet(x)
        return self.code(resnet_code)

In [5]:
class Siamize(torch.nn.Module):
    def __init__(self, enc_size=256):
        super(self.__class__, self).__init__()
        
        self.resnet = torchvision.models.resnet34(pretrained=True)
        
        self.vec1 = Vectorizer(self.resnet, enc_size)
        self.vec2 = Vectorizer(self.resnet, enc_size)
        
    def forward(self, anchor, positive, negative):
        return self.vec1(anchor), self.vec2(positive), self.vec2(negative)

In [6]:
import torch.nn.functional as F

def triplet_loss(enc_anch, enc_pos, enc_neg, alpha=1):
    a = F.pairwise_distance(enc_anch, enc_pos, 2)
    b = F.pairwise_distance(enc_anch, enc_neg, 2)
    
    return torch.mean(a - b + alpha)

In [7]:
siam = Siamize().cuda()

In [8]:
a, p, n = iter(dataloader).next()

In [9]:
siam(a.cuda(), p.cuda(), n.cuda())

(tensor([[-0.0961, -0.4084, -0.9887,  ..., -0.5545,  0.2012,  0.0194],
         [-0.6371, -0.2477, -0.4460,  ..., -0.6385, -0.4635,  0.4135],
         [-0.2660,  0.0477, -0.5092,  ..., -0.5798, -0.4693,  0.1814],
         ...,
         [ 0.0099, -0.2003, -0.2601,  ..., -0.0406, -0.6406,  0.1064],
         [ 0.0675,  0.2043, -0.7045,  ..., -0.9387, -0.0453, -0.1944],
         [ 0.3235, -0.4560, -0.3258,  ...,  0.1041,  0.5716,  0.3094]],
        device='cuda:0', grad_fn=<ThAddmmBackward>),
 tensor([[-0.4714,  0.3293, -0.4218,  ...,  0.1529,  0.1064,  0.6215],
         [-0.1963,  0.0871, -0.4475,  ..., -0.2807, -0.0438,  0.4317],
         [-0.4859, -0.0240, -0.3416,  ...,  0.3543, -0.3397,  0.3067],
         ...,
         [-0.0281,  0.0280,  0.1903,  ...,  0.5573, -0.1121,  0.2209],
         [-0.3024, -0.0112, -0.8941,  ..., -0.1444, -0.1628,  0.6426],
         [-0.0799, -1.3123, -0.6613,  ...,  0.7629, -0.2278,  1.0274]],
        device='cuda:0', grad_fn=<ThAddmmBackward>),
 tensor([[-0