In [1]:
import os
import torch
import random
import pprint
import numpy as np
from loguru import logger
from torch.utils.data import DataLoader
from torch import nn

from pose.dataset import pose_dataset
from pose.utils import (
    collate_fn,
    geodesic_distance,
    relative_pose_error,
    relative_pose_error_np,
    recall_object,
    aggregate_metrics
)
from pose.model import Mkpts_Reg_Model
from pose.animator import Animator


if os.name == 'nt':
    LM_dataset_path = 'd:/git_project/POPE/data/LM_dataset/'
    LM_dataset_json_path = 'd:/git_project/POPE/data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'd:/git_project/POPE/data/LM_dataset-points/'

    onepose_path = 'd:/git_project/POPE/data/onepose/'
    onepose_json_path = 'd:/git_project/POPE/data/pairs/Onepose-test.json'
    onepose_points_path = 'd:/git_project/POPE/data/onepose-points/'

    onepose_plusplus_path = 'd:/git_project/POPE/data/onepose_plusplus/'
    onepose_plusplus_json_path = 'd:/git_project/POPE/data/pairs/OneposePlusPlus-test.json'
    onepose_plusplus_points_path = 'd:/git_project/POPE/data/onepose_plusplus-points/'

    ycbv_path = 'd:/git_project/POPE/data/ycbv/'
    ycbv_json_path = 'd:/git_project/POPE/data/pairs/YCB-VIDEO-test.json'
    ycbv_points_path = 'd:/git_project/POPE/data/ycbv-points'
elif os.name == 'posix':
    LM_dataset_path = 'data/LM_dataset/'
    LM_dataset_json_path = 'data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'data/LM_dataset-points/'

    onepose_path = 'data/onepose/'
    onepose_json_path = 'data/pairs/Onepose-test.json'
    onepose_points_path = 'data/onepose-points/'

    onepose_plusplus_path = 'data/onepose_plusplus/'
    onepose_plusplus_json_path = 'data/pairs/OneposePlusPlus-test.json'
    onepose_plusplus_points_path = 'data/onepose_plusplus-points/'

    ycbv_path = 'data/ycbv/'
    ycbv_json_path = 'data/pairs/YCB-VIDEO-test.json'
    ycbv_points_path = 'data/ycbv-points'

paths = [
    # ('linemod', LM_dataset_path, LM_dataset_json_path, LM_dataset_points_path),
    # ('onepose', onepose_path, onepose_json_path, onepose_points_path),
    ('onepose_plusplus', onepose_plusplus_path, onepose_plusplus_json_path, onepose_plusplus_points_path),
    # ('ycbv', ycbv_path, ycbv_json_path, ycbv_points_path),
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = pose_dataset(paths)
mkpts_max_len, mkpts_sum_len = dataset.get_mkpts_info()

 56%|█████▌    | 5/9 [00:03<00:03,  1.29it/s]

data/onepose_plusplus-points/0706-teabox-box/mkpts0/820.png-761.png.txt does not exist


 89%|████████▉ | 8/9 [00:05<00:00,  1.53it/s]

data/onepose_plusplus-points/0712-insta-others/mkpts0/1570.png-125.png.txt does not exist
data/onepose_plusplus-points/0712-insta-others/mkpts0/1605.png-185.png.txt does not exist
data/onepose_plusplus-points/0712-insta-others/mkpts0/1628.png-210.png.txt does not exist


100%|██████████| 9/9 [00:06<00:00,  1.37it/s]


In [2]:
random.seed(20231223)
torch.manual_seed(20231223)
torch.cuda.manual_seed(20231223)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [3]:
num_sample = 300
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))

In [4]:
# linemod
# net = torch.load('./weights/linemod-relative_r-gt_t-quat-300-2024-03-08-16-56-16-0.2183.pth').to(device)
# onepose
# net = torch.load('./weights/onepose-relative_r-gt_t-6d-300-2024-03-02-16-10-39-0.2373.pth').to(device)
# onepose++
net = torch.load('./weights/relative_r-gt_t-6d-300-2024-03-12-23-54-12-0.2483.pth').to(device)
# ycbv
# net = torch.load('./weights/ycbv-relative_r-gt_t-6d-300-2024-03-02-17-01-40-6281.5225.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformerlayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
        )
        (linear1): Linear(in_features=76, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inp

In [5]:
id2name_dict = {
    1: "ape",
    2: "benchvise",
    4: "camera",
    5: "can",
    6: "cat",
    8: "driller",
    9: "duck",
    10: "eggbox",
    11: "glue",
    12: "holepuncher",
    13: "iron",
    14: "lamp",
    15: "phone",
}

ycbv_dict = {
    1: 'one',
    2: 'two',
    3: 'three',
    4: 'four',
    5: 'five',
    6: 'six',
    7: 'seven',
    8: 'eight',
    9: 'nine',
    10: 'ten',
}

In [6]:
# linemod
ape_data = []
benchvise_data = []
camera_data = []
can_data = []
cat_data = []
driller_data = []
duck_data = []
eggbox_data = []
glue_data = []
holepuncher_data = []
iron_data = []
lamp_data = []
phone_data = []
# onepose
aptamil_data = []
jzhg_data = []
minipuff_data = []
hlyormosiapie_data = []
brownhouse_data = []
oreo_data = []
mfmilkcake_data = []
diycookies_data = []
taipingcookies_data = []
tee_data = []
# onepose++
toyrobot_data = []
yellowduck_data = []
sheep_data = []
fakebanana_data = []
teabox_data = []
orange_data = []
greenteapot_data = []
lecreusetcup_data = []
insta_data = []
# ycbv
one_data = []
two_data = []
three_data = []
four_data = []
five_data = []
six_data = []
seven_data = []
eight_data = []
nine_data = []
ten_data = []

all_data = {
    # linemod
    'ape_data': ape_data,
    'benchvise_data': benchvise_data,
    'camera_data': camera_data,
    'can_data': can_data,
    'cat_data': cat_data,
    'driller_data': driller_data,
    'duck_data': duck_data,
    'eggbox_data': eggbox_data,
    'glue_data': glue_data,
    'holepuncher_data': holepuncher_data,
    'iron_data': iron_data,
    'lamp_data': lamp_data,
    'phone_data': phone_data,
    # onepose
    'aptamil_data': aptamil_data,
    'jzhg_data': jzhg_data,
    'minipuff_data': minipuff_data,
    'hlyormosiapie_data': hlyormosiapie_data,
    'brownhouse_data': brownhouse_data,
    'oreo_data': oreo_data,
    'mfmilkcake_data': mfmilkcake_data,
    'diycookies_data': diycookies_data,
    'taipingcookies_data': taipingcookies_data,
    'tee_data': tee_data,
    # onepose++
    'toyrobot_data': toyrobot_data,
    'yellowduck_data': yellowduck_data,
    'sheep_data': sheep_data,
    'fakebanana_data': fakebanana_data,
    'teabox_data': teabox_data,
    'orange_data': orange_data,
    'greenteapot_data': greenteapot_data,
    'lecreusetcup_data': lecreusetcup_data,
    'insta_data': insta_data,
    # ycbv
    'one_data': one_data,
    'two_data': two_data,
    'three_data': three_data,
    'four_data': four_data,
    'five_data': five_data,
    'six_data': six_data,
    'seven_data': seven_data,
    'eight_data': eight_data,
    'nine_data': nine_data,
    'ten_data': ten_data,
}

In [7]:
linemod_type = ['ape_data', 'benchvise_data', 'camera_data', 'can_data', 'cat_data', 'driller_data', 'duck_data', 'eggbox_data', 'glue_data', 'holepuncher_data', 'iron_data', 'lamp_data', 'phone_data']
onepose_type = ['aptamil_data', 'jzhg_data', 'minipuff_data', 'hlyormosiapie_data', 'brownhouse_data', 'oreo_data', 'mfmilkcake_data', 'diycookies_data', 'taipingcookies_data', 'tee_data']
oneposeplusplus_type = ['toyrobot_data', 'yellowduck_data', 'sheep_data', 'fakebanana_data', 'teabox_data', 'orange_data', 'greenteapot_data', 'lecreusetcup_data', 'insta_data']
ycbv_type = ['one_data', 'two_data', 'three_data', 'four_data', 'five_data', 'six_data', 'seven_data', 'eight_data', 'nine_data', 'ten_data']

In [8]:
for i, batch in enumerate(test_dataloader):
    for data in batch:
        if 'lm' in data['name']:
            all_data[f"{id2name_dict[int(data['name'][2:])]}_data"].append(data)
        else:
            if data['name'] in ('12345678910'):
                all_data[f"{ycbv_dict[int(data['name'])]}_data"].append(data)
            else:
                all_data[f"{data['name']}_data"].append(data)

empty_keys = []
for key in all_data.keys():
    if len(all_data[key]) == 0:
        empty_keys.append(key)

for key in empty_keys:
    all_data.pop(key)

for key in all_data.keys():
    print(key, len(all_data[key]))

print('len(all_data):', len(all_data))

toyrobot_data 57
yellowduck_data 66
sheep_data 52
fakebanana_data 48
teabox_data 65
orange_data 75
greenteapot_data 40
lecreusetcup_data 47
insta_data 94
len(all_data): 9


In [9]:
res_table = []

model_type = 'relative_r-gt_t'

for key in all_data.keys():
    if key in linemod_type:
        logger.info(f"LINEMOD: {key}")
    elif key in onepose_type:
        logger.info(f"ONEPOSE: {key}")
    elif key in oneposeplusplus_type:
        logger.info(f"ONEPOSE++: {key}")
    elif key in ycbv_type:
        logger.info(f"YCBV: {key}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    recall_image, all_image = 0, 0
    for item in all_data[key]:
        all_image += 1
        K0 = item['K0']
        K1 = item['K1']
        pose0 = item['pose0']
        pose1 = item['pose1']
        pre_bbox = item['pre_bbox']
        gt_bbox = item['gt_bbox']
        mkpts0 = item['mkpts0']
        mkpts1 = item['mkpts1']
        pre_K = item['pre_K']
        name = item['name']
        pair_name = item['pair_name']
        # linemod
        if 'lm' in name:
            name = id2name_dict[int(name[2:])]
        # ycbv
        if name in '12345678910':
            name = ycbv_dict[int(name)]

        if name not in key:
            print(f'name: {name}, key: {key}')
            continue

        is_recalled = recall_object(pre_bbox, gt_bbox)

        recall_image = recall_image + int(is_recalled > 0.5)

        batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
        batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
        pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1)
        pre_t = pre_t.squeeze(0).detach().cpu().numpy()
        pre_rot = pre_rot.squeeze(0).detach().cpu().numpy()

        if model_type == 'gt':
            t_err, R_err = relative_pose_error_np(pose1, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            t_err, R_err = relative_pose_error_np(relative_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative_r-gt_t':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            gt_pose = np.zeros_like(pose1)
            gt_pose[:3, :3] = relative_pose[:3, :3]
            gt_pose[:3, 3] = pose1[:3, 3]
            t_err, R_err = relative_pose_error_np(gt_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)

        metrics['R_errs'].append(R_err)
        metrics['t_errs'].append(t_err)
        metrics['identifiers'].append(pair_name)

    print(f"Acc: {recall_image}/{all_image}")
    val_metrics_4tb = aggregate_metrics(metrics, 5e-4)
    val_metrics_4tb["AP50"] = recall_image / all_image
    logger.info('\n' + pprint.pformat(val_metrics_4tb))
    res_table.append([f"{name}"] + list(val_metrics_4tb.values()))

[32m2024-03-13 00:48:02.413[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: toyrobot_data[0m
[32m2024-03-13 00:48:02.832[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 57 unique items...[0m
[32m2024-03-13 00:48:02.834[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.9824561403508771,
 'R:ACC15': 0.5087719298245614,
 'R:ACC30': 0.9824561403508771,
 'R:auc@15': 0.20297678838112,
 'R:auc@30': 0.48179159283018097,
 'R:meanErr': 15.842011862003424,
 'R:medianErr': 14.453136257675741,
 't:ACC15': 0.9473684210526315,
 't:ACC30': 1.0,
 't:auc@15': 0.4657032176837489,
 't:auc@30': 0.727030897983841,
 't:meanErr': 8.383617111102335,
 't:medianErr': 8.018844976112238}[0m
[32m2024-03-13 00:48:02.835[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: yellowduck_data[0m


Acc: 56/57


[32m2024-03-13 00:48:03.168[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 66 unique items...[0m
[32m2024-03-13 00:48:03.170[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 1.0,
 'R:ACC15': 0.6060606060606061,
 'R:ACC30': 0.9696969696969697,
 'R:auc@15': 0.2984369333909168,
 'R:auc@30': 0.5841621653071588,
 'R:meanErr': 12.741589157544135,
 'R:medianErr': 12.828962824503257,
 't:ACC15': 0.9696969696969697,
 't:ACC30': 1.0,
 't:auc@15': 0.5518080398809759,
 't:auc@30': 0.7750880579396106,
 't:meanErr': 6.885326730929353,
 't:medianErr': 6.655576742255706}[0m
[32m2024-03-13 00:48:03.170[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: sheep_data[0m


Acc: 66/66


[32m2024-03-13 00:48:03.433[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 52 unique items...[0m
[32m2024-03-13 00:48:03.434[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.36538461538461536,
 'R:ACC15': 0.5384615384615384,
 'R:ACC30': 0.9423076923076923,
 'R:auc@15': 0.26540674959730526,
 'R:auc@30': 0.49370699754714614,
 'R:meanErr': 15.876070145436223,
 'R:medianErr': 13.945945228978177,
 't:ACC15': 0.9423076923076923,
 't:ACC30': 1.0,
 't:auc@15': 0.5854631636768445,
 't:auc@30': 0.7860240272615172,
 't:meanErr': 6.6888499872237,
 't:medianErr': 5.530322712702267}[0m
[32m2024-03-13 00:48:03.435[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: fakebanana_data[0m


Acc: 19/52


[32m2024-03-13 00:48:03.677[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 48 unique items...[0m
[32m2024-03-13 00:48:03.678[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.9791666666666666,
 'R:ACC15': 0.5,
 'R:ACC30': 0.9375,
 'R:auc@15': 0.17332172304937277,
 'R:auc@30': 0.470436884023851,
 'R:meanErr': 16.40575413574921,
 'R:medianErr': 14.823982375801897,
 't:ACC15': 0.9583333333333334,
 't:ACC30': 1.0,
 't:auc@15': 0.5192690459662292,
 't:auc@30': 0.7587681692347599,
 't:meanErr': 7.416296215210778,
 't:medianErr': 6.952786227097967}[0m
[32m2024-03-13 00:48:03.679[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: teabox_data[0m


Acc: 47/48


[32m2024-03-13 00:48:04.007[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 65 unique items...[0m
[32m2024-03-13 00:48:04.008[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.9384615384615385,
 'R:ACC15': 0.5538461538461539,
 'R:ACC30': 0.9692307692307692,
 'R:auc@15': 0.29965172011741764,
 'R:auc@30': 0.5544385138773578,
 'R:meanErr': 13.78980077120767,
 'R:medianErr': 13.280419293390842,
 't:ACC15': 0.9846153846153847,
 't:ACC30': 1.0,
 't:auc@15': 0.4807467854717978,
 't:auc@30': 0.7403752592078569,
 't:meanErr': 7.904814717660415,
 't:medianErr': 7.6722868045121}[0m
[32m2024-03-13 00:48:04.009[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: orange_data[0m


Acc: 61/65


[32m2024-03-13 00:48:04.389[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 75 unique items...[0m
[32m2024-03-13 00:48:04.390[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.9866666666666667,
 'R:ACC15': 0.5733333333333334,
 'R:ACC30': 0.9733333333333334,
 'R:auc@15': 0.23526939481390752,
 'R:auc@30': 0.5292313588317142,
 'R:meanErr': 14.578011954926891,
 'R:medianErr': 13.534363443821134,
 't:ACC15': 0.9866666666666667,
 't:ACC30': 1.0,
 't:auc@15': 0.6454604736085439,
 't:auc@30': 0.8216590126350332,
 't:meanErr': 5.490509977669284,
 't:medianErr': 4.835217165093264}[0m
[32m2024-03-13 00:48:04.391[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: greenteapot_data[0m


Acc: 74/75


[32m2024-03-13 00:48:04.597[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 40 unique items...[0m
[32m2024-03-13 00:48:04.598[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.925,
 'R:ACC15': 0.5,
 'R:ACC30': 0.95,
 'R:auc@15': 0.22146638491559967,
 'R:auc@30': 0.4874353741818929,
 'R:meanErr': 15.870033111994285,
 'R:medianErr': 15.29933426565154,
 't:ACC15': 0.975,
 't:ACC30': 1.0,
 't:auc@15': 0.4493896827103111,
 't:auc@30': 0.7248295640251455,
 't:meanErr': 8.455589935586321,
 't:medianErr': 8.749679991707387}[0m
[32m2024-03-13 00:48:04.599[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: lecreusetcup_data[0m


Acc: 37/40


[32m2024-03-13 00:48:04.836[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 47 unique items...[0m
[32m2024-03-13 00:48:04.838[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.851063829787234,
 'R:ACC15': 0.5957446808510638,
 'R:ACC30': 0.9787234042553191,
 'R:auc@15': 0.279969254710428,
 'R:auc@30': 0.545231469097702,
 'R:meanErr': 14.268790332489079,
 'R:medianErr': 13.753004609111962,
 't:ACC15': 0.8297872340425532,
 't:ACC30': 1.0,
 't:auc@15': 0.4008195398236183,
 't:auc@30': 0.688081868702586,
 't:meanErr': 9.562268698301692,
 't:medianErr': 8.94957894454996}[0m
[32m2024-03-13 00:48:04.838[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: insta_data[0m


Acc: 40/47


[32m2024-03-13 00:48:05.314[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 94 unique items...[0m
[32m2024-03-13 00:48:05.315[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m
{'AP50': 0.5638297872340425,
 'R:ACC15': 0.43617021276595747,
 'R:ACC30': 0.8829787234042553,
 'R:auc@15': 0.19973920487606375,
 'R:auc@30': 0.44601708260661554,
 'R:meanErr': 17.681497734857814,
 'R:medianErr': 17.02191408933925,
 't:ACC15': 0.8936170212765957,
 't:ACC30': 1.0,
 't:auc@15': 0.44198483702584074,
 't:auc@30': 0.7089585433221335,
 't:meanErr': 8.879315460452167,
 't:medianErr': 8.08998651100503}[0m


Acc: 53/94


In [10]:
for i, key in enumerate(all_data.keys()):
    if i == 1:
        break
    if key in linemod_type:
        print(f"LINEMOD")
    elif key in onepose_type:
        print(f"ONEPOSE")
    elif key in oneposeplusplus_type:
        print(f"ONEPOSE++")
    elif key in ycbv_type:
        print(f"YCBV")

from tabulate import tabulate
headers = ["Category"] + list(val_metrics_4tb.keys())
all_data = np.array(res_table)[:, 1:].astype(np.float32)
res_table.append(["Avg"] + all_data.mean(0).tolist())
print(tabulate(res_table, headers=headers, tablefmt='fancy_grid'))

ONEPOSE++
╒══════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤═════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤═════════════╤══════════╕
│ Category     │   R:auc@15 │   R:auc@30 │   R:ACC15 │   R:ACC30 │   R:medianErr │   R:meanErr │   t:auc@15 │   t:auc@30 │   t:ACC15 │   t:ACC30 │   t:medianErr │   t:meanErr │     AP50 │
╞══════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪═════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪═════════════╪══════════╡
│ toyrobot     │   0.202977 │   0.481792 │  0.508772 │  0.982456 │       14.4531 │     15.842  │   0.465703 │   0.727031 │  0.947368 │         1 │       8.01884 │     8.38362 │ 0.982456 │
├──────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼─────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼─────────────┼──────────┤
│ yellowduck   │   0.298437 │   0.584162 │  0.6060