In [66]:
import cv2
import torch
import torchvision
import numpy as np
from facenet import FaceNetNN2
from os import listdir
from random import choice
from typing import Tuple, Generator
from IPython.display import clear_output

# Utils

In [69]:
preproc_img = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((224,224))
])

# Dataset

In [None]:
%%capture
!wget http://vis-www.cs.umass.edu/lfw/lfw.tgz
!tar xvf lfw.tgz
!rm lfw.tgz

In [7]:
def edist(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    '''
    Calculates the Euclidean distance between two points in Euclidean space.
    The two points are represented by tensors with shape (1, d).
    Returns a tensor of one element.
    Reference: https://en.wikipedia.org/wiki/Euclidean_distance 
    '''
    return (p - q).pow(2).sum().sqrt()

In [8]:
def is_valid_triplet(
    a_emb: torch.Tensor,
    p_emb: torch.Tensor,
    n_emb: torch.Tensor,
    margin: float
) -> bool:
    '''
    Determines the validity of a triplet for training.
    Reference: https://arxiv.org/abs/1503.03832 
    '''
    if edist(a_emb, p_emb)**2 + margin < edist(a_emb, n_emb)**2:
        return True
    else:
        return False

In [60]:
@torch.no_grad()
def generate_triplets(
    model: torch.nn.Module,
    margin: float = 0.0
) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]:
    '''
    Generates valid training triplets from the LFW dataset,
        going through all sorts of identities,
            but not every image (to speed things up).
    '''
    identities = listdir('lfw')
    for ap_identity in identities:
        ap_imgs_name = listdir(f'lfw/{ap_identity}')
        identities_without = identities.copy()
        identities_without.remove(ap_identity)
        for n_identity in identities_without:
            n_imgs_name = listdir(f'lfw/{n_identity}')

            for n_img_name in n_imgs_name:
                a_img_name = choice(ap_imgs_name) if len(ap_imgs_name) > 1 else ap_imgs_name[0]
                ap_imgs_name_without = ap_imgs_name.copy()
                ap_imgs_name_without.remove(a_img_name)
                p_img_name = choice(ap_imgs_name_without) if len(ap_imgs_name) > 1 else a_img_name
                a_img = cv2.imread(f'lfw/{ap_identity}/{a_img_name}')
                p_img = cv2.imread(f'lfw/{ap_identity}/{p_img_name}') if p_img_name != a_img_name else a_img.copy()
                n_img = cv2.imread(f'lfw/{n_identity}/{n_img_name}')
                a_img_tensor = preproc_img(a_img).unsqueeze(0)
                p_img_tensor = preproc_img(p_img).unsqueeze(0)
                n_img_tensor = preproc_img(n_img).unsqueeze(0)
                a_emb = model(a_img_tensor)
                p_emb = model(p_img_tensor) if p_img_name != a_img_name else a_emb.clone()
                n_emb = model(n_img_tensor)

                if is_valid_triplet(a_emb, p_emb, n_emb, margin):
                    yield (a_img_tensor, p_img_tensor, n_img_tensor)

# Model

In [46]:
@torch.no_grad()
def init_weights(m: torch.nn.Module):
    
    if getattr(m, 'weight', None) != None:
        torch.nn.init.uniform_(m.weight)
    if getattr(m, 'bias', None) != None:
        torch.nn.init.zeros_(m.bias)
    
    for childm in m.children():
        init_weights(childm)

In [65]:
class TripletLoss():
    '''
    Triplet Loss implementation.
    Reference: https://arxiv.org/abs/1503.03832 
    '''
    def __init__(self, margin: float):
        self.margin = margin

    def __call__(self, a: torch.Tensor, p: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
        return edist(a, p)**2 - edist(a, n)**2 + self.margin        

In [61]:
%%capture
model = FaceNetNN2()
# model.apply(init_weights)

In [71]:
triplets_generator = generate_triplets(model)

In [67]:
def train(
        model: torch.nn.Module,
        triplets_generator: Generator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], None, None],
        epochs: int
    ):

    loss_fn = TripletLoss(margin=0.2)
    optim_fn = torch.optim.SGD(model.parameters(), lr=0.05)

    for ep in range(epochs):
        for tr, triplet in enumerate(triplets_generator):
            a_img_tensor, p_img_tensor, n_img_tensor = triplet
            optim_fn.zero_grad()
            a_emb = model(a_img_tensor)
            p_emb = model(p_img_tensor)
            n_emb = model(n_img_tensor)
            loss = loss_fn(a_emb, p_emb, n_emb)
            loss.backward()
            optim_fn.step()
            
            clear_output()
            print(f'Epoch: {ep}. Triplet {tr}. Loss: {loss.item()}')

In [None]:
train(
    model,
    triplets_generator,
    epochs=1
)