In [3]:
import copy
import os
import random
from tqdm.notebook import tqdm
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from src.constants import DATA_PATH, MASTER_THESIS_DIR, TRAINING_CONFIG_PATH, MODEL_CONFIG_PATH
from src.data_loader.data_set import Data_Set
from src.experiments.utils import get_experiement_args, process_experiment_args
from src.models.baseline_model import BaselineModel
from src.utils import get_console_logger, read_json
from src.visualization.visualize import plot_hand
from torch.utils.data import DataLoader
from torchvision import transforms
from easydict import EasyDict as edict
import matplotlib.pyplot as plt

In [4]:
model_config = edict(read_json(MODEL_CONFIG_PATH))
model_config.gpu =False
train_config =  edict(read_json(TRAINING_CONFIG_PATH))
train_config.crop =True
train_config.rotate=True
print(model_config)

{'alpha': 5, 'gpu': False, 'resnet_trainable': True, 'learning_rate': 0.0001, 'scheduler': {'choice': 'cosine_annealing', 'cosine_annealing': {'T_max': 20, 'verbose': True}, 'reduce_on_plateau': {'mode': 'min', 'factor': 0.5, 'patience': 4, 'verbose': True}}}


In [5]:
train_data = Data_Set(
    config=train_config,
    transforms=transforms.Compose(
        [transforms.ToTensor()]
    ),
    train_set=True,
)
val_data = copy.copy(train_data)
val_data.is_training(False)

In [6]:
model = BaselineModel(config=model_config)

In [7]:
checkpoint = torch.load(os.path.join(MASTER_THESIS_DIR,"models","master-thesis","30a940c759aa43f1b19e22ecfff1621e","checkpoints","epoch=99.ckpt"), map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model.eval()

BaselineModel(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [8]:
@interact(idx=widgets.IntSlider(min=1, max=100, step=5, value=3))
def visualize(idx):
    sample = val_data[idx]
    joints = sample["joints"]
    img = sample["image"]
    img_input = img.view(([1]+list(img.shape)))
    prediction = model(img_input).view(21,3).detach().numpy()
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    plt.imshow(transforms.ToPILImage()(img))
#     plot_hand(ax, joints)
    plot_hand(ax, prediction,linestyle=':' )

interactive(children=(IntSlider(value=3, description='idx', min=1, step=5), Output()), _dom_classes=('widget-i…

In [21]:
val_pred_2d = []
val_pred_z = []
val_gt_z = []
val_gt_2d =[]
for i in tqdm(range(len(val_data)//100)):
    sample =train_data[i]
    joints = sample["joints"]
    img = sample["image"]
    img_input = img.view(([1]+list(img.shape)))
    pred = model(img_input).view(21,3)
    val_pred_2d.append(pred[:,:-1])
    val_pred_z.append(pred[:,-1:])
    val_gt_z.append(joints[:,-1:])
    val_gt_2d.append(joints[:,:-1])
    

HBox(children=(FloatProgress(value=0.0, max=130.0), HTML(value='')))




In [13]:
class EvalUtil:
    """ Util class for evaluation networks.
    """
    def __init__(self, num_kp=21):
        # init empty data storage
        self.data = list()
        self.num_kp = num_kp
        for _ in range(num_kp):
            self.data.append(list())

    def feed(self, keypoint_gt, keypoint_vis, keypoint_pred, skip_check=False):
        """ Used to feed data to the class. Stores the euclidean distance between gt and pred, when it is visible. """
        if not skip_check:
            keypoint_gt = np.squeeze(keypoint_gt)
            keypoint_pred = np.squeeze(keypoint_pred)
            keypoint_vis = np.squeeze(keypoint_vis).astype('bool')

            assert len(keypoint_gt.shape) == 2
            assert len(keypoint_pred.shape) == 2
            assert len(keypoint_vis.shape) == 1

        # calc euclidean distance
        diff = keypoint_gt - keypoint_pred
        euclidean_dist = np.sqrt(np.sum(np.square(diff), axis=1))

        num_kp = keypoint_gt.shape[0]
        for i in range(num_kp):
            if keypoint_vis[i]:
                self.data[i].append(euclidean_dist[i])

    def _get_pck(self, kp_id, threshold):
        """ Returns pck for one keypoint for the given threshold. """
        if len(self.data[kp_id]) == 0:
            return None

        data = np.array(self.data[kp_id])
        pck = np.mean((data <= threshold).astype('float'))
        return pck

    def _get_epe(self, kp_id):
        """ Returns end point error for one keypoint. """
        if len(self.data[kp_id]) == 0:
            return None, None

        data = np.array(self.data[kp_id])
        epe_mean = np.mean(data)
        epe_median = np.median(data)
        return epe_mean, epe_median

    def get_measures(self, val_min, val_max, steps):
        """ Outputs the average mean and median error as well as the pck score. """
        thresholds = np.linspace(val_min, val_max, steps)
        thresholds = np.array(thresholds)
        norm_factor = np.trapz(np.ones_like(thresholds), thresholds)

        # init mean measures
        epe_mean_all = list()
        epe_median_all = list()
        auc_all = list()
        pck_curve_all = list()

        # Create one plot for each part
        for part_id in range(self.num_kp):
            # mean/median error
            mean, median = self._get_epe(part_id)

            if mean is None:
                # there was no valid measurement for this keypoint
                continue

            epe_mean_all.append(mean)
            epe_median_all.append(median)

            # pck/auc
            pck_curve = list()
            for t in thresholds:
                pck = self._get_pck(part_id, t)
                pck_curve.append(pck)

            pck_curve = np.array(pck_curve)
            pck_curve_all.append(pck_curve)
            auc = np.trapz(pck_curve, thresholds)
            auc /= norm_factor
            auc_all.append(auc)

        epe_mean_all = np.mean(np.array(epe_mean_all))
        epe_median_all = np.mean(np.array(epe_median_all))
        auc_all = np.mean(np.array(auc_all))
        pck_curve_all = np.mean(np.array(pck_curve_all), 0)  # mean only over keypoints

        return epe_mean_all, epe_median_all, auc_all, pck_curve_all, thresholds

eval = EvalUtil()

In [22]:
eval.feed(val_gt_2d[0].detach().numpy(), np.ones(21), val_pred_2d[0].detach().numpy())

In [23]:
eval.get_measures(1,10, 10)

(4.2212625,
 4.2212625,
 0.6547619047619045,
 array([0.04761905, 0.14285714, 0.42857143, 0.57142857, 0.69047619,
        0.78571429, 0.88095238, 0.92857143, 0.95238095, 0.97619048]),
 array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]))