In [1]:
import os
import json
import torch

from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity

from FaceDataset import FaceDataset, TRANSFORMS
from models import ModifiedResnet18




In [2]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

In [3]:
DATASET_PATH = "../dataset/"
TRAIN_PATH = DATASET_PATH + "/train/"
VALIDATION_PATH = DATASET_PATH + "/validation/"
TEST_PATH = DATASET_PATH + "/test/"
landmarks = json.load(open(DATASET_PATH + 'data.json'))

train_images = os.listdir(TRAIN_PATH)
validation_images = os.listdir(VALIDATION_PATH)
test_images = os.listdir(TEST_PATH)

train_dataset =  FaceDataset(train_images, TRAIN_PATH, landmarks, device, transforms=TRANSFORMS)
validation_dataset = FaceDataset(validation_images, VALIDATION_PATH, landmarks, device, transforms=TRANSFORMS)
test_dataset = FaceDataset(test_images, TEST_PATH, landmarks, device, transforms=TRANSFORMS)

In [4]:
def draw_keypoints(image: torch.Tensor, keypoints):
    if len(keypoints.shape) == 1:
        keypoints = keypoints.reshape(-1, 2)
    new_image = torch.clone(image)
    for point in keypoints:
        point = point.to(dtype=int)
        x, y = point[0], point[1]
        new_image[:, y, x] = torch.Tensor((0, 255, 0))
    return new_image

In [5]:
#new_image = draw_keypoints(image, keypoints)
#plt.figure(figsize=(12, 12))
#plt.imshow(new_image.to(device=torch.device('cpu'), dtype=torch.int).permute(1,2,0))

In [6]:
model = ModifiedResnet18().to(device=device)
model.load_state_dict(torch.load("resnet18_weights.pt"))
model.eval()

ModifiedResnet18(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [7]:
test_dataloader = DataLoader(test_dataset, batch_size=16)
max_i = 0
max_similarity = 0
real_keypoints = test_dataset[123][1]
i = 0
for batch, (image, label) in enumerate(test_dataloader):
    with torch.no_grad():
        predicted_keypoints = model(image.to(device=device, dtype=torch.float)).squeeze()
    predicted_keypoints = predicted_keypoints.to(device=torch.device('cpu')).detach().numpy()
    for keypoint in predicted_keypoints:
        similarity = cosine_similarity(keypoint.reshape(1, -1), real_keypoints.reshape(1, -1))
        if similarity > max_similarity:
            max_similarity = similarity
            max_i = i
        i += 1
    
    
print(max_i, max_similarity)

123 [[0.99991655]]
