In [1]:
import torch
import torchvision.transforms as transforms
import cv2
from model.resnet import Resnet34Triplet


In [2]:
from fpt.data import join_face_df
from fpt.path import DTFR
from fpt.split import read_split

In [3]:
DATA_CATEGORY = "aihub_family"
SPLIT = "valid"

In [4]:
flag_gpu_available = torch.cuda.is_available()
if flag_gpu_available:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [5]:
checkpoint = torch.load("model/model_resnet34_triplet.pt", map_location=device)
model = Resnet34Triplet(embedding_dimension=checkpoint["embedding_dimension"])
model.load_state_dict(checkpoint["model_state_dict"])
best_distance_threshold = checkpoint["best_distance_threshold"]

In [None]:
model.to(device)
model.eval()

In [7]:
preprocess = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize(size=140),  # Pre-trained model uses 140x140 input images
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[
                0.6071,
                0.4609,
                0.3944,
            ],  # Normalization settings for the model, the calculated mean and std values
            std=[
                0.2457,
                0.2175,
                0.2129,
            ],  # for the RGB channels of the tightly-cropped glint360k face dataset
        ),
    ]
)


In [None]:
face = join_face_df(DTFR, DATA_CATEGORY)
valid_face_uuids = read_split(SPLIT)
x_valid = face.loc[valid_face_uuids]

In [9]:
img = cv2.imread(x_valid.iloc[0].image)  # Or from a cv2 video capture stream

# Note that you need to use a face detection model here to crop the face from the image and then
#  create a new face image object that will be inputted to the facial recognition model later.

# Convert the image from BGR color (which OpenCV uses) to RGB color
img = img[:, :, ::-1]

img = preprocess(img)
img = img.unsqueeze(0)
img = img.to(device)

embedding = model(img)

# Turn embedding Torch Tensor to Numpy array
embedding = embedding.cpu().detach().numpy()
embedding.shape

(1, 512)