In [6]:
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image
from fpt.path import DTFR
from fpt.data import join_face_df
from facenet_pytorch import InceptionResnetV1
from facenet_pytorch import fixed_image_standardization

In [76]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
face = join_face_df(DTFR)
annotations_file = face.loc[:, ["path", "target"]]

In [69]:
data_transform = transforms.Compose(
    [
        transforms.Resize(160),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization
    ]
)

In [70]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, transform):
        self.img_labels = annotations_file
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = self.img_labels.iloc[idx, 0]
        image = Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        return image, label

In [71]:
ds = CustomImageDataset(annotations_file, data_transform)
loader = DataLoader(ds, batch_size=2, shuffle=True)

In [72]:
train_features, train_labels = next(iter(loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels}")

Feature batch shape: torch.Size([2, 3, 160, 160])
Labels batch shape: ('F0001-GM', 'F0003-D')


In [77]:
resnet = InceptionResnetV1(pretrained="vggface2").eval().to(device)

In [78]:
embeddings = resnet(train_features).detach().cpu()
embeddings.shape