# Import libraries

In [None]:
import os
import torch
from torchvision import transforms
from src.data_loader.freihand_loader import F_DB
from src.models.baseline_model import BaselineModel
from torch.utils.data import DataLoader
from src.data_loader.utils import convert_2_5D_to_3D, get_root_depth, convert_to_2_5D
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from src.visualization.visualize import plot_hand
from src.constants import MASTER_THESIS_DIR, FREIHAND_DATA
from src.utils import read_json
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
from tqdm.notebook import tqdm

In [None]:
from src.data_loader.joints import Joints

In [None]:
joints =Joints()

# Dataset

In [None]:
f_db = F_DB(
    root_dir=os.path.join(FREIHAND_DATA, "training", "rgb"),
    labels_path=os.path.join(FREIHAND_DATA, "training_xyz.json"),
    camera_param_path=os.path.join(FREIHAND_DATA, "training_K.json"),
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            #             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
)

In [None]:
convert_2_5D_to_3D(f_db[0]['joints'], f_db[0]['scale'], f_db[0]['K'])

In [None]:
print(f_db[0]['joints_3D'])

# Label Visualization

In [None]:
@interact(id=widgets.IntSlider(min=0, max=len(f_db), step=1, value=179))
def visualize_sample(id):
    A = joints.freihand_to_ait(joints.ait_to_freihand(f_db[id]["joints"]))
    s = f_db[id]["scale"]
    K = f_db[id]["K"]
    img = transforms.ToPILImage('RGB')(f_db[id]["image"])
    display(img)
    Axy = f_db[id]["joints_3D"]
    fig = plt.figure(figsize=(10,10))
    gs = fig.add_gridspec(5, 8)
    ax1 = fig.add_subplot(gs[1:3,:2])
    ax1.set_title('2d pose')
    plot_hand(ax1, np.array(A))
    ax2 = fig.add_subplot(gs[:,2:5], projection='3d')
    ax2.set_title('3D pose')
    plot_hand(ax2, np.array(Axy),plot_3d=True)
    ax3 = fig.add_subplot(gs[:,5:], projection='3d')
    ax3.set_title('3D pose recreated')
    # recreated 3D pose:
    Axy_recreated= convert_2_5D_to_3D(A,s,K)
    plot_hand(ax3, np.array(Axy_recreated),plot_3d=True) 
    print(Axy - Axy_recreated)
    print(f"total error {torch.mean(torch.sum((Axy-Axy_recreated)**2)**.5)}")
    
    plt.show()

In [None]:
def error_in_conversion(true_joints_3D, K):
    error_percentage = torch.abs((convert_2_5D_to_3D(joints25D,scale,K) -true_joints_3D))/true_joints_3D
    return torch.max(error_percentage)*100

In [None]:
error = torch.tensor(0)
high_error_index = []
for id in tqdm(range(len(f_db)//4)):
    joints25D = f_db[id]["joints"]
    scale = f_db[id]["scale"]
    K = f_db[id]["K"]
    true_joints_3D = f_db[id]["joints_3D"]
    error  =torch.max(torch.stack([error_in_conversion(true_joints_3D, K), error]))
    if error>50:
        high_error_index.append(id)