# Training a facial recognition system with Contrastive Learning

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision.models import resnet18,ResNet18_Weights
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import tqdm
from glob import glob
import random
from torchvision.io import read_image, ImageReadMode

Here I just get a ResNet18 model, and replace its classification layer with an output layer that outputs a 256-dimensional vector.  That will be the size of my embeddings.

In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

model = resnet18(weights=ResNet18_Weights.DEFAULT)
num_feats = model.fc.in_features
model.fc = nn.Linear(num_feats, 256)

Here's my dataset.  A lot actually is happening here.  The key thing to understand here is that it returns an *anchor* image, a *positive* image of the same person, and a *negative* image of a different person.

In [3]:
class ContLearnDataset(Dataset):
    def __init__(self, dir, transforms=None):
        self.dir = dir
        self.transforms = transforms
        self.filenames = glob(dir+'/*/*.png')
        self.ids = set([ self.path2id(pth) for pth in self.filenames ])

    def path2id(self, path):
        return path.split('/')[-2]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        anchor = self.filenames[idx]

        id = self.path2id(anchor)
        
        poscands = [fn for fn in self.filenames if id in fn and fn != anchor]
        ind = torch.randint(len(poscands),(1,))[0]
        positive = poscands[ind]
        
        negcands = [fn for fn in self.filenames if id not in fn]
        ind = torch.randint(len(negcands),(1,))[0]
        negative = negcands[ind]

        anchor = read_image(anchor,mode=ImageReadMode.RGB)
        positive = read_image(positive,mode=ImageReadMode.RGB)
        negative = read_image(negative,mode=ImageReadMode.RGB)
        if self.transforms is not None:
            return self.transforms(anchor), self.transforms(positive), self.transforms(negative)
        return anchor, positive, negative

I build my datasets and my dataloader...

In [4]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])
cld = ContLearnDataset('data/faces/training/',transforms=transform)
cld2 = ContLearnDataset('data/faces/testing/',transforms=transform)

dl = DataLoader(cld, batch_size=64, num_workers=10)

I train using TripletMarginLoss, to encourage the embeddings of the same person to be close to each other, and the embeddings of different people to be farther away from each other.

In [5]:
model=model.to('cuda')
EPOCHS = 201

criterion = nn.TripletMarginLoss()
optimizer = optim.Adam(model.parameters(), lr=.001)

for epoch in tqdm.tqdm(range(EPOCHS)):
    totalloss=0
    for batch, (a, p, n) in enumerate(dl):
        a,p,n = a.to('cuda'), p.to('cuda'), n.to('cuda')
        aem = model(a)
        pem = model(p)
        nem = model(n)
        loss = criterion(aem, pem, nem)

            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        totalloss+=loss.item()
    if epoch%100==0:
        print(totalloss)
torch.save(model, 'fr.pt')

  0%|▏                                          | 1/201 [00:00<03:07,  1.07it/s]

3.775298625230789


 50%|████████████████████▌                    | 101/201 [01:00<00:59,  1.67it/s]

0.024434760212898254


100%|█████████████████████████████████████████| 201/201 [02:00<00:00,  1.67it/s]

0.1770007610321045





In [6]:
model=torch.load('fr.pt')
model=model.to('cuda')

Here I test on a testing set, which was not trained on.  `a,p,n` are an anchor, positive, and negative examples from a different dataset of people.  You can see the distance between the anchor and positive example is far smaller than the distances between the anchor and negative, and positive and negative examples.

In [8]:
a,p,n = cld2[20]
a,p,n = a.reshape((1,3,112,92)), p.reshape((1,3,112,92)), n.reshape((1,3,112,92))
vals=torch.cat((a,p,n),dim=0).to('cuda')
res=model(vals).detach()
print(res.shape)
res[0].shape
print(f'a/p: {nn.functional.mse_loss(res[0],res[1])}')
print(f'a/n: {nn.functional.mse_loss(res[0],res[2])}')
print(f'p/n: {nn.functional.mse_loss(res[1],res[2])}')


torch.Size([3, 256])
a/p: 0.007957154884934425
a/n: 0.5905918478965759
p/n: 0.5849518775939941
