In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Compose, Resize, Grayscale
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
class SiameseNet(nn.Module):
    def __init__(self, img_shape=(100,100)):
        super(SiameseNet,self).__init__()
        self.cnn = nn.Sequential(
            self.cnn_block(1,4),
            self.cnn_block(4,8),
            self.cnn_block(8,8),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(img_shape[0]*img_shape[1]*8, 500),
            nn.ReLU(),
            nn.Linear(500,500),
            nn.ReLU(),
            nn.Linear(500,30)
        )
        return
    
    def cnn_block(self, in_channels, out_channels):
        cnn = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels,out_channels,3),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )
        return cnn
    
    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)
        return x

In [4]:
class ATTFaces(Dataset):
    def __init__(self, path, transform=None):
        self.img_ds = ImageFolder(path, transform=transform)
        self.n_classes = len(self.img_ds.classes)
        self.class_size = 10
    
    def _get_rand_img_from_class(self, c):
        n = np.random.randint(self.class_size)
        return self.img_ds[(c*self.class_size + n)]
    
    def _get_image_pair_same_class(self):
        c = np.random.randint(self.n_classes)
        return self._get_rand_img_from_class(c), self._get_rand_img_from_class(c)
    
    def _get_image_pair_different_class(self):
        c = np.random.randint(0,self.n_classes,2)
        return self._get_rand_img_from_class(c[0]), self._get_rand_img_from_class(c[1])
        
    def __getitem__(self, idx):
        if np.random.randint(2):
            img1, img2 = self._get_image_pair_same_class()
        else:
            img1, img2 = self._get_image_pair_different_class()
            
        return img1[0], img2[0], int(img1[1]!=img2[1])

    def __len__(self):
        return len(self.img_ds)

In [5]:
class ContrasiveLoss(nn.Module):
    def __init__(self, m):
        super(ContrasiveLoss, self).__init__()
        self.m = m
    
    def forward(self, v1, v2, t):
        d = torch.norm(v1 - v2, dim=1)
        l = (1-t)*d**2 + t*(torch.clamp(self.m - d, min=0.0))**2
        return l/2

In [6]:
dataset_path = "att_faces"
train_path = os.path.join(dataset_path, 'train_set')
test_path = os.path.join(dataset_path, 'test_set')

img_shape = (100,100)
transforms = Compose([Resize(img_shape), Grayscale(), ToTensor()])
ds = ATTFaces(train_path, transform=transforms)
dl = DataLoader(ds, batch_size=1)

ds_test = ATTFaces(test_path, transform=transforms)
dl_test = DataLoader(ds_test, batch_size=1)

In [7]:
model = SiameseNet(img_shape)
model.to(device)

optimizer = Adam(model.parameters(), lr=0.0001)
loss_fn = ContrasiveLoss(2.)

EPOCHS = 40
min_loss = 100000
for n in range(EPOCHS):
    model.train()
    total_loss = 0
    for i, mb in enumerate(dl,0):
        img1, img2, label = mb[0].to(device), mb[1].to(device), mb[2].to(device)
        optimizer.zero_grad()
        emb1 = model(img1)
        emb2 = model(img2)
        loss = loss_fn(emb1, emb2, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    total_loss /= len(ds)
    print(total_loss)
#     if total_loss < min_loss:
#         min_loss = total_loss
#         torch.save(model, "model.checkpoint")
#         print("saved model")
        
    model.eval()
    total_loss_test = 0
    for i, mb in enumerate(dl_test,0):
        img1, img2, label = mb[0].to(device), mb[1].to(device), mb[2].to(device)
        emb1 = model(img1)
        emb2 = model(img2)
        loss = loss_fn(emb1, emb2, label)
        total_loss_test += loss.item()
    total_loss_test /= len(ds_test)
    print(total_loss_test)
    if total_loss_test < min_loss:
        min_loss = total_loss_test
        torch.save(model, "model.checkpoint")
        print("saved model")
        
    
        


1.6538769340747845
0.8168447386473418


  "type " + obj.__name__ + ". It won't be checked "


saved model
1.3049782137788133
2.357071481309831
0.42988966459389
3.1383849369734524
0.30021020355596495
4.580989760950033
0.2514655553952919
1.7469185554639262
0.15701399960700654
8.177664301814511
0.13493491409656902
2.0973949324349817
0.09176688249669193
1.1415909814334009
0.08928141862241318
4.235594155360014
0.07019017191779009
5.170724723068997
0.0655686297580299
0.8931510563008487
0.05878621233554857
6.79754934495315
0.06188855538833498
1.93828111034533
0.053964075927421314
4.512365977652371
0.06514688304244079
1.9339282459579408
0.04921332340842734
3.994299346776679
0.05630800061218906
1.105116596070584
0.0528322378298617
0.8722413994371891
0.04323187756118083
1.0478554924903438
0.049786990485251106
2.837320207497105
0.07088253055541675
1.5101246199756861
0.04633096406488524
2.0505081094428896
0.04993934898986481
1.6841161077003926
0.03212068009384287
2.794539422355592
0.03700368360602928
2.836395698511042
0.03314233483358597
1.0497476268394894
0.02168465073117356
2.04180171246

In [8]:
model = torch.load("model.checkpoint")
model.eval()

for i, mb in enumerate(dl_test,0):
    img1, img2, label = mb[0].to(device), mb[1].to(device), mb[2].to(device)
    emb1 = model(img1)
    emb2 = model(img2)
    dist = torch.norm(emb1 - emb2, dim=1)
    print(dist.item(), label)
print(total_loss)

3.1910221576690674 tensor([0], device='cuda:0')
0.7320285439491272 tensor([1], device='cuda:0')
5.896364212036133 tensor([1], device='cuda:0')
2.2141172885894775 tensor([0], device='cuda:0')
3.546520233154297 tensor([0], device='cuda:0')
2.2753489017486572 tensor([0], device='cuda:0')
1.053887128829956 tensor([1], device='cuda:0')
7.313989162445068 tensor([1], device='cuda:0')
1.371145248413086 tensor([1], device='cuda:0')
7.882805347442627 tensor([1], device='cuda:0')
1.5637966394424438 tensor([1], device='cuda:0')
7.834571838378906 tensor([1], device='cuda:0')
6.106356620788574 tensor([0], device='cuda:0')
6.606673717498779 tensor([1], device='cuda:0')
6.308690071105957 tensor([1], device='cuda:0')
5.355000019073486 tensor([1], device='cuda:0')
0.8447266817092896 tensor([0], device='cuda:0')
3.3678133487701416 tensor([1], device='cuda:0')
6.8026604652404785 tensor([1], device='cuda:0')
2.392662763595581 tensor([1], device='cuda:0')
4.4945244789123535 tensor([1], device='cuda:0')
0.75