In [1]:
import os
import cv2
import json
import torch
import pprint
import numpy as np
from tqdm import tqdm
from loguru import logger
from tabulate import tabulate

from pose.utils import collate_fn, geodesic_distance, relative_pose_error, aggregate_metrics, recall_object, project_points

In [2]:
with open("data/pairs/Onepose-test.json") as f:
    dir_list = json.load(f)
len(dir_list)

10

In [3]:
if os.name == 'unix':
    ROOT_DIR = 'data/onepose/'
elif os.name == 'nt':
    ROOT_DIR = 'e:/datasets/OnePose/test_data/'

In [4]:
res_table = []

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

num_sample = 3000

net = torch.load('./weights/onepose-6d-3000-2024-01-06-00-05-51-0.2705.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformerlayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
        )
        (linear1): Linear(in_features=76, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inp

In [6]:
for label_idx, test_dict in enumerate(dir_list):
    logger.info(f"Onepose: {label_idx + 1}/{len(dir_list)}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    sample_data = dir_list[label_idx]["0"][0]
    label = sample_data.split("/")[0]
    name = label.split("-")[1]
    dir_name = os.path.dirname(sample_data)
    FULL_ROOT_DIR = os.path.join(ROOT_DIR, dir_name)
    recall_image, all_image = 0, 0
    for rotation_key, rotation_list in zip(test_dict.keys(), test_dict.values()):
        for pair_idx, pair_name in enumerate(tqdm(rotation_list)):
            all_image = all_image + 1
            base_name = os.path.basename(pair_name)
            idx0_name = base_name.split("-")[0]
            idx1_name = base_name.split("-")[1]
            image0_name = os.path.join(FULL_ROOT_DIR, idx0_name)
            image1_name = os.path.join(FULL_ROOT_DIR.replace("color", "color"), idx1_name)

            K0_path = image0_name.replace("color", "intrin_ba").replace("png", "txt")
            K1_path = image1_name.replace("color", "intrin_ba").replace("png", "txt")
            K0 = np.loadtxt(K0_path)
            K1 = np.loadtxt(K1_path)

            pose0_path = image0_name.replace("color", "poses_ba").replace("png", "txt")
            pose1_path = image1_name.replace("color", "poses_ba").replace("png", "txt")
            pose0 = np.loadtxt(pose0_path)
            pose1 = np.loadtxt(pose1_path)
            if pose0.shape[0] == 3:
                pose0 = np.concatenate([pose0, np.array([[0, 0, 0, 1]])], axis=0)
                pose1 = np.concatenate([pose1, np.array([[0, 0, 0, 1]])], axis=0)

            points_file_path = os.path.join('d:/git_project/POPE/data/onepose-points/', pair_name.split("/")[0])
            pre_bbox_path = os.path.join(points_file_path, "pre_bbox")
            mkpts0_path = os.path.join(points_file_path, "mkpts0")
            mkpts1_path = os.path.join(points_file_path, "mkpts1")
            pre_K_path = os.path.join(points_file_path, "pre_K")
            points_name = pair_name.split("/")[-1]
            pre_bbox_path = os.path.join(pre_bbox_path, f'{points_name}.txt')
            mkpts0_path = os.path.join(mkpts0_path, f'{points_name}.txt')
            mkpts1_path = os.path.join(mkpts1_path, f'{points_name}.txt')
            pre_K_path = os.path.join(pre_K_path, f'{points_name}.txt')

            if not os.path.exists(pre_bbox_path):
                continue
            pre_bbox = np.loadtxt(pre_bbox_path)
            mkpts0 = np.loadtxt(mkpts0_path)
            mkpts1 = np.loadtxt(mkpts1_path)
            pre_K = np.loadtxt(pre_K_path)

            if mkpts0.shape[0] > num_sample:
                rand_idx = np.random.choice(mkpts0.shape[0], num_sample, replace=False)
                mkpts0 = mkpts0[rand_idx]
                mkpts1 = mkpts1[rand_idx]
            else:
                mkpts0 = np.concatenate([mkpts0, np.zeros((num_sample - mkpts0.shape[0], 2))], axis=0)
                mkpts1 = np.concatenate([mkpts1, np.zeros((num_sample - mkpts1.shape[0], 2))], axis=0)

            _3d_bbox = np.loadtxt(f"{os.path.join(ROOT_DIR, label)}/box3d_corners.txt")
            bbox_pts_3d, _ = project_points(_3d_bbox, pose1[:3, :4], K1)
            bbox_pts_3d = bbox_pts_3d.astype(np.int32)
            x0, y0, w, h = cv2.boundingRect(bbox_pts_3d)
            x1, y1 = x0 + w, y0 + h
            gt_bbox = np.array([x0, y0, x1, y1])
            is_recalled = recall_object(pre_bbox, gt_bbox)
            recall_image = recall_image + int(is_recalled > 0.5)

            batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
            batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
            pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1)
            # pre_t = pre_t.cpu()
            # pre_rot = pre_rot.cpu()

            # batch_pose0 = torch.from_numpy(pose0).unsqueeze(0).float().to(device)
            # batch_pose1 = torch.from_numpy(pose1).unsqueeze(0).float().to(device)
            # batch_relative_pose = torch.matmul(batch_pose1, batch_pose0.permute(0, 2, 1))
            batch_relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            batch_relative_pose = torch.from_numpy(batch_relative_pose).unsqueeze(0).float().to(device)
            t_err, R_err = relative_pose_error(batch_relative_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)

            metrics['t_errs'] = metrics['t_errs'] + np.array(t_err.reshape(-1).cpu().detach().numpy()).tolist()
            metrics['R_errs'] = metrics['R_errs'] + np.array(R_err.reshape(-1).cpu().detach().numpy()).tolist()
            metrics['identifiers'].append(pair_name)

    print(f'Acc: {recall_image}/{all_image}')
    val_metrics_4tb = aggregate_metrics(metrics, 5e-4)
    val_metrics_4tb['AP50'] = recall_image / all_image
    logger.info('\n' + pprint.pformat(val_metrics_4tb))

    res_table.append([f"{name}"] + list(val_metrics_4tb.values()))

[32m2024-01-06 00:11:01.735[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 1/10[0m
100%|██████████| 66/66 [00:02<00:00, 22.33it/s]
100%|██████████| 120/120 [00:02<00:00, 55.97it/s]
100%|██████████| 107/107 [00:01<00:00, 57.27it/s]
100%|██████████| 59/59 [00:01<00:00, 57.06it/s]
100%|██████████| 41/41 [00:00<00:00, 59.86it/s]
100%|██████████| 46/46 [00:00<00:00, 61.39it/s]
[32m2024-01-06 00:11:11.189[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9157175398633257,
 'R:ACC15': 0.6767441860465117,
 'R:ACC30': 0.9837209302325581,
 'R:auc@15': 0.3043439340868661,
 'R:auc@30': 0.5808061372725538,
 'R:medianErr': 10.867824077606201,
 't:ACC15': 0.027906976744186046,
 't:ACC30': 0.15348837209302327,
 't:auc@15': 0.01025364787079567,
 't:auc@30': 0.05103212903636371,
 't:medianErr': 57.68374443054199}[0m
[32m2024-01-06 00:11:11.190[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [

Acc: 402/439


100%|██████████| 45/45 [00:00<00:00, 54.91it/s]
100%|██████████| 91/91 [00:01<00:00, 55.16it/s]
100%|██████████| 62/62 [00:01<00:00, 56.27it/s]
100%|██████████| 28/28 [00:00<00:00, 57.66it/s]
100%|██████████| 34/34 [00:00<00:00, 55.79it/s]
100%|██████████| 34/34 [00:00<00:00, 55.49it/s]
[32m2024-01-06 00:11:16.485[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9149659863945578,
 'R:ACC15': 0.6972789115646258,
 'R:ACC30': 0.9931972789115646,
 'R:auc@15': 0.3344629462816818,
 'R:auc@30': 0.6027458145471115,
 'R:medianErr': 10.290589809417725,
 't:ACC15': 0.04421768707482993,
 't:ACC30': 0.1292517006802721,
 't:auc@15': 0.01186611463153173,
 't:auc@30': 0.05169281072897706,
 't:medianErr': 59.70263671875}[0m
[32m2024-01-06 00:11:16.486[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 3/10[0m


Acc: 269/294


100%|██████████| 30/30 [00:00<00:00, 53.38it/s]
100%|██████████| 40/40 [00:00<00:00, 53.61it/s]
100%|██████████| 43/43 [00:00<00:00, 53.89it/s]
100%|██████████| 39/39 [00:00<00:00, 53.20it/s]
100%|██████████| 45/45 [00:00<00:00, 57.15it/s]
100%|██████████| 30/30 [00:00<00:00, 55.51it/s]
[32m2024-01-06 00:11:20.669[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.8986784140969163,
 'R:ACC15': 0.5506607929515418,
 'R:ACC30': 0.9911894273127754,
 'R:auc@15': 0.25215038777097903,
 'R:auc@30': 0.5316885529810812,
 'R:medianErr': 13.477045059204102,
 't:ACC15': 0.03524229074889868,
 't:ACC30': 0.15418502202643172,
 't:auc@15': 0.011481180345084349,
 't:auc@30': 0.054062740113066506,
 't:medianErr': 55.76251983642578}[0m
[32m2024-01-06 00:11:20.669[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 4/10[0m


Acc: 204/227


100%|██████████| 48/48 [00:00<00:00, 53.76it/s]
100%|██████████| 51/51 [00:00<00:00, 55.13it/s]
100%|██████████| 66/66 [00:01<00:00, 56.08it/s]
100%|██████████| 39/39 [00:00<00:00, 56.43it/s]
100%|██████████| 41/41 [00:00<00:00, 57.36it/s]
100%|██████████| 29/29 [00:00<00:00, 58.75it/s]
[32m2024-01-06 00:11:25.579[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9562043795620438,
 'R:ACC15': 0.635036496350365,
 'R:ACC30': 1.0,
 'R:auc@15': 0.3020638002821418,
 'R:auc@30': 0.5761927084258583,
 'R:medianErr': 11.825913906097412,
 't:ACC15': 0.04744525547445255,
 't:ACC30': 0.1678832116788321,
 't:auc@15': 0.018532890705006543,
 't:auc@30': 0.059858781693915675,
 't:medianErr': 57.92544174194336}[0m
[32m2024-01-06 00:11:25.580[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 5/10[0m


Acc: 262/274


100%|██████████| 37/37 [00:00<00:00, 53.83it/s]
100%|██████████| 30/30 [00:00<00:00, 55.08it/s]
100%|██████████| 41/41 [00:00<00:00, 55.33it/s]
100%|██████████| 48/48 [00:00<00:00, 56.66it/s]
100%|██████████| 34/34 [00:00<00:00, 57.75it/s]
100%|██████████| 60/60 [00:01<00:00, 58.62it/s]
[32m2024-01-06 00:11:30.029[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.796,
 'R:ACC15': 0.468,
 'R:ACC30': 0.988,
 'R:auc@15': 0.22044909730752302,
 'R:auc@30': 0.4990124638835589,
 'R:medianErr': 16.109557151794434,
 't:ACC15': 0.04,
 't:ACC30': 0.164,
 't:auc@15': 0.01666362279256185,
 't:auc@30': 0.05540458437601726,
 't:medianErr': 59.90811538696289}[0m
[32m2024-01-06 00:11:30.030[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 6/10[0m


Acc: 199/250


100%|██████████| 63/63 [00:01<00:00, 54.13it/s]
100%|██████████| 53/53 [00:00<00:00, 55.59it/s]
100%|██████████| 45/45 [00:00<00:00, 56.28it/s]
100%|██████████| 50/50 [00:00<00:00, 54.69it/s]
100%|██████████| 35/35 [00:00<00:00, 55.83it/s]
100%|██████████| 36/36 [00:00<00:00, 56.72it/s]
[32m2024-01-06 00:11:35.142[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9148936170212766,
 'R:ACC15': 0.624113475177305,
 'R:ACC30': 0.9893617021276596,
 'R:auc@15': 0.31657796833250257,
 'R:auc@30': 0.5939261024020805,
 'R:medianErr': 11.660637855529785,
 't:ACC15': 0.05673758865248227,
 't:ACC30': 0.1524822695035461,
 't:auc@15': 0.01887407922857479,
 't:auc@30': 0.0649749450367957,
 't:medianErr': 58.30123710632324}[0m
[32m2024-01-06 00:11:35.143[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 7/10[0m


Acc: 258/282


100%|██████████| 37/37 [00:00<00:00, 53.43it/s]
100%|██████████| 80/80 [00:01<00:00, 54.89it/s]
100%|██████████| 109/109 [00:02<00:00, 53.91it/s]
100%|██████████| 55/55 [00:00<00:00, 55.26it/s]
100%|██████████| 43/43 [00:00<00:00, 56.73it/s]
100%|██████████| 40/40 [00:00<00:00, 58.64it/s]
[32m2024-01-06 00:11:41.766[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9065934065934066,
 'R:ACC15': 0.6456043956043956,
 'R:ACC30': 0.989010989010989,
 'R:auc@15': 0.26155658311896274,
 'R:auc@30': 0.5592782623711087,
 'R:medianErr': 12.125442504882812,
 't:ACC15': 0.04395604395604396,
 't:ACC30': 0.17032967032967034,
 't:auc@15': 0.014852824490585607,
 't:auc@30': 0.05739410251051515,
 't:medianErr': 54.21873092651367}[0m
[32m2024-01-06 00:11:41.767[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 8/10[0m


Acc: 330/364


100%|██████████| 76/76 [00:01<00:00, 54.62it/s]
100%|██████████| 55/55 [00:00<00:00, 57.02it/s]
100%|██████████| 62/62 [00:01<00:00, 57.24it/s]
100%|██████████| 44/44 [00:00<00:00, 58.26it/s]
100%|██████████| 43/43 [00:00<00:00, 59.03it/s]
100%|██████████| 25/25 [00:00<00:00, 59.71it/s]
[32m2024-01-06 00:11:47.126[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.8459016393442623,
 'R:ACC15': 0.6491803278688525,
 'R:ACC30': 0.9901639344262295,
 'R:auc@15': 0.32980904372012027,
 'R:auc@30': 0.5926565199276138,
 'R:medianErr': 11.174063682556152,
 't:ACC15': 0.06229508196721312,
 't:ACC30': 0.20327868852459016,
 't:auc@15': 0.0187574413695622,
 't:auc@30': 0.07615982529895553,
 't:medianErr': 52.32026672363281}[0m
[32m2024-01-06 00:11:47.127[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 9/10[0m


Acc: 258/305


100%|██████████| 45/45 [00:00<00:00, 54.71it/s]
100%|██████████| 65/65 [00:01<00:00, 56.06it/s]
100%|██████████| 59/59 [00:01<00:00, 55.79it/s]
100%|██████████| 67/67 [00:01<00:00, 56.39it/s]
100%|██████████| 57/57 [00:00<00:00, 57.07it/s]
100%|██████████| 34/34 [00:00<00:00, 56.94it/s]
[32m2024-01-06 00:11:52.968[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.5535168195718655,
 'R:ACC15': 0.5443425076452599,
 'R:ACC30': 0.9969418960244648,
 'R:auc@15': 0.2600529317583634,
 'R:auc@30': 0.547970008108837,
 'R:medianErr': 13.689202308654785,
 't:ACC15': 0.06422018348623854,
 't:ACC30': 0.1743119266055046,
 't:auc@15': 0.0256067602641721,
 't:auc@30': 0.0697242958465483,
 't:medianErr': 59.28119659423828}[0m
[32m2024-01-06 00:11:52.969[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOnepose: 10/10[0m


Acc: 181/327


100%|██████████| 55/55 [00:01<00:00, 54.36it/s]
100%|██████████| 98/98 [00:01<00:00, 55.85it/s]
100%|██████████| 87/87 [00:01<00:00, 55.97it/s]
100%|██████████| 87/87 [00:01<00:00, 55.77it/s]
100%|██████████| 45/45 [00:00<00:00, 54.87it/s]
100%|██████████| 32/32 [00:00<00:00, 52.44it/s]
[32m2024-01-06 00:12:00.298[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.7475247524752475,
 'R:ACC15': 0.6188118811881188,
 'R:ACC30': 0.9975247524752475,
 'R:auc@15': 0.2848089729598647,
 'R:auc@30': 0.5677398429845426,
 'R:medianErr': 11.823050022125244,
 't:ACC15': 0.04950495049504951,
 't:ACC30': 0.1782178217821782,
 't:auc@15': 0.015998072671418142,
 't:auc@30': 0.059944056874454615,
 't:medianErr': 57.56123924255371}[0m


Acc: 302/404


In [7]:
headers = ["Category"] + list(val_metrics_4tb.keys())
all_data = np.array(res_table)[:, 1:].astype(np.float32)
res_table.append(["Avg"] + all_data.mean(0).tolist())
print(tabulate(res_table, headers=headers, tablefmt='fancy_grid'))

╒════════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤══════════╕
│ Category       │   R:auc@15 │   R:auc@30 │   R:ACC15 │   R:ACC30 │   R:medianErr │   t:auc@15 │   t:auc@30 │   t:ACC15 │   t:ACC30 │   t:medianErr │     AP50 │
╞════════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪══════════╡
│ aptamil        │   0.304344 │   0.580806 │  0.676744 │  0.983721 │       10.8678 │  0.0102536 │  0.0510321 │ 0.027907  │  0.153488 │       57.6837 │ 0.915718 │
├────────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼──────────┤
│ jzhg           │   0.334463 │   0.602746 │  0.697279 │  0.993197 │       10.2906 │  0.0118661 │  0.0516928 │ 0.0442177 │  0.129252 │       59.7026 │ 0.914966 │
├────────────────┼──────────