In [1]:
%load_ext autoreload
%autoreload 2

In [23]:
import json
import numpy as np
from scipy.spatial.transform import Rotation

import faiss
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from typing import List
from omegaconf import OmegaConf
from hydra.utils import instantiate
from opr.pipelines.place_recognition import PlaceRecognitionPipeline
from opr.pipelines.place_recognition import TextLabelsPlaceRecognitionPipeline

In [11]:
def get_labels_by_id(labels: List[str], id: str):
    frame = labels[id]
    all_labels = [i["value"]["text"] for i in frame["back_cam_anno"] + frame["front_cam_anno"]]
    all_labels = sum(all_labels, [])
    return all_labels

def pose_to_matrix(pose):
    """From the 6D poses in the [tx ty tz qx qy qz qw] format to 4x4 pose matrices."""
    position = pose[:3]
    orientation_quat = pose[3:]
    rotation = Rotation.from_quat(orientation_quat)
    pose_matrix = np.eye(4)
    pose_matrix[:3,:3] = rotation.as_matrix()
    pose_matrix[:3,3] = position
    return pose_matrix

def compute_error(estimated_pose, gt_pose):
    """For the 6D poses in the [tx ty tz qx qy qz qw] format."""
    estimated_pose = pose_to_matrix(estimated_pose)
    gt_pose = pose_to_matrix(gt_pose)
    error_pose = np.linalg.inv(estimated_pose) @ gt_pose
    dist_error = np.sum(error_pose[:3, 3]**2) ** 0.5
    r = Rotation.from_matrix(error_pose[:3, :3])
    rotvec = r.as_rotvec()
    angle_error = (np.sum(rotvec**2)**0.5) * 180 / np.pi
    angle_error = abs(90 - abs(angle_error-90))
    return dist_error, angle_error

def compute_translation_error(gt_pose, pred_pose):
    """For the 4x4 pose matrices."""
    gt_trans = gt_pose[:3, 3]
    pred_trans = pred_pose[:3, 3]
    error = np.linalg.norm(gt_trans - pred_trans)
    return error

def compute_rotation_error(gt_pose, pred_pose):
    """For the 4x4 pose matrices."""
    gt_rot = Rotation.from_matrix(gt_pose[:3, :3])
    pred_rot = Rotation.from_matrix(pred_pose[:3, :3])
    error = Rotation.inv(gt_rot) * pred_rot
    error = error.as_euler('xyz', degrees=True)
    error = np.linalg.norm(error)
    return error

def compute_absolute_pose_error(gt_pose, pred_pose):
    """For the 4x4 pose matrices."""
    rotation_error = compute_rotation_error(gt_pose, pred_pose)
    translation_error = compute_translation_error(gt_pose, pred_pose)
    return rotation_error, translation_error

In [17]:
from opr.datasets.itlp import ITLPCampus

QUERY_LABELS_PATH = "/home/docker_opr/Datasets/indoor/01_2023-11-09-twilight/text_labels.json"
DB_LABELS_PATH = "/home/docker_opr/Datasets/indoor/00_2023-10-25-night/text_labels.json"

QUERY_TRACK_DIR = "/home/docker_opr/Datasets/indoor/01_2023-11-09-twilight"
DATABASE_TRACK_DIR = "/home/docker_opr/Datasets/indoor/00_2023-10-25-night"

DEVICE = "cuda"

BATCH_SIZE = 64
NUM_WORKERS = 4

MODEL_CONFIG_PATH = "../configs/model/place_recognition/minkloc3d.yaml"
WEIGHTS_PATH = "../weights/place_recognition/minkloc3d_nclt.pth"

In [18]:
query_dataset = ITLPCampus(
    dataset_root=QUERY_TRACK_DIR,
    subset="test",
    sensors=["lidar"],
    mink_quantization_size=0.5,
    load_semantics=False,
    load_text_descriptions=False,
    load_text_labels=False,
    load_aruco_labels=False,
    indoor=True,
    test_split=[1],
)

db_dataset = ITLPCampus(
    dataset_root=DATABASE_TRACK_DIR,
    subset="test",
    sensors=["lidar"],
    indoor=True,
    test_split=[1],
)

with open(QUERY_LABELS_PATH, "rb") as f:
    query_labels = json.load(f)
    query_labels = json.loads(query_labels)


In [19]:
model_config = OmegaConf.load(MODEL_CONFIG_PATH)
model = instantiate(model_config)
model.load_state_dict(torch.load(WEIGHTS_PATH))
model = model.to(DEVICE)
model.eval();

## calculate index.faiss

In [20]:
db_dataloader = DataLoader(
    db_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=db_dataset.collate_fn,
)


In [21]:
descriptors = []
with torch.no_grad():
    for batch in tqdm(db_dataloader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        final_descriptor = model(batch)["final_descriptor"]
        descriptors.append(final_descriptor.detach().cpu().numpy())

descriptors = np.concatenate(descriptors, axis=0)

  0%|          | 0/4 [00:00<?, ?it/s]

  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = tor

In [24]:
index = faiss.IndexFlatL2(descriptors.shape[1])
index.add(descriptors)
print(index.is_trained)
print(index.ntotal)

True
241


In [25]:
faiss.write_index(index, DATABASE_TRACK_DIR + "/index.faiss")

# Pipeline

In [26]:
pipe = TextLabelsPlaceRecognitionPipeline(
    db_labels_path=DB_LABELS_PATH,
    database_dir=DATABASE_TRACK_DIR,
    model=model,
    model_weights_path=WEIGHTS_PATH,
    device=DEVICE,
)

default_pipe = PlaceRecognitionPipeline(
    database_dir=DATABASE_TRACK_DIR,
    model=model,
    model_weights_path=WEIGHTS_PATH,
    device=DEVICE,
)


In [27]:
pred_i = pipe.database_df[pipe.database_df["timestamp"] == 1698265583792060160]
pred_i.index[0]

5

In [29]:
id = 5 # 914

timestamp = list(query_labels.keys())[id]
query_annos = get_labels_by_id(query_labels, timestamp)
print(f"query_annos = {query_annos}")

sample_data = query_dataset[id]
sample_pose_gt = sample_data.pop("pose")

sample_output = pipe.infer(sample_data, query_annos)

print(f"sample_output.keys() = {sample_output.keys()}")
print(f"sample_output['idx'] = {sample_output['idx']}")
print(f"pose = {sample_output['pose']}")
print(f"pose_gt = {sample_pose_gt.numpy()}")

dist_error, angle_error = compute_error(sample_output["pose"], sample_pose_gt.numpy())
print(f"dist_error = {dist_error}, angle_error = {angle_error}")

query_annos = ['мфти', 'центр цифровых технологий']
sample_output.keys() = dict_keys(['idx', 'pose', 'descriptor'])
sample_output['idx'] = 1
pose = [ 1.97034869e+00  9.23527777e-02  2.15496325e-01  2.44941231e-04
  2.16747197e-02 -1.12763019e-02  9.99701451e-01]
pose_gt = [ 2.8521736  -2.199138   -0.05990117  0.00974635  0.01638859 -0.9013
  0.4327756 ]
dist_error = 2.4707060806175454, angle_error = 52.61228770101103


  pc = torch.tensor(pc, dtype=torch.float32)


In [30]:
sample_output.keys()

dict_keys(['idx', 'pose', 'descriptor'])

In [31]:
import time
from geotransformer.utils.registration import compute_registration_error
from geotransformer.utils.pointcloud import get_transform_from_rotation_translation

In [35]:
PR_MATCH_THRESHOLD = 25.0
pr_matches = []
rre_list = []
rte_list = []
times = []


#for id in tqdm(range(len(query_labels))):  #!!!!!!
for id in tqdm(range(len(query_dataset))):
    timestamp = list(query_labels.keys())[id]
    query_annos = get_labels_by_id(query_labels, timestamp)
    data = query_dataset[id]
    gt_pose = data.pop("pose")
    gt_pose = get_transform_from_rotation_translation(Rotation.from_quat(gt_pose[3:]).as_matrix(), gt_pose[:3])

    start_time = time.time()
    pipe_out = pipe.infer(data, query_annos)
    times.append(time.time() - start_time)

    estimated_pose = pipe_out["pose"]
    estimated_pose = get_transform_from_rotation_translation(Rotation.from_quat(estimated_pose[3:]).as_matrix(), estimated_pose[:3])

    _, db_match_distance = compute_registration_error(gt_pose, estimated_pose)
    pr_matched = db_match_distance <= PR_MATCH_THRESHOLD
    pr_matches.append(pr_matched)

    if pr_matched:
        rre, rte = compute_registration_error(gt_pose, estimated_pose)
        rre_list.append(rre)
        rte_list.append(rte)


  0%|          | 0/239 [00:00<?, ?it/s]

  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = tor

In [36]:
print(f"PlaceRecognition R@1 = {np.mean(pr_matches):0.3f}")
print(f"Localization Mean RRE = {np.mean(rre_list):0.3f}")
print(f"Localization Mean RTE = {np.mean(rte_list):0.3f}")

print(f"Localization Median RRE = {np.median(rre_list):0.3f}")
print(f"Localization Median RTE = {np.median(rte_list):0.3f}")

print(f"Mean Time = {(np.mean(times) * 1000):0.2f} ms")

PlaceRecognition R@1 = 0.912
Localization Mean RRE = 67.764
Localization Mean RTE = 4.184
Localization Median RRE = 13.541
Localization Median RTE = 2.106
Mean Time = 17.74 ms


In [37]:
len(pr_matches), len(rre_list)

(239, 218)

### Results only on frames with text labels (TextLabelsPlaceRecognitionPipeline)

In [38]:
PR_MATCH_THRESHOLD = 25.0
pr_matches = []
rre_list = []
rte_list = []
times = []


# for id in tqdm(range(len(query_labels))):   #!!!
for id in tqdm(range(len(query_dataset))):
    timestamp = list(query_labels.keys())[id]
    query_annos = get_labels_by_id(query_labels, timestamp)

    if len(query_annos) == 0:
        continue
    else:
        data = query_dataset[id]
        gt_pose = data.pop("pose")
        gt_pose = get_transform_from_rotation_translation(Rotation.from_quat(gt_pose[3:]).as_matrix(), gt_pose[:3])

        start_time = time.time()
        pipe_out = pipe.infer(data, query_annos)
        times.append(time.time() - start_time)

        estimated_pose = pipe_out["pose"]
        estimated_pose = get_transform_from_rotation_translation(Rotation.from_quat(estimated_pose[3:]).as_matrix(), estimated_pose[:3])

        _, db_match_distance = compute_registration_error(gt_pose, estimated_pose)
        pr_matched = db_match_distance <= PR_MATCH_THRESHOLD
        pr_matches.append(pr_matched)

        if pr_matched:
            rre, rte = compute_registration_error(gt_pose, estimated_pose)
            rre_list.append(rre)
            rte_list.append(rte)

  0%|          | 0/239 [00:00<?, ?it/s]

  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = tor

In [39]:
print(f"PlaceRecognition R@1 = {np.mean(pr_matches):0.3f}")
print(f"Localization Mean RRE = {np.mean(rre_list):0.3f}")
print(f"Localization Mean RTE = {np.mean(rte_list):0.3f}")

print(f"Localization Median RRE = {np.median(rre_list):0.3f}")
print(f"Localization Median RTE = {np.median(rte_list):0.3f}")

print(f"Mean Time = {(np.mean(times) * 1000):0.2f} ms")

PlaceRecognition R@1 = 0.908
Localization Mean RRE = 95.755
Localization Mean RTE = 5.096
Localization Median RRE = 110.311
Localization Median RTE = 4.345
Mean Time = 22.97 ms


In [40]:
len(pr_matches), len(rre_list)

(76, 69)

In [41]:
len(query_dataset), len(pr_matches) / len(query_dataset) * 100

(239, 31.799163179916317)

### Results only on frames with text labels (PlaceRecognitionPipeline)

In [42]:
PR_MATCH_THRESHOLD = 25.0
pr_matches = []
rre_list = []
rte_list = []
times = []


# for id in tqdm(range(len(query_labels))):   #!!!
for id in tqdm(range(len(query_dataset))):
    timestamp = list(query_labels.keys())[id]
    query_annos = get_labels_by_id(query_labels, timestamp)

    if len(query_annos) == 0:
        continue
    else:
        data = query_dataset[id]
        gt_pose = data.pop("pose")
        gt_pose = get_transform_from_rotation_translation(Rotation.from_quat(gt_pose[3:]).as_matrix(), gt_pose[:3])

        start_time = time.time()
        pipe_out = default_pipe.infer(data)
        times.append(time.time() - start_time)

        estimated_pose = pipe_out["pose"]
        estimated_pose = get_transform_from_rotation_translation(Rotation.from_quat(estimated_pose[3:]).as_matrix(), estimated_pose[:3])

        _, db_match_distance = compute_registration_error(gt_pose, estimated_pose)
        pr_matched = db_match_distance <= PR_MATCH_THRESHOLD
        pr_matches.append(pr_matched)

        if pr_matched:
            rre, rte = compute_registration_error(gt_pose, estimated_pose)
            rre_list.append(rre)
            rte_list.append(rte)

  0%|          | 0/239 [00:00<?, ?it/s]

  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = torch.tensor(pc, dtype=torch.float32)
  pc = tor

In [43]:
print(f"PlaceRecognition R@1 = {np.mean(pr_matches):0.3f}")
print(f"Localization Mean RRE = {np.mean(rre_list):0.3f}")
print(f"Localization Mean RTE = {np.mean(rte_list):0.3f}")

print(f"Localization Median RRE = {np.median(rre_list):0.3f}")
print(f"Localization Median RTE = {np.median(rte_list):0.3f}")

print(f"Mean Time = {(np.mean(times) * 1000):0.2f} ms")

PlaceRecognition R@1 = 0.921
Localization Mean RRE = 42.049
Localization Mean RTE = 3.025
Localization Median RRE = 6.526
Localization Median RTE = 1.424
Mean Time = 9.67 ms


In [44]:
len(pr_matches), len(rre_list)

(76, 70)