In [10]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

import timm
from PIL import Image
import matplotlib.pyplot as plt
import json
import os

In [11]:
class ConvNeXtArcFace(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=True):
        super(ConvNeXtArcFace, self).__init__()
        self.convnext = timm.create_model(model_name, pretrained=pretrained)
        self.convnext.reset_classifier(num_classes=0, global_pool='avg')
      
    def forward(self, x):
        x = self.convnext.forward_features(x) # 
        x = F.avg_pool2d(x, 7).flatten(1)
        return x

In [12]:
ckpt = torch.load("../safe/epoch_20.pth")
model_state_dict = ckpt['model_state_dict']
model = ConvNeXtArcFace(model_name="mobilenetv4_conv_small", embedding_size=960)
model.load_state_dict(model_state_dict)
model.eval()
print()




In [13]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def getImageTensor(imgpath):
    img = Image.open(imgpath).convert('RGB')
    img = img.resize((224, 224))
    img_tensor = preprocess(img)
    img_tensor = img_tensor.unsqueeze(0)
    return img_tensor

In [14]:
def getEmbeddings(imgpath):
    img_tensor = getImageTensor(imgpath)
    with torch.no_grad():
        embeddings = model(img_tensor)
        embeddings = embeddings.squeeze()
    return embeddings

In [15]:
name = "Danny Trejo"
imgpath = "dataset/Danny Trejo/image.png"

In [16]:
# embeddings = getEmbeddings(imgpath)

# try:
#     with open("employee_embeddings/embeddings.json", "r") as f:
#         embeddings_list = json.load(f)
# except FileNotFoundError:
#     embeddings_list = []

# embeddings_dict = {employee['name']: employee['embeddings'] for employee in embeddings_list}

# if name in embeddings_dict:
#     embeddings_dict[name].append(embeddings.tolist())
# else:
#     embeddings_dict[name] = [embeddings.tolist()]

# embeddings_list = [{'name': name, 'embeddings': embeddings} for name, embeddings in embeddings_dict.items()]

# with open("employee_embeddings/embeddings.json", "w") as f:
#     json.dump(embeddings_list, f, indent=4)

In [17]:
dataset_dir = "dataset"
embeddings_dict = {}

for person_name in os.listdir(dataset_dir):
    person_dir = os.path.join(dataset_dir, person_name)
    if os.path.isdir(person_dir):
        for img_name in os.listdir(person_dir):
            img_path = os.path.join(person_dir, img_name)
            embeddings = getEmbeddings(img_path)
            if person_name in embeddings_dict:
                embeddings_dict[person_name].append(embeddings.tolist())
            else:
                embeddings_dict[person_name] = [embeddings.tolist()]

embeddings_list = [{'name': name, 'embeddings': embeddings} for name, embeddings in embeddings_dict.items()]

with open("employee_embeddings/embeddings.json", "w") as f:
    json.dump(embeddings_list, f, indent=4)