In [None]:
import torch
from torch import nn
import json
from glob import glob
import numpy as np 
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
import cv2
import os
from stacked_hourglass.model import hg2, hg8

def Euclidian_distance(a, b):
    return np.sqrt(np.sum((a - b) ** 2))

def getCoordFromHeatmap(heatmap):
    # get the coordinates of the maximum value in the heatmap
    max_value = heatmap.max()
    max_value_index = np.where(heatmap == max_value)
    return max_value_index[0][0], max_value_index[1][0]

pth_path = './checkpoint/checkpoint.pth.tar'
model = hg2().cuda()
state_dict = torch.load(pth_path)['state_dict']
dict_= {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(dict_)
test_img_list = glob('./data/test/eye0_sub16/*.jpg')
test_json_list = glob('./data/test/eye0_sub16/*.json')

dist_list = []
for i in trange(len(test_img_list)):
    img_ = cv2.imread(test_img_list[i])
    img = cv2.cvtColor(img_, cv2.COLOR_BGR2GRAY)
    img = cv2.resize(img, (256, 256)).reshape(1, 256, 256) / 255.0
    img = torch.from_numpy(img).float().unsqueeze(0).cuda()
    output = model(img)
    
    orig_keypoints = json.load(open(test_json_list[i], 'r'))
    
    ratio_x = 64/640
    ratio_y = 64/480
    eyelid_x = orig_keypoints['eyelid_x']
    eyelid_x = [x*ratio_x for x in eyelid_x]
    eyelid_y = orig_keypoints['eyelid_y']
    eyelid_y = [y*ratio_y for y in eyelid_y]
    eyelid_coord = np.array([eyelid_x, eyelid_y]).T

    iris_x = orig_keypoints['iris_x']
    iris_x = [x*ratio_x for x in iris_x]
    iris_y = orig_keypoints['iris_y']
    iris_y = [y*ratio_y for y in iris_y]
    iris_coord = np.array([iris_x, iris_y]).T

    pupil_x = orig_keypoints['pupil_x']
    pupil_x = [x*ratio_x for x in pupil_x]
    pupil_y = orig_keypoints['pupil_y']
    pupil_y = [y*ratio_y for y in pupil_y]
    pupil_coord = np.array([pupil_x, pupil_y]).T
    
    orig_keypoints = np.concatenate((eyelid_coord, iris_coord, pupil_coord), axis=0)
    pred = output[0].cpu().detach().numpy()
    pred_coord = np.array([getCoordFromHeatmap(pred[0]) for i in range(28)])
    
    dist = 0
    for i in range(28):
        dist += Euclidian_distance(pred_coord[i], orig_keypoints[i])
    
    dist_list.append(dist/28)

print(np.mean(dist_list))