In [1]:
%load_ext autoreload
%autoreload 2
import torch
import yaml
import sys
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
sys.path.append('fbsource/fbcode/scripts/psarlin/')
from maploc.data.loader_mapillary import MapillaryDataModule
from maploc.train import GenericModule
from maploc.utils.viz_2d import plot_images, plot_keypoints, features_to_RGB, save_plot, add_text
from maploc.utils.viz_localization import likelihood_overlay, plot_pose, plot_dense_rotations
from maploc.osm.viz import Colormap, plot_nodes
from maploc.data.collate import collate
from maploc.utils.geo import BoundaryBox, Projection
torch.set_grad_enabled(False);
plt.rcParams.update({'figure.max_open_warning': 0})

# Load model

In [3]:
ckpt = 'manifold://psarlin/tree/maploc/experiments/{}/last.ckpt'
exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs9-resize256_norm-d8-nrot64"
exper = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64"
cfg = {}
cfg = {'model': {"num_rotations": 128}}

path = ckpt.format(exper)
print(path)
model = GenericModule.load_from_checkpoint(path, strict=True, find_best=True, cfg=cfg)
model.eval();

# Load dataset, parse sequences

In [None]:
conf = OmegaConf.load('fbsource/fbcode/scripts/psarlin/maploc/conf/data_mapillary.yaml')
conf = OmegaConf.merge(conf, OmegaConf.create(yaml.full_load("""
local_dir: "./data/aria_dumps/"
dump_dir: ${.local_dir}
scenes:
    # - reloc_seattle_downtown
    - reloc_detroit_greektown
    # - reloc_detroit_gcp
tiles_filename: tiles.pkl
max_init_error: 0
init_from_gps: true
return_gps: true
val: {batch_size: 1, num_workers: 1}
train: ${.val}
random: false
augmentation: {rot90: false, flip: false}
""")))
OmegaConf.resolve(conf)
datamodule = MapillaryDataModule(conf)
datamodule.prepare_data()
datamodule.setup()
dataset = datamodule.dataset("val")
colormap = Colormap()
name2idx = dict(zip(dataset.names, range(len(dataset.names))))
is_gt_absolute = False

In [None]:
from maploc.data.sequential import unpack_batches, extract_sequences
scene = dataset.cfg.scenes[0]
chunks = extract_sequences(dataset, scenes=[scene], min_length=1, split_cam=False)[scene]

seq_key = list(chunks)[2]
chunk_idx = 1
seq = chunks[seq_key][chunk_idx]
print(seq_key, len(seq))
seq = seq[::2]
plot_images([dataset[name2idx[(scene, seq_key, seq[0])]]["image"].permute(1, 2, 0)])

# Inference

In [11]:
%%time
batches = [dataset[name2idx[(scene, seq_key, n)]] for n in seq]
preds = [model(collate([b])) for b in batches]
uvt_p = [p["xyr_max"][0] for p in preds]
logprobs = [p["log_probs"][0] for p in preds]
images, canvas, maps, yaws_gt, uv_gt, xy_gt, xy_gps = unpack_batches(batches)
maps = list(map(colormap.apply, maps))

plot_images(images, dpi=75)
plot_images(maps, dpi=75, titles=[f'{y:.2f}' for y in yaws_gt])
if is_gt_absolute:
    [plot_pose([i], uv, t, s=1/45, c="green", w=0.015) for i, (uv, t) in enumerate(zip(uv_gt, yaws_gt))];
[plot_pose([i], c.to_uv(xy), s=1/45, c="b", w=0.015) for i, (c, xy) in enumerate(zip(canvas, xy_gps))];
[plot_pose([i], uvt[:2], uvt[2], s=1/45, c="k", w=0.015) for i, uvt in enumerate(uvt_p)];
plot_images([lp.max(-1).values for lp in logprobs], cmaps="jet")
[plot_dense_rotations(i, lp.exp(), s=1/15) for i, lp in enumerate(logprobs)];

In [12]:
%%time
from maploc.models.sequential import rigid_alignment
from maploc.models.hough_voting import log_softmax_spatial
belief_align, uvt_seq, priors, _ = rigid_alignment(logprobs, canvas, xy_gt, yaws_gt, num_rotations=256, return_priors=True)
beliefs = log_softmax_spatial(torch.stack(priors).cumsum(0))
plot_images(beliefs.max(-1).values, cmaps="jet")

In [None]:
%%time
from maploc.models.sequential import gps_alignment
accuracy_gps = torch.stack([b["accuracy_gps"] for b in batches])
belief_align_gps, xy_seq_gps, uvt_align_gps, xy_align_gps = gps_alignment(
    xy_gps, accuracy_gps, canvas, xy_gt, yaws_gt, num_rotations=512)
plot_images([belief_align_gps.max(-1).values, belief_align_gps.max(-1).values.exp()], cmaps="jet")

### Visualization

In [14]:
plot_images(images, dpi=75)

plot_images([likelihood_overlay(b.max(-1).values.exp().numpy(), m) for m, b in zip(maps, beliefs)], cmaps="jet")
if is_gt_absolute:
    [plot_pose([i], uv, y, s=1/35, c="g", w=0.015) for i, (uv, y) in enumerate(zip(uv_gt, yaws_gt))];
[plot_pose([i], uvt[:2], uvt[2], s=1/35, c="k", w=0.015) for i, uvt in enumerate(uvt_p)];
[plot_pose([i], uvt[:2], uvt[2], s=1/35, c="r", w=0.015) for i, uvt in enumerate(uvt_seq)];

plot_images([lp.max(-1).values for lp in logprobs], cmaps="jet")
[plot_dense_rotations(i, lp.exp(), s=1/15) for i, lp in enumerate(logprobs)];

plot_images([b.max(-1).values for b in beliefs], cmaps="jet")
[plot_dense_rotations(i, b.exp(), s=1/15) for i, b in enumerate(beliefs)];

# Animation

In [None]:
from maploc.models.utils import make_grid

if is_gt_absolute:
    bounds = (xy_gt.numpy().min(0), xy_gt.numpy().max(0))
else:
    bounds = (xy_gps.numpy().min(0), xy_gps.numpy().max(0))
bbox_total = BoundaryBox(*bounds) + dataset.cfg.crop_size_meters
tile_total = dataset.tile_managers[scene].query(bbox_total)
map_total = colormap.apply(tile_total.raster)
uv_gt_total = tile_total.to_uv(xy_gt.numpy())
uv_total = tile_total.to_uv([c.to_xy(uvt[:2].numpy()) for c, uvt in zip(canvas, uvt_p)])
uv_seq_total = tile_total.to_uv([c.to_xy(uvt[:2].numpy()) for c, uvt in zip(canvas, uvt_seq)])
uv_gps = tile_total.to_uv(xy_gps.numpy())
uv_seq_gps = tile_total.to_uv(xy_seq_gps.numpy())
xy_grid_total = tile_total.to_xy(make_grid(tile_total.w, tile_total.h))

def remap_to_tile(xy_grid_total, canvas, log_probs):
    uv_grid_i = canvas.bbox.normalize(xy_grid_total).float() * 2 - 1
    uv_grid_i[..., 1] *= -1
    lp_total = torch.nn.functional.grid_sample(log_probs[None, None], uv_grid_i[None])[0, 0]
    valid = torch.nn.functional.grid_sample(torch.ones_like(lp_total)[None, None], uv_grid_i[None])[0, 0] == 1.
    lp_total.masked_fill_(~valid | lp_total.isnan(), -np.inf);
    return lp_total

def get_log_probs(idx):
    lp = logprobs[idx].max(-1).values
    # lp = beliefs[idx].max(-1).values
    return lp

def get_overlay(idx):
    lp = get_log_probs(idx)
    lp_total = remap_to_tile(xy_grid_total, canvas[idx], lp)
    p = lp_total.exp().numpy()**(1/3)
    p = p / p.max()
    ov = mpl.cm.jet(p)
    ov[..., -1] = np.where(p > 0.01, p**0.3, 0)
    return ov

def get_text(idx):
    return ""

plot_images([images[0], map_total, map_total])
plot_nodes(1, tile_total.raster[2])
axes = plt.gcf().axes
axes[2].plot(*uv_gps.T, marker='o', ms=4, lw=1, c="blue", label="GPS");
axes[2].plot(*uv_seq_gps.T, marker='o', ms=4, lw=1, c="orange", label="GPS seq");
if is_gt_absolute:
    axes[2].plot(*uv_gt_total.T, marker='o', ms=4, lw=1, c="green", label="GT");
axes[2].plot(*uv_total.T, marker='o', ms=4, lw=1, c="k", label="single", alpha=0.5);
axes[2].plot(*uv_seq_total.T, marker='o', ms=4, lw=1, c="r", label="sequential");
plt.legend(loc="lower right")

idx = 0
axes[0].images[-1].set_data(images[idx])
ax = axes[1]
ax.scatter(*(uv_gt_total if is_gt_absolute else uv_seq_total)[idx], s=32, c="lime", zorder=9, edgecolors="k");
ax.imshow(get_overlay(idx), zorder=10);
# ax.text(0.01, 0.99, get_text(idx), ha="left", va="top", fontsize=20, c="k", transform=ax.transAxes);
add_text(1, get_text(idx), (0.01, 0.99), fs=16, zorder=11);

def animate(idx):
    if idx is None:
        return
    ax.images[-1].set_data(get_overlay(idx))
    ax.collections[-1].set_offsets((uv_gt_total if is_gt_absolute else uv_seq_total)[idx])
    axes[0].images[-1].set_data(images[idx])
    ax.texts[-1].set_text(get_text(idx))

from IPython.display import Image
from matplotlib import animation
anim = animation.FuncAnimation(plt.gcf(), animate, frames=list(range(len(seq)))+[None]*4, interval=500);
path = f"dumps/aria_sequence_loc/{seq_key}_{chunk_idx}.gif"
anim.save(path, fps=2)
plt.close()
display(Image(filename=path, format="png"));

In [None]:
# Check pointclo