In [1]:
import os
import cv2
import json
import torch
import pprint
import numpy as np
from tqdm import tqdm
from loguru import logger
from tabulate import tabulate

from pose.utils import collate_fn, geodesic_distance, relative_pose_error, aggregate_metrics, recall_object, project_points

In [2]:
with open("data/pairs/Onepose-test.json") as f:
    dir_list = json.load(f)
len(dir_list)

10

In [3]:
if os.name == 'unix':
    ROOT_DIR = 'data/onepose/'
elif os.name == 'nt':
    ROOT_DIR = 'e:/datasets/OnePose/test_data/'

In [4]:
res_table = []

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

num_sample = 3000

net = torch.load('./weights/onepose-6d-3000-2024-01-05-22-20-02-0.9768.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformerlayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
        )
        (linear1): Linear(in_features=76, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inp

In [6]:
for label_idx, test_dict in enumerate(dir_list):
    logger.info(f"Onepose: {label_idx + 1}/{len(dir_list)}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    sample_data = dir_list[label_idx]["0"][0]
    label = sample_data.split("/")[0]
    name = label.split("-")[1]
    dir_name = os.path.dirname(sample_data)
    FULL_ROOT_DIR = os.path.join(ROOT_DIR, dir_name)
    recall_image, all_image = 0, 0
    for rotation_key, rotation_list in zip(test_dict.keys(), test_dict.values()):
        for pair_idx, pair_name in enumerate(tqdm(rotation_list)):
            all_image = all_image + 1
            base_name = os.path.basename(pair_name)
            idx0_name = base_name.split("-")[0]
            idx1_name = base_name.split("-")[1]
            image0_name = os.path.join(FULL_ROOT_DIR, idx0_name)
            image1_name = os.path.join(FULL_ROOT_DIR.replace("color", "color"), idx1_name)

            K0_path = image0_name.replace("color", "intrin_ba").replace("png", "txt")
            K1_path = image1_name.replace("color", "intrin_ba").replace("png", "txt")
            K0 = np.loadtxt(K0_path)
            K1 = np.loadtxt(K1_path)

            pose0_path = image0_name.replace("color", "poses_ba").replace("png", "txt")
            pose1_path = image1_name.replace("color", "poses_ba").replace("png", "txt")
            pose0 = np.loadtxt(pose0_path)
            pose1 = np.loadtxt(pose1_path)
            if pose0.shape[0] == 3:
                pose0 = np.concatenate([pose0, np.array([[0, 0, 0, 1]])], axis=0)
                pose1 = np.concatenate([pose1, np.array([[0, 0, 0, 1]])], axis=0)

            points_file_path = os.path.join('d:/git_project/POPE/data/onepose-points/', pair_name.split("/")[0])
            pre_bbox_path = os.path.join(points_file_path, "pre_bbox")
            mkpts0_path = os.path.join(points_file_path, "mkpts0")
            mkpts1_path = os.path.join(points_file_path, "mkpts1")
            pre_K_path = os.path.join(points_file_path, "pre_K")
            points_name = pair_name.split("/")[-1]
            pre_bbox_path = os.path.join(pre_bbox_path, f'{points_name}.txt')
            mkpts0_path = os.path.join(mkpts0_path, f'{points_name}.txt')
            mkpts1_path = os.path.join(mkpts1_path, f'{points_name}.txt')
            pre_K_path = os.path.join(pre_K_path, f'{points_name}.txt')

            if not os.path.exists(pre_bbox_path):
                continue
            pre_bbox = np.loadtxt(pre_bbox_path)
            mkpts0 = np.loadtxt(mkpts0_path)
            mkpts1 = np.loadtxt(mkpts1_path)
            pre_K = np.loadtxt(pre_K_path)

            if mkpts0.shape[0] > num_sample:
                rand_idx = np.random.choice(mkpts0.shape[0], num_sample, replace=False)
                mkpts0 = mkpts0[rand_idx]
                mkpts1 = mkpts1[rand_idx]
            else:
                mkpts0 = np.concatenate([mkpts0, np.zeros((num_sample - mkpts0.shape[0], 2))], axis=0)
                mkpts1 = np.concatenate([mkpts1, np.zeros((num_sample - mkpts1.shape[0], 2))], axis=0)

            _3d_bbox = np.loadtxt(f"{os.path.join(ROOT_DIR, label)}/box3d_corners.txt")
            bbox_pts_3d, _ = project_points(_3d_bbox, pose1[:3, :4], K1)
            bbox_pts_3d = bbox_pts_3d.astype(np.int32)
            x0, y0, w, h = cv2.boundingRect(bbox_pts_3d)
            x1, y1 = x0 + w, y0 + h
            gt_bbox = np.array([x0, y0, x1, y1])
            is_recalled = recall_object(pre_bbox, gt_bbox)
            recall_image = recall_image + int(is_recalled > 0.5)

            batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
            batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
            pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1)
            # pre_t = pre_t.cpu()
            # pre_rot = pre_rot.cpu()

            batch_pose0 = torch.from_numpy(pose0).unsqueeze(0).float().to(device)
            batch_pose1 = torch.from_numpy(pose1).unsqueeze(0).float().to(device)
            # batch_relative_pose = torch.matmul(batch_pose1, batch_pose0.permute(0, 2, 1))
            t_err, R_err = relative_pose_error(batch_pose1, pre_rot, pre_t, ignore_gt_t_thr=0.0)

            metrics['t_errs'] = metrics['t_errs'] + np.array(t_err.reshape(-1).cpu().detach().numpy()).tolist()
            metrics['R_errs'] = metrics['R_errs'] + np.array(R_err.reshape(-1).cpu().detach().numpy()).tolist()
            metrics['identifiers'].append(pair_name)

    print(f'Acc: {recall_image}/{all_image}')
    val_metrics_4tb = aggregate_metrics(metrics, 5e-4)
    val_metrics_4tb['AP50'] = recall_image / all_image
    logger.info('\n' + pprint.pformat(val_metrics_4tb))

    res_table.append([f"{name}"] + list(val_metrics_4tb.values()))

[32m2024-01-05 22:20:52.202[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 1/10[0m
100%|██████████| 66/66 [00:03<00:00, 18.37it/s]
100%|██████████| 120/120 [00:02<00:00, 43.39it/s]
100%|██████████| 107/107 [00:02<00:00, 48.38it/s]
100%|██████████| 59/59 [00:01<00:00, 48.70it/s]
100%|██████████| 41/41 [00:00<00:00, 49.38it/s]
100%|██████████| 46/46 [00:00<00:00, 50.68it/s]
[32m2024-01-05 22:21:03.746[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.9157175398633257,
 'R:ACC15': 0.1813953488372093,
 'R:ACC30': 0.3930232558139535,
 'R:auc@15': 0.05920264558274617,
 'R:auc@30': 0.18180772428364717,
 'R:medianErr': 44.934865951538086,
 't:ACC15': 0.8465116279069768,
 't:ACC30': 0.9883720930232558,
 't:auc@15': 0.4008286370510279,
 't:auc@30': 0.6762125902074252,
 't:medianErr': 8.838622093200684}[0m
[32m2024-01-05 22:21:03.747[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mO

Acc: 402/439


100%|██████████| 45/45 [00:01<00:00, 42.60it/s]
100%|██████████| 91/91 [00:02<00:00, 42.70it/s]
100%|██████████| 62/62 [00:01<00:00, 49.36it/s]
100%|██████████| 28/28 [00:00<00:00, 50.28it/s]
100%|██████████| 34/34 [00:00<00:00, 49.68it/s]
100%|██████████| 34/34 [00:00<00:00, 52.97it/s]
[32m2024-01-05 22:21:10.090[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.9149659863945578,
 'R:ACC15': 0.17687074829931973,
 'R:ACC30': 0.35034013605442177,
 'R:auc@15': 0.05659367025844634,
 'R:auc@30': 0.17643143735369857,
 'R:medianErr': 42.62068176269531,
 't:ACC15': 0.7687074829931972,
 't:ACC30': 0.9965986394557823,
 't:auc@15': 0.3033409126626661,
 't:auc@30': 0.6143591281981154,
 't:medianErr': 10.878254890441895}[0m
[32m2024-01-05 22:21:10.091[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 3/10[0m


Acc: 269/294


100%|██████████| 30/30 [00:00<00:00, 49.16it/s]
100%|██████████| 40/40 [00:00<00:00, 50.52it/s]
100%|██████████| 43/43 [00:00<00:00, 49.94it/s]
100%|██████████| 39/39 [00:00<00:00, 48.80it/s]
100%|██████████| 45/45 [00:00<00:00, 48.83it/s]
100%|██████████| 30/30 [00:00<00:00, 48.69it/s]
[32m2024-01-05 22:21:14.709[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.8986784140969163,
 'R:ACC15': 0.2643171806167401,
 'R:ACC30': 0.47577092511013214,
 'R:auc@15': 0.11041578903422308,
 'R:auc@30': 0.23889460451627442,
 'R:medianErr': 35.80271530151367,
 't:ACC15': 0.6475770925110133,
 't:ACC30': 0.9823788546255506,
 't:auc@15': 0.25430434156269405,
 't:auc@30': 0.5653903112712586,
 't:medianErr': 12.093002319335938}[0m
[32m2024-01-05 22:21:14.710[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 4/10[0m


Acc: 204/227


100%|██████████| 48/48 [00:01<00:00, 47.32it/s]
100%|██████████| 51/51 [00:01<00:00, 45.43it/s]
100%|██████████| 66/66 [00:01<00:00, 39.81it/s]
100%|██████████| 39/39 [00:00<00:00, 54.57it/s]
100%|██████████| 41/41 [00:00<00:00, 60.69it/s]
100%|██████████| 29/29 [00:00<00:00, 41.03it/s]
[32m2024-01-05 22:21:20.621[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.9562043795620438,
 'R:ACC15': 0.24817518248175183,
 'R:ACC30': 0.4416058394160584,
 'R:auc@15': 0.0949194867535519,
 'R:auc@30': 0.22368410973653308,
 'R:medianErr': 41.4505672454834,
 't:ACC15': 0.5364963503649635,
 't:ACC30': 0.9817518248175182,
 't:auc@15': 0.17777169272847418,
 't:auc@30': 0.5038098012443876,
 't:medianErr': 14.358899116516113}[0m
[32m2024-01-05 22:21:20.622[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 5/10[0m


Acc: 262/274


100%|██████████| 37/37 [00:00<00:00, 50.11it/s]
100%|██████████| 30/30 [00:00<00:00, 45.82it/s]
100%|██████████| 41/41 [00:01<00:00, 36.31it/s]
100%|██████████| 48/48 [00:01<00:00, 45.11it/s]
100%|██████████| 34/34 [00:00<00:00, 46.55it/s]
100%|██████████| 60/60 [00:01<00:00, 48.56it/s]
[32m2024-01-05 22:21:26.197[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.796,
 'R:ACC15': 0.036,
 'R:ACC30': 0.12,
 'R:auc@15': 0.010312250391642254,
 'R:auc@30': 0.043143032010396314,
 'R:medianErr': 105.38240432739258,
 't:ACC15': 0.688,
 't:ACC30': 0.98,
 't:auc@15': 0.3045899465958277,
 't:auc@30': 0.5909591375867527,
 't:medianErr': 10.826921463012695}[0m
[32m2024-01-05 22:21:26.198[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 6/10[0m


Acc: 199/250


100%|██████████| 63/63 [00:01<00:00, 44.51it/s]
100%|██████████| 53/53 [00:01<00:00, 43.77it/s]
100%|██████████| 45/45 [00:00<00:00, 45.39it/s]
100%|██████████| 50/50 [00:01<00:00, 45.96it/s]
100%|██████████| 35/35 [00:00<00:00, 45.59it/s]
100%|██████████| 36/36 [00:00<00:00, 47.63it/s]
[32m2024-01-05 22:21:32.448[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.9148936170212766,
 'R:ACC15': 0.1347517730496454,
 'R:ACC30': 0.6453900709219859,
 'R:auc@15': 0.044481218622085886,
 'R:auc@30': 0.2233908191076689,
 'R:medianErr': 25.230935096740723,
 't:ACC15': 0.7411347517730497,
 't:ACC30': 0.9964539007092199,
 't:auc@15': 0.34260870128095006,
 't:auc@30': 0.6337543243916604,
 't:medianErr': 10.14711332321167}[0m
[32m2024-01-05 22:21:32.449[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 7/10[0m


Acc: 258/282


100%|██████████| 37/37 [00:00<00:00, 45.29it/s]
100%|██████████| 80/80 [00:01<00:00, 44.91it/s]
100%|██████████| 109/109 [00:02<00:00, 46.32it/s]
100%|██████████| 55/55 [00:01<00:00, 47.97it/s]
100%|██████████| 43/43 [00:00<00:00, 44.01it/s]
100%|██████████| 40/40 [00:00<00:00, 42.91it/s]
[32m2024-01-05 22:21:40.479[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.9065934065934066,
 'R:ACC15': 0.10714285714285714,
 'R:ACC30': 0.2774725274725275,
 'R:auc@15': 0.04718605031023969,
 'R:auc@30': 0.11060487333234853,
 'R:medianErr': 57.38472938537598,
 't:ACC15': 0.6813186813186813,
 't:ACC30': 0.9752747252747253,
 't:auc@15': 0.28551376809130663,
 't:auc@30': 0.5764867969484994,
 't:medianErr': 11.395231246948242}[0m
[32m2024-01-05 22:21:40.479[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 8/10[0m


Acc: 330/364


100%|██████████| 76/76 [00:01<00:00, 43.26it/s]
100%|██████████| 55/55 [00:01<00:00, 43.32it/s]
100%|██████████| 62/62 [00:01<00:00, 46.14it/s]
100%|██████████| 44/44 [00:00<00:00, 44.68it/s]
100%|██████████| 43/43 [00:00<00:00, 47.17it/s]
100%|██████████| 25/25 [00:00<00:00, 48.71it/s]
[32m2024-01-05 22:21:47.281[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.8459016393442623,
 'R:ACC15': 0.18360655737704917,
 'R:ACC30': 0.43934426229508194,
 'R:auc@15': 0.06234972561643423,
 'R:auc@30': 0.19834973470109407,
 'R:medianErr': 40.006710052490234,
 't:ACC15': 0.45901639344262296,
 't:ACC30': 0.9016393442622951,
 't:auc@15': 0.19149427989792953,
 't:auc@30': 0.4556019436596522,
 't:medianErr': 15.865317344665527}[0m
[32m2024-01-05 22:21:47.281[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 9/10[0m


Acc: 258/305


100%|██████████| 45/45 [00:00<00:00, 45.69it/s]
100%|██████████| 65/65 [00:01<00:00, 44.96it/s]
100%|██████████| 59/59 [00:01<00:00, 46.96it/s]
100%|██████████| 67/67 [00:01<00:00, 46.70it/s]
100%|██████████| 57/57 [00:01<00:00, 47.01it/s]
100%|██████████| 34/34 [00:00<00:00, 47.09it/s]
[32m2024-01-05 22:21:54.356[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.5535168195718655,
 'R:ACC15': 0.2324159021406728,
 'R:ACC30': 0.5412844036697247,
 'R:auc@15': 0.06493458193741078,
 'R:auc@30': 0.2412858520443651,
 'R:medianErr': 26.727766036987305,
 't:ACC15': 0.7889908256880734,
 't:ACC30': 0.9908256880733946,
 't:auc@15': 0.34665656374438947,
 't:auc@30': 0.6463109103805186,
 't:medianErr': 10.336851119995117}[0m
[32m2024-01-05 22:21:54.357[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 10/10[0m


Acc: 181/327


100%|██████████| 55/55 [00:01<00:00, 40.16it/s]
100%|██████████| 98/98 [00:02<00:00, 43.42it/s]
100%|██████████| 87/87 [00:01<00:00, 46.92it/s]
100%|██████████| 87/87 [00:01<00:00, 46.51it/s]
100%|██████████| 45/45 [00:00<00:00, 48.34it/s]
100%|██████████| 32/32 [00:00<00:00, 47.65it/s]
[32m2024-01-05 22:22:03.333[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1m
{'AP50': 0.7475247524752475,
 'R:ACC15': 0.15346534653465346,
 'R:ACC30': 0.6089108910891089,
 'R:auc@15': 0.03759496133319616,
 'R:auc@30': 0.23101337211753675,
 'R:medianErr': 24.494354248046875,
 't:ACC15': 0.7153465346534653,
 't:ACC30': 0.9752475247524752,
 't:auc@15': 0.27836935661806916,
 't:auc@30': 0.5895678637838206,
 't:medianErr': 11.17952299118042}[0m


Acc: 302/404


In [7]:
headers = ["Category"] + list(val_metrics_4tb.keys())
all_data = np.array(res_table)[:, 1:].astype(np.float32)
res_table.append(["Avg"] + all_data.mean(0).tolist())
print(tabulate(res_table, headers=headers, tablefmt='fancy_grid'))

╒════════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤══════════╕
│ Category       │   R:auc@15 │   R:auc@30 │   R:ACC15 │   R:ACC30 │   R:medianErr │   t:auc@15 │   t:auc@30 │   t:ACC15 │   t:ACC30 │   t:medianErr │     AP50 │
╞════════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪══════════╡
│ aptamil        │  0.0592026 │   0.181808 │  0.181395 │  0.393023 │       44.9349 │   0.400829 │   0.676213 │  0.846512 │  0.988372 │       8.83862 │ 0.915718 │
├────────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼──────────┤
│ jzhg           │  0.0565937 │   0.176431 │  0.176871 │  0.35034  │       42.6207 │   0.303341 │   0.614359 │  0.768707 │  0.996599 │      10.8783  │ 0.914966 │
├────────────────┼──────────