In [1]:
%load_ext autoreload
%autoreload 2
import torch
import yaml
import sys
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from pathlib import Path
sys.path.append('fbsource/surreal/')
from maploc.data.loader_mapillary import MapillaryDataModule
from maploc.train import GenericModule
from maploc.utils.viz_2d import plot_images, plot_keypoints, features_to_RGB, save_plot, add_text
from maploc.utils.viz_localization import likelihood_overlay, plot_pose, plot_dense_rotations
from maploc.osm.viz import Colormap, plot_nodes
from maploc.data.collate import collate
from maploc.utils.geo import BoundaryBox, Projection
torch.set_grad_enabled(False);
plt.rcParams.update({'figure.max_open_warning': 0})

In [2]:
ckpt = 'manifold://psarlin/tree/maploc/experiments/{}/last.ckpt'
# exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs9-resize256_norm-d8-nrot64"
exper = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64"
cfg = {}
cfg = {'model': {"num_rotations": 128}}

path = ckpt.format(exper)
print(path)
model = GenericModule.load_from_checkpoint(path, strict=True, find_best=True, cfg=cfg)
model.eval();

In [3]:
conf = OmegaConf.load('fbsource/fbcode/scripts/psarlin/maploc/conf/data_mapillary.yaml')
conf = OmegaConf.merge(conf, OmegaConf.create(yaml.full_load("""
local_dir: "./data/aria_dumps/"
dump_dir: ${.local_dir}
scenes:
    - reloc_seattle_downtown
    # - reloc_detroit_greektown
    # - reloc_detroit_gcp
tiles_filename: tiles.pkl
max_init_error: 0
init_from_gps: true
return_gps: true
val: {batch_size: 1, num_workers: 1}
train: ${.val}
random: false
augmentation: {rot90: false, flip: false}
""")))
OmegaConf.resolve(conf)
datamodule = MapillaryDataModule(conf)
datamodule.prepare_data()
datamodule.setup()
dataset = datamodule.dataset("val")
colormap = Colormap()
name2idx = dict(zip(dataset.names, range(len(dataset.names))))
is_gt_absolute = False

In [None]:
scene = dataset.cfg.scenes[0]
seq_key = list(dataset.dumps[scene])[1]
frames = [(sq, n) for sc, sq, n in dataset.names if sc == scene and sq == seq_key]
print(len(frames))
timestamps = np.array([dataset.dumps[scene][sq]["views"][n]["capture_time"] for sq, n in frames])
plot_images([dataset[name2idx[(scene,)+frames[0]]]["image"].permute(1, 2, 0)])

# Parse MDC LeeLoo output

In [314]:
import json
import gaia
from kapture.io.csv import rigs_from_file
import quaternion
from maploc.mapillary.processing import decompose_rotmat
from maploc.data.aria import get_gaia_gt_id

gaia_id = int(seq_key)
gt_id = get_gaia_gt_id(gaia_id)
aria_dir = Path("/home/psarlin/local/aria", str(gaia_id)).absolute()

gt_path = aria_dir / "gt.mdc"
if not gt_path.exists():
    gaia.download_file(gt_id, output_dir=aria_dir, destination_file_name=gt_path.stem)
with open(gt_path) as fd:
    gt_dict = json.load(fd)

kapture_path = aria_dir / f"kapture/{gaia_id}/sensors"
sensor_id = f"{gaia_id}_RGB"
rigs = rigs_from_file(kapture_path / "rigs.txt", [sensor_id])
T_rgb2rig = rigs[str(gaia_id), sensor_id].inverse()

def match_timestamps(ts_rgb, ts_kf):
    idx = np.searchsorted(ts_rgb, ts_kf)
    idx = np.minimum(idx, len(ts_rgb)-1)
    prev = np.maximum(idx-1, 0)
    idx_kf2rgb = np.where(np.abs(ts_rgb[idx] - ts_kf) < np.abs(ts_rgb[prev] - ts_kf), idx, prev)
    print("Max timestamp error", np.abs(ts_kf - ts_rgb[idx_kf2rgb]).max()/1e6)
    return idx_kf2rgb

compactse3d = {p["center_capture_time_us"]: p["anchored_pose_type"]["world_anchored_pose"] for p in gt_dict["trajectory"]["sampled_poses"]}
ts2pose_rig = {}
for p in gt_dict["trajectory"]["sampled_poses"]:
    ts = p["center_capture_time_us"]
    compactse3d = p["anchored_pose_type"]["world_anchored_pose"]["transform_parent_trajpose"]
    q_xyzw, t_rig2w = np.split(compactse3d, [4])
    R_rig2w = quaternion.as_rotation_matrix(quaternion.from_float_array(np.r_[q_xyzw[-1], q_xyzw[:3]]))
    ts2pose_rig[ts] = (R_rig2w, t_rig2w)

timestamps_gt = np.array(list(ts2pose_rig))
ts_rgb2gt = dict(zip(timestamps, timestamps_gt[match_timestamps(timestamps_gt, timestamps/1000)]))

xy_gt_all = []
yaws_gt_all = []
for ts in timestamps:
    R_rig2w, t_rig2w = ts2pose_rig[ts_rgb2gt[ts]]
    R_rgb2w = R_rig2w @ quaternion.as_rotation_matrix(T_rgb2rig.r)
    t_rgb2w = R_rig2w @ T_rgb2rig.t[:, 0] + t_rig2w
    *_, yaw = decompose_rotmat(R_rgb2w)
    xy_gt_all.append(t_rgb2w[:2])
    yaws_gt_all.append(yaw)
xy_gt_all = np.stack(xy_gt_all)
yaws_gt_all = np.stack(yaws_gt_all)

# Subsample for eval

In [443]:
idx_select = np.arange(len(timestamps))[::1]#[::len(timestamps)//100]
frames_select = [frames[i] for i in idx_select]
print(len(frames_select))
xy_gt = torch.from_numpy(xy_gt_all[idx_select]).float()
yaws_gt = torch.from_numpy(yaws_gt_all[idx_select]).float()

In [None]:
plt.scatter(*xy_gt_all.T, c=np.linspace(0, 1, len(xy_gt_all)), cmap="jet")
plt.gca().set_aspect("equal")

In [444]:
%%time
from maploc.data.sequential import unpack_batches
batches = [dataset[name2idx[(scene, sq, n)]] for sq, n in frames_select]
preds = [model(collate([b])) for b in batches]
uvt_p = [p["xyr_max"][0] for p in preds]
logprobs = [p["log_probs"][0] for p in preds]
images, canvas, maps, yaws_slam, _, xy_slam, xy_gps = unpack_batches(batches)
maps = list(map(colormap.apply, maps))

In [445]:
from maploc.models.sequential import gps_alignment
accuracy_gps = torch.stack([b.get("accuracy_gps", torch.tensor(15.)) for b in batches])
belief_align_gps, xy_seq_gps, uvt_align_gps, xy_align_gps = gps_alignment(
    xy_gps, accuracy_gps, canvas, xy_gt, yaws_gt, num_rotations=512)
plot_images([belief_align_gps.max(-1).values, belief_align_gps.max(-1).values.exp()], cmaps="jet")

In [None]:
%%time
from maploc.models.sequential import rigid_alignment
from maploc.osm.raster import Canvas
from maploc.utils.geo import BoundaryBox

tile_size = dataset.cfg.crop_size_meters
ppm = 1
num_rotations = 256
xy_init = xy_gps[0].numpy()

for _ in range(5):
    # print(ppm, num_rotations, tile_size, ppm**2*num_rotations/256*tile_size/64)
    canvas_align = Canvas(BoundaryBox(xy_init-tile_size, xy_init+tile_size), ppm)
    belief_align, uvt_seq, _, uvt_align_ref = rigid_alignment(logprobs, canvas, xy_gt, yaws_gt, num_rotations, canvas_align, )
    xy_align_ref = canvas_align.to_xy(uvt_align_ref[:2])
    print(xy_align_ref, uvt_align_ref[-1])
    plot_images([belief_align.max(-1).values, belief_align.max(-1).values.exp()], cmaps="jet")

    ppm *= 2
    num_rotations *= 2
    tile_size /= 4
    num_rotations = min(num_rotations, 4096)
    tile_size = max(tile_size, 1)
    xy_init = xy_align_ref

In [None]:
bounds = (xy_gps.numpy().min(0), xy_gps.numpy().max(0))
bbox_total = BoundaryBox(*bounds) + dataset.cfg.crop_size_meters
tile_total = dataset.tile_managers[scene].query(bbox_total)
map_total = colormap.apply(tile_total.raster)
uv_gps = tile_total.to_uv(xy_gps.numpy())
uv_seq_gps = tile_total.to_uv(xy_seq_gps.numpy())
uv_seq_total = tile_total.to_uv([c.to_xy(uvt[:2].numpy()) for c, uvt in zip(canvas, uvt_seq)])

plot_images([map_total])
uv_total = tile_total.to_uv([c.to_xy(uvt[:2].numpy()) for c, uvt in zip(canvas, uvt_p)])
plt.plot(*uv_total.T, marker='o', ms=4, lw=1, c="k", label="single", alpha=0.5);
plt.plot(*uv_gps.T, marker='o', ms=4, lw=1, c="blue", label="GPS");
plt.plot(*uv_seq_gps.T, marker='o', ms=4, lw=1, c="orange", label="GPS seq");
plt.plot(*uv_seq_total.T, marker='o', ms=4, lw=1, c="r", label="sequential");

# Check pointcloud alignment

In [None]:
from pathlib import Path
from livemaps.mapping.sparse_map import create_map_context_from_filesystem
sparse_map, _, _, _ = create_map_context_from_filesystem(
    mdc_prefix=str(aria_dir / f"gaia:{gaia_id}/reconstruction/result"),
    sparse_map=True, pose_graph=False, image_reader=False, descriptor_db=False,
)
all_p3d = np.stack([sparse_map.get_point_in_world(i) for i in sparse_map.track_ids()])

In [416]:
from maploc.models.utils import rotmat2d, deg2rad
p3d_xy = (torch.from_numpy(all_p3d[:, :2]) - xy_slam[0]) @ rotmat2d(deg2rad(yaws_slam[0] - uvt_seq[0][2])).T + canvas[0].to_xy(uvt_seq[0][:2]).float()
p3d_uv = tile_total.to_uv(p3d_xy)

In [436]:
plot_images([map_total], dpi=200)
plt.scatter(*p3d_uv[tile_total.bbox.contains(p3d_xy.numpy())].T, c="r", s=0.2, alpha=0.5)
p = 30
c = uv_seq_total[len(uv_seq_total)//2-15]
plt.ylim([c[1]+p, c[1]-p])
plt.xlim([c[0]-p, c[0]+p]);