In [1]:
%load_ext autoreload
%autoreload 2

import torch
import yaml
import sys
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from pathlib import Path
from pytorch_lightning import seed_everything
sys.path.append('fbsource/fbcode/scripts/psarlin/')
from maploc.data.loader_mapillary import MapillaryDataModule
from maploc.module import GenericModule, find_best_checkpoint
from maploc.test import evaluate_single_image, evaluate_sequential
torch.set_grad_enabled(False);

## Setup dataset

In [None]:
conf = OmegaConf.load('fbsource/fbcode/scripts/psarlin/maploc/conf/data_mapillary.yaml')
conf = OmegaConf.merge(conf, OmegaConf.create(yaml.full_load("""
local_dir: "./data/mapillary_dumps_v2/"
# local_dir: "./devvm1417.cln0.facebook.com/data/mapillary_dumps_v2/"
dump_dir: ${.local_dir}
scenes:
    - sanfrancisco_soma
    - sanfrancisco_hayes
    - amsterdam
    - berlin
    - lemans
    - montrouge
    - toulouse
    - nantes
    - vilnius
    - avignon
    - helsinki
    - milan
    # - newyork_hoboken
    # - metropolis_train
# split: None
return_gps: true
val: {batch_size: 1, num_workers: 3}
train: ${.val}
random: false
augmentation: {rot90: false, flip: false}
""")))
if not Path(conf.local_dir).exists():
    conf.local_dir = str(Path("devvm1417.cln0.facebook.com", conf.local_dir))
    assert Path(conf.local_dir).exists()
OmegaConf.resolve(conf)
dataset = MapillaryDataModule(conf)
dataset.prepare_data()
dataset.setup()
scenes = "mply12" if len(dataset.cfg.scenes) == 12 else "-".join(dataset.cfg.scenes)
results = {}

## Setup model

In [3]:
# exper = "basic-rotation-osm2-mly12-n100_rn18-vgg13_b9-resize256_d64_nrot32"
exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs9-resize256_norm-d8-nrot64"
# exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs10-resize256_attn-fix-2-2-128d_norm-d8-nrot64"
# exper = "bev1-osm2-mly12-n100_vgg16-vgg13_bs6-resize256_attn-fix-2-2-128d_norm-d8-nrot64"
exper = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64"
exper = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior"

root = "manifold://psarlin/tree/maploc/experiments"
path = f'{root}/{exper}/last.ckpt'
# path = find_best_checkpoint(f'{root}/{exper}')
print(path)
cfg = {}
cfg = {'model': {"num_rotations": 128}}
model = GenericModule.load_from_checkpoint(path, strict=True, find_best=True, cfg=cfg)
model.eval()
if torch.cuda.is_available():
    model = model.cuda()
exper += '_nrot128'

# Eval loop

In [8]:
thresholds = [1, 2, 5, 10]
seed_everything(0)
results[exper] = ret = evaluate_single_image(
    dataset.dataloader("val", shuffle=True), model, thresholds, fuse="rand", num=3000)  # 5k NYC
errors, recalls, aucs = ret
# errors, recalls, recalls_per_index, aucs, aucs_per_index = evaluate_sequential(
#     dataset.dataset("val"), model, thresholds, num=75, max_length=100, min_length=1, max_dist=15, joint=True)
# results[exper+"_seq"] = ret = (errors, recalls, aucs)

print(scenes, exper)
col_size = max(map(len, aucs))
s = ' '*(col_size+1) + ' '.join(f'{t:>5}' for t in thresholds)
for k, auc in aucs.items():
    s += f'\n{k:>{col_size}} '
    s += ' '.join(f'{v:>5.2f}' for v in auc)
print(s)

plt.figure(dpi=150, figsize=(10, 4))
plt.subplot(121)
plt.plot(*recalls["xy_max_error"], label="argmax", c="k")
plt.plot(*recalls["xy_gps_error"], label="gps", c="b")
if "xy_prior_error" in recalls:
    plt.plot(*recalls["xy_prior_error"], label="rand", c="r")
plt.plot(*recalls["xy_fused_error"], label="fused", c="g")
plt.xlim([0, 15]);
plt.xlabel("Position error [m]")
plt.ylabel("Recall")
plt.legend();

plt.subplot(122)
plt.plot(*recalls["yaw_max_error"], label="argmax", c="k")
plt.plot(*recalls["yaw_fused_error"], label="fused", c="g")
plt.xlim([0, 10]);
plt.xlabel("Orientation error [deg]");

# Ablation

In [None]:
exper_bev = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior_nrot128"
exper_seq = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior_nrot128_seq"
exper_nobev = "basic-rotation-osm2-mly12-n100_rn18-vgg13_b9-resize256_d64_nrot32_nrot128"

lw = 3
plt.figure(dpi=150, figsize=(10, 4))
plt.subplot(121)
plt.plot(*results[exper_bev][1]["xy_max_error"], label="single image", c="k", lw=lw)
plt.plot(*results[exper_bev][1]["xy_fused_error"], label="single with prior", c="r", lw=lw)
plt.plot(*results[exper_seq][1]["xy_seq_error"], label="sequence", c="g", lw=lw)
plt.plot(*results[exper_nobev][1]["xy_max_error"], label="no geometry", c="b", lw=lw)
plt.plot(*recalls["xy_prior_error"], label="prior: gps+N(0, 15)", c="r", linestyle=":", lw=lw)
plt.xlim([0, 10]);
plt.xlabel("Position error [m]")
plt.ylabel("Recall")
plt.legend(loc="lower center", ncol=2, bbox_to_anchor=(0.5, 1.));

plt.subplot(122)
plt.plot(*results[exper_bev][1]["yaw_max_error"], c="k", lw=lw)
plt.plot(*results[exper_bev][1]["yaw_fused_error"], c="r", lw=lw)
plt.plot(*results[exper_seq][1]["yaw_seq_error"], label="sequence", c="g", lw=lw)
plt.plot(*results[exper_nobev][1]["yaw_fused_error"], label="no geometry", c="b", lw=lw)
plt.xlim([0, 10]);
plt.xlabel("Orientation error [deg]");

# Sequential eval

In [None]:
exper_seq = "bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior_nrot128_seq75"

lw = 3
plt.figure(dpi=150, figsize=(10, 4))
plt.subplot(121)
plt.plot(*results[exper_seq][1]["xy_max_error"], label="single image", c="g", lw=lw)
plt.plot(*results[exper_seq][1]["xy_seq_error"], label="sequence", c="g", lw=lw, linestyle="dashed")
plt.plot(*results[exper_seq][1]["xy_gps_error"], label="gps", c="b", lw=lw)
plt.plot(*results[exper_seq][1]["xy_gps_seq_error"], label="gps sequence", c="b", lw=lw, linestyle="dashed")
plt.xlim([0, 10]);
plt.xlabel("Position error [m]")
plt.ylabel("Recall")
plt.legend(loc="lower center", ncol=2, bbox_to_anchor=(0.5, 1.));

plt.subplot(122)
plt.plot(*results[exper_seq][1]["yaw_max_error"], c="k", lw=lw)
plt.plot(*results[exper_seq][1]["yaw_seq_error"], label="sequence", c="g", lw=lw)
plt.xlim([0, 10]);
plt.xlabel("Orientation error [deg]");

## Model comparison

In [None]:
exps = {
    "prior rot=64": ("bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior_nrot64", "r"),
    "prior rot=128": ("bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior_nrot128", "k"),
    "prior rot=256": ("bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64-prior_nrot256", "g"),
    "noprior rot=128": ("bev1-osm2-mly12-n100_vgg16-vgg13-plane_bs9-resize256_norm-d8-nrot64_nrot128", "orange"),
}
plt.figure(dpi=150, figsize=(10, 4))
plt.subplot(121)
plt.plot(*recalls["xy_prior_error"], label="gps + N(0, 15)", c="b")
for l, (e, c) in exps.items():
    print(e)
    plt.plot(*results[e][1]["xy_max_error"], label=l, c=c, lw=2)
    plt.plot(*results[e][1]["xy_fused_error"], c=c, linestyle=":", lw=2)
plt.xlim([0, 10]);
plt.xlabel("Position error [m]")
plt.ylabel("Recall")
plt.legend();
plt.subplot(122)
for l, (e, c) in exps.items():
    plt.plot(*results[e][1]["yaw_max_error"], c=c, lw=2)
    plt.plot(*results[e][1]["yaw_fused_error"], c=c, linestyle=":", lw=2)
plt.xlim([0, 10]);
plt.xlabel("Orientation error [deg]");

# Sequential ablation

In [32]:
thresholds = [1, 2, 5, 10]
dset = dataset.dataset("val")
errors, recalls, recalls_per_index, aucs, aucs_per_index = evaluate_sequential(
    dset, model, thresholds, num=200, max_length=10, min_length=10, max_dist=15)

In [None]:
def plot_recalls_seq(recalls_seq):
    # for i, r in recalls_seq.items():
    for i in [0, 1, 3, 7, 11, 15, 23]:
        r = recalls_seq[i]
        # if i == 0:
        #     continue
        idx = i + 1
        plt.plot(*r, label=f"length={idx}", c=mpl.cm.winter(np.log(1+idx)/np.log(max(recalls_seq)+1)), lw=lw)

print(scenes, exper)
lw = 2
plt.figure(dpi=150, figsize=(10, 4))
plt.subplot(121)
# plt.plot(*recalls["xy_gps_error"], label="gps", c="b", lw=lw)
# plt.plot(*recalls["xy_max_error"], label="single", c="k", lw=lw)
plot_recalls_seq(recalls_per_index["xy_seq_error"])
plt.xlim([0, 10]);
plt.xlabel("Position error [m]")
plt.ylabel("Recall")
plt.gca().minorticks_on()
plt.gca().xaxis.set_minor_locator(mpl.ticker.MultipleLocator(1))
plt.gca().tick_params(axis='x', which='minor', bottom=True)
plt.legend()
plt.legend(ncol=1, bbox_to_anchor=(-0.2, 0.5), loc='center right');

plt.subplot(122)
# plt.plot(*recalls["yaw_max_error"], label="single", c="k", lw=lw)
plot_recalls_seq(recalls_per_index["yaw_seq_error"])
plt.xlim([0, 10]);
plt.xlabel("Orientation error [deg]");

In [None]:
seq_labels = [f"xy seq {i}" for i in aucs_per_index["xy_seq_error"]]
col_size = max(map(len, list(aucs)+seq_labels))
s = ' '*(col_size+1) + ' '.join(f'{t:>5}' for t in thresholds)
for k in ["xy_gps_error", "xy_max_error"]:
    s += f'\n{k:>{col_size}} '
    s += ' '.join(f'{v:>5.2f}' for v in aucs[k])
for k, auc in zip(seq_labels, aucs_per_index["xy_seq_error"].values()):
    s += f'\n{k:>{col_size}} '
    s += ' '.join(f'{v:>5.2f}' for v in auc)
print(s)

In [292]:
ths = [1, 2, 5]
print(scenes, exper)
plt.figure(dpi=100, figsize=(12, 4))
for i, th in enumerate(ths):
    plt.subplot(131+i)
    th_index = thresholds.index(th)
    values =  [auc[th_index] for auc in aucs_per_index["xy_seq_error"].values()]
    lengths = np.array(list(aucs_per_index["xy_seq_error"].keys()))+1
    plt.bar(lengths, values, label="sequential");
    plt.axhline(y=aucs["xy_gps_error"][th_index], color='r', linestyle='-', zorder=1, label="gps")
    plt.axhline(y=aucs["xy_max_error"][th_index], color='k', linestyle='-', zorder=1, label="single")
    plt.xlabel('Sequence index')
    plt.ylabel(f'AUC @ {th}m')
    plt.xticks(lengths);
plt.subplot(131)
plt.legend(ncol=3, bbox_to_anchor=(0.5, 1.0), loc='lower center');

In [12]:
counts = {i: len(r[0])-1 for i, r in recalls_per_index["xy_seq_error"].items()}
counts = np.array([counts[i] for i in range(max(counts)+1)])
lengths = np.arange(len(counts))+1
plt.bar(lengths, counts)
plt.xlabel('Sequence index')
plt.ylabel('Image count')
plt.xticks(lengths);