# Import libraries

In [9]:
import os
from tqdm.notebook import tqdm
import pprint
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import torch
from src.constants import DATA_PATH, MASTER_THESIS_DIR, TRAINING_CONFIG_PATH, MODEL_CONFIG_PATH
from src.data_loader.data_set import Data_Set
from src.experiments.utils import get_experiement_args, process_experiment_args
from src.models.baseline_model import BaselineModel
from src.utils import get_console_logger, read_json
from src.visualization.visualize import plot_hand
from torch.utils.data import DataLoader
from torchvision import transforms
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
from src.data_loader.utils import convert_2_5D_to_3D, convert_to_2_5D
from src.experiments.evaluation_utils import calculate_epe_statistics, evaluate

# Load configuration

In [30]:
model_config = edict(read_json(MODEL_CONFIG_PATH))
model_config.gpu =True
train_config =  edict(read_json(TRAINING_CONFIG_PATH))
train_config.crop =True
train_config.rotate=False
print(model_config)

{'alpha': 5, 'gpu': True, 'resnet_trainable': True, 'learning_rate': 0.0001, 'scheduler': {'choice': 'cosine_annealing', 'cosine_annealing': {'T_max': 20, 'verbose': True}, 'reduce_on_plateau': {'mode': 'min', 'factor': 0.5, 'patience': 4, 'verbose': True}}}


# Load data

In [32]:
train_data = Data_Set(
    config=train_config,
    transforms=transforms.Compose(
        [transforms.ToTensor()]
    ),
    train_set=True,
)
val_data = copy.copy(train_data)
val_data.is_training(False)

# Load model

In [13]:
model = BaselineModel(config=model_config)
checkpoint = torch.load(
    "/local/home/adahiya/Documents/master_thesis/models/models/master-thesis/54de390a2c1f455481fc94a0b2063cd8/checkpoints/epoch=260.ckpt"
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
device = torch.device("cuda")
model.to(device)


BaselineModel(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

# calculate Results

In [33]:
results = evaluate(model, val_data, num_workers=8,batch_size=32)
pprint.pprint(results)

407it [00:03, 116.74it/s]
100%|██████████| 13024/13024 [00:03<00:00, 3962.00it/s]

{'Mean_EPE_2D': tensor(22.0244),
 'Mean_EPE_3D': tensor(0.4368),
 'Mean_EPE_3D_R': tensor(0.4180),
 'Mean_EPE_3D_R_v_3D': tensor(0.1832),
 'Median_EPE_2D': tensor(17.9482),
 'Median_EPE_3D': tensor(0.2279),
 'Median_EPE_3D_R': tensor(0.1830),
 'Median_EPE_3D_R_V_3D': tensor(0.1437)}





#  Visualize sample predictions

In [None]:
model.to(torch.device('cpu'))

In [29]:
@interact(idx=widgets.IntSlider(min=1, max=100, step=5, value=3),plot_gt= widgets.Checkbox(value=True, description='Ground truth'),
         plot_pred= widgets.Checkbox(value=True, description='Predicted_labels'))
def visualize(idx, plot_gt, plot_pred):
    sample = val_data[idx]
    joints = sample["joints"]
    img = sample["image"]
    img_input = img.view(([1]+list(img.shape)))
    prediction = model(img_input).view(21,3).detach().numpy()
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    plt.imshow(transforms.ToPILImage()(img))
    if plot_gt:
       plot_hand(ax, joints)
    if plot_pred:
        plot_hand(ax, prediction,linestyle=':' )
    plt.show()

interactive(children=(IntSlider(value=3, description='idx', min=1, step=5), Checkbox(value=True, description='…