# Import libraries

In [1]:
import os
import copy
from tqdm.notebook import tqdm
import pprint
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import torch
from src.constants import DATA_PATH, MASTER_THESIS_DIR, TRAINING_CONFIG_PATH, MODEL_CONFIG_PATH
from src.data_loader.data_set import Data_Set
from src.experiments.utils import get_experiement_args, process_experiment_args
from src.models.baseline_model import BaselineModel
from src.utils import get_console_logger, read_json
from src.visualization.visualize import plot_hand
from torch.utils.data import DataLoader
from torchvision import transforms
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
from src.data_loader.utils import convert_2_5D_to_3D, convert_to_2_5D
from src.experiments.evaluation_utils import calculate_epe_statistics, evaluate,get_predictions_and_ground_truth, cal_auc_joints, get_pck_curves
from src.data_loader.utils import get_train_val_split
from src.models.supervised_head_model import SupervisedHead

# Load configuration

In [2]:
model_config = edict(read_json(os.path.join(MASTER_THESIS_DIR, "src", "experiments", "simclr_config.json")))
joints_mapping = {v: k for k, v in read_json("/local/home/adahiya/Documents/master_thesis/src/data_loader/joint_mapping.json")["ait"].items()}
model_config.gpu =True
train_config =  edict(read_json(TRAINING_CONFIG_PATH))
train_config.augmentation_flags.crop =True
train_config.augmentation_flags.resize =True
train_config.augmentation_flags.rotate=False
print(model_config)
train_config.num_workers=8
train_config.epochs = 150
model_config.warmup_epochs = 10
train_config, model_config

{'batch_size': 128, 'lr': 0.0001, 'opt_weight_decay': 0.0001, 'output_dim': 128, 'projection_head_hidden_dim': 2048, 'projection_head_input_dim': 512, 'warmup_epochs': 10, 'gpu': True}


({'batch_size': 128,
  'epochs': 150,
  'train_ratio': 0.9,
  'gpu': True,
  'num_workers': 8,
  'seed': 5,
  'augmentation_flags': {'color_drop': False,
   'color_jitter': False,
   'crop': True,
   'cut_out': False,
   'flip': False,
   'gaussian_blur': False,
   'random_crop': False,
   'resize': True,
   'rotate': False},
  'augmentation_params': {'crop_margin': 1.5,
   'crop_margin_range': [0.8, 1.5],
   'cut_out_fraction': [0.05, 0.2],
   'hue_factor_range': [-0.5, 0.5],
   'max_angle': 359,
   'min_angle': 0,
   'resize_shape': [128, 128],
   'sat_factor_range': [-0.5, 0.5],
   'value_factor_alpha_range': [0.5, 1],
   'value_factor_beta_range': [10, 50]}},
 {'batch_size': 128,
  'lr': 0.0001,
  'opt_weight_decay': 0.0001,
  'output_dim': 128,
  'projection_head_hidden_dim': 2048,
  'projection_head_input_dim': 512,
  'warmup_epochs': 10,
  'gpu': True})

# Load data

In [3]:
train_data = Data_Set(
    config=train_config,
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
    train_set=True,
)
val_data = copy.copy(train_data)
val_data.is_training(False)

# Load model

In [25]:
# supervised model
model = BaselineModel(config=model_config)
checkpoint = torch.load("/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/8364a50940f44dc3963173d0f7b3c68a/checkpoints/epoch=284.ckpt"
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
device = torch.device("cuda")
model.to(device)
print()




In [15]:
# semi supervised model
model = SupervisedHead(
    model_config,
    "/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/14910753afb9499ab9229174d6377efb/checkpoints/epoch=999.ckpt",
    model_config,
)
checkpoint = torch.load(
    "/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/abf5b702ea4c4ad78703b6f81221acd4/checkpoints/epoch=149.ckpt"
)["state_dict"]
model.load_state_dict(checkpoint)
model.eval()

SupervisedHead(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

# calculate Results

In [16]:
# strong augmentations
results = evaluate(model, val_data, num_workers=8,batch_size=32)
pprint.pprint(results)
pred_dict = get_predictions_and_ground_truth(model, val_data,num_workers=8,batch_size=32)
pred = pred_dict["predictions_3d"]
gt = pred_dict["ground_truth_3d"]
sts =calculate_epe_statistics(pred, gt, dim=3)
eucledian_dist = sts['eucledian_dist']
y,x = get_pck_curves(eucledian_dist, per_joint=True)
auc = cal_auc_joints(eucledian_dist, per_joint=True)
print(f"AUC : {np.mean(auc)}")

407it [00:03, 118.68it/s]
100%|██████████| 13024/13024 [00:03<00:00, 3975.73it/s]

{'Mean_EPE_2D': tensor(13.4872),
 'Mean_EPE_3D': tensor(0.3836),
 'Mean_EPE_3D_R': tensor(0.3444),
 'Mean_EPE_3D_R_v_3D': tensor(0.1832),
 'Median_EPE_2D': tensor(10.5245),
 'Median_EPE_3D': tensor(0.1891),
 'Median_EPE_3D_R': tensor(0.1308),
 'Median_EPE_3D_R_V_3D': tensor(0.1437)}



407it [00:03, 118.92it/s]
100%|██████████| 13024/13024 [00:03<00:00, 3800.39it/s]


AUC : 0.50091322308425


In [14]:
# weak augmentations
results = evaluate(model, val_data, num_workers=8,batch_size=32)
pprint.pprint(results)
pred_dict = get_predictions_and_ground_truth(model, val_data,num_workers=8,batch_size=32)
pred = pred_dict["predictions_3d"]
gt = pred_dict["ground_truth_3d"]
sts =calculate_epe_statistics(pred, gt, dim=3)
eucledian_dist = sts['eucledian_dist']
y,x = get_pck_curves(eucledian_dist, per_joint=True)
auc = cal_auc_joints(eucledian_dist, per_joint=True)
print(f"AUC : {np.mean(auc)}")

407it [00:03, 120.07it/s]
100%|██████████| 13024/13024 [00:03<00:00, 4019.47it/s]

{'Mean_EPE_2D': tensor(27.1314),
 'Mean_EPE_3D': tensor(0.9543),
 'Mean_EPE_3D_R': tensor(0.9589),
 'Mean_EPE_3D_R_v_3D': tensor(0.1832),
 'Median_EPE_2D': tensor(24.9557),
 'Median_EPE_3D': tensor(0.4825),
 'Median_EPE_3D_R': tensor(0.4832),
 'Median_EPE_3D_R_V_3D': tensor(0.1437)}



407it [00:03, 120.44it/s]
100%|██████████| 13024/13024 [00:03<00:00, 3991.53it/s]


AUC : 0.24114521374102332


#  Visualize sample predictions

In [18]:
model.to(torch.device('cpu'))
print()




In [19]:
@interact(idx=widgets.IntSlider(min=1, max=100, step=5, value=3),plot_gt= widgets.Checkbox(value=True, description='Ground truth'),
         plot_pred= widgets.Checkbox(value=True, description='Predicted_labels'))
def visualize1(idx, plot_gt, plot_pred):
    sample = val_data[idx]
    joints = sample["joints"]
    img = sample["image"]
    img_input = img.view(([1]+list(img.shape)))
    prediction = model(img_input).view(21,3).detach().numpy()
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    plt.imshow(transforms.ToPILImage()(img))
    if plot_gt:
       plot_hand(ax, joints)
    if plot_pred:
        plot_hand(ax, prediction,linestyle=':' )
    plt.show()

interactive(children=(IntSlider(value=3, description='idx', min=1, step=5), Checkbox(value=True, description='…

In [8]:
@interact(joint_id=widgets.IntSlider(min=0, max=20, step=1, value=3))

def visualize2(joint_id):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    plt.plot(x,y[joint_id])
    plt.title(f"PCK curve, AUC: {auc[joint_id]}")
    plt.plot(x,np.mean(y,axis=0), color='black',linestyle=":" ,alpha=0.5)
    plt.xlabel("error in mm")
    plt.ylabel("Ratio of points below the error")
    plt.legend([f"PCK curve for {joints_mapping[joint_id]}", "Average PCK curve"])
    plt.show()

interactive(children=(IntSlider(value=3, description='joint_id', max=20), Output()), _dom_classes=('widget-int…