In [2]:
%load_ext autoreload
%autoreload 2

import torch
import yaml
import sys
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from pytorch_lightning import seed_everything
sys.path.append('fbsource/fbcode/scripts/psarlin/')
from maploc.data.loader_mapillary import MapillaryDataModule
from maploc.module import GenericModule
from maploc.utils.viz_2d import plot_images, plot_keypoints, features_to_RGB, save_plot, add_text
from maploc.utils.viz_localization import likelihood_overlay, plot_pose, plot_dense_rotations
from maploc.osm.viz import Colormap, plot_nodes
from maploc.models.bev_localizer import PolarProjection
from maploc.models.hough_voting import argmax_xyr, fuse_gps
from maploc.models.refinement import FeaturemetricRefiner, subpixel_refinement
torch.set_grad_enabled(False);
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
conf = OmegaConf.load('fbsource/fbcode/scripts/psarlin/maploc/conf/data_mapillary.yaml')
conf = OmegaConf.merge(conf, OmegaConf.create(yaml.full_load("""
local_dir: "./data/aria_dumps/"
dump_dir: ${.local_dir}
scenes:
    # - reloc_seattle_downtown
    - reloc_detroit_greektown
    # - reloc_detroit_gcp
tiles_filename: tiles.pkl
max_init_error: 0
init_from_gps: true
return_gps: true
val: {batch_size: 1, num_workers: 1}
train: ${.val}
random: false
augmentation: {rot90: false, flip: false}
""")))
OmegaConf.resolve(conf)
dataset = MapillaryDataModule(conf)
dataset.prepare_data()
dataset.setup()

In [4]:
exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs9-resize256_norm-d8-nrot64"
# exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs10-resize256_attn-fix-2-2-128d_norm-d8-nrot64"
# exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs6-resize256_attn-fix-2-2-256d_norm-d8-nrot32"
exper = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior"

root = "manifold://psarlin/tree/maploc/experiments"
path = f'{root}/{exper}/last.ckpt'
print(path)
cfg = {}
cfg = {'model': {"num_rotations": 128}}
model = GenericModule.load_from_checkpoint(path, strict=True, find_best=True, cfg=cfg)
model = model.eval()#.cuda()
proj_polar = PolarProjection(model.cfg.model.z_max, model.cfg.model.pixel_per_meter)

In [None]:
seed_everything(42)
loader = dataset.dataloader("val", shuffle=True)
print(exper)
colormap = Colormap()
for i, batch in zip(range(15), loader):
    batch = model.transfer_batch_to_device(batch, model.device, i)
    pred = model(batch)
    uv_gps, = pred["xy_gps"] = batch["xy_gps"]
    uvt_fused = argmax_xyr(fuse_gps(pred["log_probs"], uv_gps, model.cfg.model.pixel_per_meter, sigma=10))
    uv_fused, = pred["xy_fused"] = uvt_fused[..., :2]
    yaw_fused, = pred["yaw_fused"] = uvt_fused[..., -1]
    # pred["xy_subpix"], pred["yaw_subpix"] = subpixel_refinement(
    #     pred["log_probs"][0], uv_fused, yaw_fused)

    loss = model.model.loss(pred, batch)

    scene = batch["scene"][0]
    name = batch["name"][0]
    view = dataset.dumps[scene][batch["sequence"][0]]["views"][name]
    image = batch["image"][0].permute(1, 2, 0)
    rasters = batch["map"][0]

    lp_uvt = pred["log_probs"][0]
    lp_uv = lp_uvt.max(-1).values
    prob = lp_uv.exp()
    feats_map = pred["map"]["map_features"][0][0]
    feats_q = pred["features_bev"][0]
    mask_bev = pred["valid_bev"][0]
    if "log_prior" in pred["map"]:
        prior = pred["map"]["log_prior"][0][0].sigmoid()
    else:
        prior = None
    norm_q = torch.norm(feats_q, dim=0)
    conf_q = pred["bev"].get("confidence", norm_q[None])[0].clone()
    conf_q.masked_fill_(~mask_bev, np.nan)

    image_bev, _, _ = model.model.projection_bev(
        *model.model.projection_polar(
            batch["image"], batch["ground_plane"],  batch["camera"])[:2], batch["camera"])
            # batch["image"], batch["camera_height"],  batch["camera"])[:2], batch["camera"])
    image_bev = image_bev[0].permute(1, 2, 0).numpy()
    image_bev = np.concatenate([image_bev, mask_bev[..., None]], -1)

    feats_map_rgb, = features_to_RGB(feats_map.numpy())
    feats_q_rgb, = features_to_RGB(feats_q.numpy(), masks=[mask_bev])
    # feats_map_rgb, feats_q_rgb, = features_to_RGB(feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
    feats_q_rgb = np.concatenate([feats_q_rgb, mask_bev[..., None]], -1)

    # text1 = f'error: {results["xy_max_error"]:.1f}m/{results["yaw_max_error"]:.1f}°'
    # text2 = f'{results["xy_gps"]:.1f}/{results["xy_fused"]:.1f}/{results["xy_subpix"]:.1f}m'
    text1 = text2 = ""
    map_viz = colormap.apply(rasters)
    plot_images([image, map_viz, lp_uv, likelihood_overlay(prob.numpy(), map_viz), feats_map_rgb],
                titles=[text1, text2, 'loglikelihood', 'likelihood', 'map features'], dpi=75, cmaps='jet')
    colormap.add_colorbar()
    plot_nodes(1, rasters[2])
    plot_pose([1], uv_fused, yaw_fused, s=1/35, c="blue", w=0.015)
    plot_pose([1], uv_gps, c="red")
    plot_pose([1], pred["xy_max"][0], pred["yaw_max"][0], s=1/35, c="k", w=0.015)
    plot_dense_rotations(2, lp_uvt.exp(), s=1/15)

    plot_images([image_bev, conf_q, feats_q_rgb] + ([] if prior is None else [prior]),
                 titles=["BEV image", "BEV weight", "BEV features"]+([] if prior is None else ['map prior']), dpi=50, cmaps='jet')