In [2]:
import os
import random
import itertools
import numpy as np
from collections import Counter
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset, Sampler
from torchvision import datasets, transforms
from torchvision.models.mobilenetv2 import MobileNetV2
import pytorch_lightning as pl
import torch.optim as optim

# Set raw directory
kaggle_path = 'kaggle/input/labeled-faces-in-the-wild-lfw-20180109/lfw'
default_path = 'data/raw'
raw_dir = kaggle_path if os.path.exists(kaggle_path) else default_path
print(f'Raw directory is set to: {raw_dir}')

# DATASET
class SiameseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.image_folder = datasets.ImageFolder(root=root_dir)
        self.transform = transform
        self.image_pairs = list(itertools.combinations_with_replacement(range(len(self.image_folder)), 2))
        self.targets = [int(self.image_folder.targets[idx1] == self.image_folder.targets[idx2]) for idx1, idx2 in self.image_pairs]

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, index):
        idx1, idx2 = self.image_pairs[index]
        img1, _ = self.image_folder[idx1]
        img2, _ = self.image_folder[idx2]
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img1, img2, self.targets[index]

# SAMPLER & DATALOADER
class RandomUnderSampler(Sampler):
    def __init__(self, targets, seed=42, shuffle=False):
        self.class_counts = Counter(targets)
        self.indices = {cls: np.where(targets == cls)[0] for cls in self.class_counts.keys()}
        self.seed = seed
        self.shuffle = shuffle

    def __iter__(self):
        sampled_indices = []
        if self.seed is not None:
            np.random.seed(self.seed)
        for _, indices in self.indices.items():
            sampled_indices.extend(np.random.choice(indices, self.__min_count()))
        if self.shuffle:
            np.random.shuffle(sampled_indices)
        return iter(sampled_indices)

    def __len__(self):
        return self.__min_count() * len(self.class_counts.keys())

    def __min_count(self):
        return min(self.class_counts.values())

# MODEL
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.backbone = MobileNetV2()
        self.fc = nn.Sequential(
            nn.Linear(1280, 1280),
            nn.Sigmoid(),
        )

    def forward_one_branch(self, x):
        x = self.backbone.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one_branch(input1)
        output2 = self.forward_one_branch(input2)
        return output1, output2

class EuclideanDistance(nn.Module):
    def __init__(self):
        super(EuclideanDistance, self).__init__()

    def forward(self, output1, output2):
        return torch.sqrt(torch.sum((output1 - output2) ** 2))

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        loss_negative = (1 - label) * distance ** 2
        loss_positive = label * torch.clamp(self.margin - distance, min=0) ** 2
        loss = torch.mean(loss_negative + loss_positive)
        return loss

class sMobileNetV2(pl.LightningModule):
    def __init__(self):
        super(sMobileNetV2, self).__init__()
        self.embedding = Embedding()
        self.distance = EuclideanDistance()
        self.loss_fn = ContrastiveLoss(margin=1.0)

    def forward(self, x1, x2):
        return self.embedding(x1, x2)

    def training_step(self, batch, batch_idx):
        img1, img2, label = batch
        out1, out2 = self.forward(img1, img2)
        dist = self.distance(out1, out2)
        loss = self.loss_fn(dist, label)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Setup dataset and dataloader
ds = SiameseDataset(raw_dir)
ds.transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

sampler = RandomUnderSampler(np.array(ds.targets), seed=42)
dl = torch.utils.data.DataLoader(ds, sampler=sampler, batch_size=8)

# Setup and start training
trainer = pl.Trainer(
    max_epochs=1,
    max_steps=100
)

model = sMobileNetV2()
trainer.fit(model, dl)

Raw directory is set to: data/raw


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | embedding | Embedding         | 5.1 M  | train
1 | distance  | EuclideanDistance | 0      | train
2 | loss_fn   | ContrastiveLoss   | 0      | train
--------------------------------------------------------
5.1 M     Trainable params
0         Non-trainable params
5.1 M     Total params
20.578    Total estimated model params size (MB)
219       Modules in train mode
0         Modules in eval mode
/opt/homebrew/Caskroom/miniconda/base/envs/oneshot-face/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=100` reached.
