In [1]:
%load_ext autoreload
%autoreload 2

import torch
import yaml
import sys
from torchmetrics import MetricCollection
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from pytorch_lightning import seed_everything
sys.path.append('fbsource/fbcode/surreal/')
from maploc.data import MapillaryDataModule
from maploc.module import GenericModule
from maploc.utils.viz_2d import plot_images, features_to_RGB, save_plot, add_text
from maploc.utils.viz_localization import likelihood_overlay, plot_pose, plot_dense_rotations
from maploc.osm.viz import Colormap, plot_nodes
from maploc.models.metrics import Location2DError, AngleError
from maploc.models.hough_voting import argmax_xyr, fuse_gps
from maploc.models.refinement import FeaturemetricRefiner, subpixel_refinement
torch.set_grad_enabled(False);
plt.rcParams.update({'figure.max_open_warning': 0})

In [10]:
conf = OmegaConf.load('fbsource/fbcode/surreal/maploc/conf/data/mapillary_v4.yaml')
conf = OmegaConf.merge(conf, OmegaConf.create(yaml.full_load("""
# Comment the next 3 lines to pull data from Manifold
local_dir: "./data/mapillary_dumps_v4/"
dump_dir: ${.local_dir}
scenes:
    - sanfrancisco_soma
    # - sanfrancisco_hayes
    # - amsterdam
    # - berlin
    # - lemans
    # - montrouge
    # - toulouse
    # - nantes
    # - vilnius
    # - avignon
    # - helsinki
    # - milan
    # - paris
    # - brussels
    # - newyork_hoboken
    # - metropolis_train
split: "splits_mly14_from-v2.json"
# split: null
filter_for: null
tiles_filename: tiles.pkl
resize_image: 512
force_camera_height: 1.65

return_gps: true
loading:
    val: {batch_size: 1, num_workers: 0}
    train: ${.val}
random: false
augmentation: {rot90: false, flip: false, image: {apply: false}}
add_map_mask: false
max_init_error_rotation: 10
""")))
OmegaConf.resolve(conf)
dataset = MapillaryDataModule(conf)
dataset.prepare_data()
dataset.setup()

cfg2 = OmegaConf.load('fbsource/fbcode/surreal/maploc/conf/bev_plane-v2.yaml')
cfg2.model.image_encoder = OmegaConf.load('fbsource/fbcode/surreal/maploc/conf/model/image_encoder/resnet_fpn.yaml')
cfg2.data = dataset.cfg
OmegaConf.resolve(cfg2)
model = GenericModule(cfg2).eval()

In [19]:
dataset.cfg.max_init_error_rotation = 10

In [None]:
seed_everything(25)
loader = dataset.dataloader("val", shuffle=True)
batch = next(iter(loader))
pred = model(batch)
camera = batch["camera"]
height = batch["camera_height"]
ppm = model.model.conf.pixel_per_meter

In [42]:
level = 0
camera = camera.scale(1 / model.model.image_encoder.scales[level])
feats_map = pred["map"]["map_features"][0]
feats_image = pred["image"]["feature_maps"][level][:, :feats_map.shape[1]]
feats_image = feats_image.flatten(-2).permute(0, 2, 1)

from maploc.models.utils import rotmat2d, rotmat2d_grad, make_grid, deg2rad, rad2deg
from maploc.models.interpolation import Interpolator

def masked_mean(x, mask, dim):
    mask = mask.float()
    return (mask * x).sum(dim) / mask.sum(dim).clamp(min=1)

# get grid in camera coordinates
h, w = feats_image.shape[-2:]
print(h, w, camera.size)
grid = make_grid(w, h, device=feats_image.device).reshape(-1, 2)  # HW,2
grid_norm = camera.denormalize(grid.unsqueeze(0)) # B,HW,2
eps = 1e-5
u, v = grid_norm[..., 0], grid_norm[..., 1]
v_clipped = v.clip(min=eps)
depth = height.unsqueeze(1) / v_clipped
x = u / v_clipped * height.unsqueeze(1)
valid_grid = v > eps
grid_xy_cam = torch.stack([x, depth], -1) # B,HW,2

interpolator = Interpolator(mode="linear", pad=1)
yaw_init = batch["yaw_prior"][..., 0]
uv_pose = batch["xy_init"]
yaw_pose = deg2rad(yaw_init)

grid_uv_cam = grid_xy_cam * ppm * grid_xy_cam.new_tensor([1, -1])
R = rotmat2d(yaw_pose)  # B,2,2
grid_uv_world = uv_pose.unsqueeze(-2) + torch.einsum('bij,bnj->bni', R, grid_uv_cam)
feats_map_view, valid_view, J_f_p2d = interpolator(
    feats_map, grid_uv_world, return_gradients=True
)
valid_cost = valid_view & valid_grid
res = feats_map_view - feats_image
error = torch.sum(res**2, -1)
cost = masked_mean(error, valid_cost, -1)

from maploc.models.refinement import build_system, optimizer_step
R_grad = rotmat2d_grad(yaw_pose)
J_p2d_T = torch.cat(
    [
        torch.diag_embed(torch.ones_like(res)),
        torch.einsum('bij,bnj->bni', R_grad, grid_uv_cam).unsqueeze(-1)
    ],
    -1,
)
J = J_f_p2d @ J_p2d_T
damping = 0.1
g, H = build_system(res, J, valid_cost.float())
delta = optimizer_step(g, H, lambda_=damping)

uv_new = uv_pose + delta[..., :2]
yaw_new = yaw_pose + delta[..., 2]

center_map = feats_map.new_tensor(feats_map.shape[-2:]) / 2 - 0.5
uv_pose_range = ppm * dataset.cfg.max_init_error
bounds = center_map + center_map.new_tensor([-1, 1]) * uv_pose_range
uv_pose_reinit = torch.rand_like(uv_new)
rand_u = torch.distributions.uniform.Uniform(-1, 1).sample([B, 1]).to(shift_u.device)
rand_v = torch.distributions.uniform.Uniform(-1, 1).sample([B, 1]).to(shift_u.device)

In [43]:
feats_map_view.shape, feats_image.shape

In [36]:
grid_uv_cam.shape

In [26]:
# get satellite coordinates given a pose

# compute the jacobian from 