# Import libraries

In [None]:
import os
import copy
from tqdm.notebook import tqdm
import pprint
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import torch
from src.constants import DATA_PATH, MASTER_THESIS_DIR, TRAINING_CONFIG_PATH, MODEL_CONFIG_PATH
from src.data_loader.data_set import Data_Set
from src.experiments.utils import get_experiement_args, process_experiment_args
from src.models.baseline_model import BaselineModel
from src.utils import get_console_logger, read_json
from src.visualization.visualize import plot_hand
from torch.utils.data import DataLoader
from torchvision import transforms
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
from src.data_loader.utils import convert_2_5D_to_3D, convert_to_2_5D
from src.experiments.evaluation_utils import calculate_epe_statistics, evaluate,get_predictions_and_ground_truth, cal_auc_joints, get_pck_curves
from src.data_loader.utils import get_train_val_split
from src.models.supervised_head_model import SupervisedHead

# Load configuration

In [None]:
model_config = edict(read_json(os.path.join(MASTER_THESIS_DIR, "src", "experiments","config", "simclr_config.json")))
joints_mapping = {v: k for k, v in read_json("/local/home/adahiya/Documents/master_thesis/src/data_loader/joint_mapping.json")["ait"].items()}
model_config.gpu =True
train_config =  edict(read_json(TRAINING_CONFIG_PATH))
train_config.augmentation_flags.crop =True
train_config.augmentation_flags.resize =True
train_config.augmentation_flags.rotate=False
print(model_config)
train_config.num_workers=8
train_config.epochs = 150
model_config.warmup_epochs = 10
train_config, model_config

# Load data

In [None]:
train_data = Data_Set(
    config=train_config,
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
    train_set=True,
)
val_data = copy.copy(train_data)
val_data.is_training(False)

##  Sanity check against all zero predictions 
only 3d

In [None]:
joints3d = []
print("Reading the data")
for i in tqdm(range(len(val_data))):
    joints3d.append(val_data[i]["joints3D"])
Joints3d = torch.stack(joints3d)
Joints3d_zero = torch.zeros_like(Joints3d)
stats = calculate_epe_statistics(Joints3d_zero, Joints3d, 3)
print(stats)

# Load model

In [None]:
# supervised model
model = BaselineModel(config=model_config)
checkpoint = torch.load("/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/8364a50940f44dc3963173d0f7b3c68a/checkpoints/epoch=284.ckpt"
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
device = torch.device("cuda")
model.to(device)
print()

In [None]:
# semi supervised model
model = SupervisedHead(
    model_config,
    "/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/14910753afb9499ab9229174d6377efb/checkpoints/epoch=999.ckpt",
    model_config,
)
checkpoint = torch.load(
    "/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/abf5b702ea4c4ad78703b6f81221acd4/checkpoints/epoch=149.ckpt"
)["state_dict"]
model.load_state_dict(checkpoint)
model.eval()

# calculate Results

In [None]:
# strong augmentations
results = evaluate(model, val_data, num_workers=8,batch_size=32)
pprint.pprint(results)
pred_dict = get_predictions_and_ground_truth(model, val_data,num_workers=8,batch_size=32)
pred = pred_dict["predictions_3d"]
gt = pred_dict["ground_truth_3d"]
sts =calculate_epe_statistics(pred, gt, dim=3)
eucledian_dist = sts['eucledian_dist']
y,x = get_pck_curves(eucledian_dist, per_joint=True)
auc = cal_auc_joints(eucledian_dist, per_joint=True)
print(f"AUC : {np.mean(auc)}")

In [None]:
# weak augmentations
results = evaluate(model, val_data, num_workers=8,batch_size=32)
pprint.pprint(results)
pred_dict = get_predictions_and_ground_truth(model, val_data,num_workers=8,batch_size=32)
pred = pred_dict["predictions_3d"]
gt = pred_dict["ground_truth_3d"]
sts =calculate_epe_statistics(pred, gt, dim=3)
eucledian_dist = sts['eucledian_dist']
y,x = get_pck_curves(eucledian_dist, per_joint=True)
auc = cal_auc_joints(eucledian_dist, per_joint=True)
print(f"AUC : {np.mean(auc)}")

#  Visualize sample predictions

In [None]:
model.to(torch.device('cpu'))
print()

In [None]:
@interact(idx=widgets.IntSlider(min=1, max=100, step=5, value=3),plot_gt= widgets.Checkbox(value=True, description='Ground truth'),
         plot_pred= widgets.Checkbox(value=True, description='Predicted_labels'))
def visualize1(idx, plot_gt, plot_pred):
    sample = val_data[idx]
    joints = sample["joints"]
    img = sample["image"]
    img_input = img.view(([1]+list(img.shape)))
    prediction = model(img_input).view(21,3).detach().numpy()
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    plt.imshow(transforms.ToPILImage()(img))
    if plot_gt:
       plot_hand(ax, joints)
    if plot_pred:
        plot_hand(ax, prediction,linestyle=':' )
    plt.show()

In [None]:
@interact(joint_id=widgets.IntSlider(min=0, max=20, step=1, value=3))

def visualize2(joint_id):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    plt.plot(x,y[joint_id])
    plt.title(f"PCK curve, AUC: {auc[joint_id]}")
    plt.plot(x,np.mean(y,axis=0), color='black',linestyle=":" ,alpha=0.5)
    plt.xlabel("error in mm")
    plt.ylabel("Ratio of points below the error")
    plt.legend([f"PCK curve for {joints_mapping[joint_id]}", "Average PCK curve"])
    plt.show()