In [2]:
import torch
import cv2
import joblib
import quaternion as q
import numpy as np
from src.utils.render_utils import add_title, add_agent_view_on_w
from src.models.autoencoder.autoenc import Embedder
import copy
from src.utils.camera_trajectory import go_forward, go_backward, rotate_n
import imageio
from IPython.display import HTML, display
from tqdm.notebook import tqdm

np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})

In [3]:
### Habitat sample data ###
# data = joblib.load("sample_data/Elmira_random_traj.dat.gz")

### Gazebo data ###
data = np.load("sample_data/gazebo_traj_seen.npy", allow_pickle=True).item()

In [4]:
device = 'cuda'
embedder = Embedder(pretrained_ckpt='pretrained/autoenc_large.ckpt',
                    img_res=128, w_size=128, coordinate_scale=32, w_ch=32, nerf_res=64, voxel_res=128)
embedder = embedder.to(device).eval()

## Embed RNR-Map along trajectory

In [5]:
from scipy.spatial.transform import Rotation as R
import time

w, w_mask = None, None

T = len(data['rgb'])
K = torch.eye(3)
K[0,0] = (embedder.img_res/2.) / np.tan(np.deg2rad(90.0) / 2)
K[1,1] = -(embedder.img_res/2.) / np.tan(np.deg2rad(90.0) / 2)
K = K.unsqueeze(0).to(device)

start_position = data['position'][0]
start_rotation = q.from_float_array(data['rotation'][0])

orig_Rt = np.eye(4)
orig_Rt[:3,3] = start_position
orig_Rt[:3,:3] = q.as_rotation_matrix(start_rotation)
orig_Rt = np.linalg.inv(orig_Rt)

view_size = 512 #data['rgb'][0].shape[0]
print(data['rgb'][0].shape[0])
time_embedding, time_rendering = [], []
imgs = []

for t in tqdm(range(T)):

    #time.sleep(0.05) #0.5)#0.08)
    
    Rt_t = np.eye(4)
    Rt_t[:3,3] = data['position'][t]
    Rt_t[:3,:3] = q.as_rotation_matrix(q.from_float_array(data['rotation'][t]))
    Rt_t = np.linalg.inv(Rt_t)
    Rt_t = Rt_t @ np.linalg.inv(orig_Rt)
    
    ####### convert from Gazebo coord to RNR-Map coord
    t_ = np.zeros(3) 
    R_ = q.from_rotation_matrix(Rt_t[:3,:3])
    
    R_.y = R_.z; R_.z = 0
    t_[0] = -Rt_t[1,3] 
    t_[2] = -Rt_t[0,3] 
    t_[1] = Rt_t[2,3]
    
    Rt_t[:3, 3] = t_
    Rt_t[:3, :3] = q.as_rotation_matrix(R_)
    
    #############
    
    print(f"[t={t}] ", np.array([data['position'][t][0], data['position'][t][2], data['position'][t][1]]), Rt_t[:3,3], np.array([q.from_rotation_matrix(Rt_t[:3,:3]).w,q.from_rotation_matrix(Rt_t[:3,:3]).x,q.from_rotation_matrix(Rt_t[:3,:3]).y,q.from_rotation_matrix(Rt_t[:3,:3]).z]))
        
    rgb_t = torch.from_numpy(data['rgb'][t]/255.).unsqueeze(0).permute(0,3,1,2).to(device)
    depth_t = torch.from_numpy(data['depth'][t]).unsqueeze(0).permute(0,3,1,2).to(device)
    
    Rt_t = torch.from_numpy(Rt_t).unsqueeze(0).float().to(device)

    with torch.no_grad():
        output = embedder.calculate_mask_func(depth_t*10.0, Rt_t, K)
        sorted_indices, seq_unique_list, seq_unique_counts, _ = output
        input_dict = {'rgb': rgb_t.unsqueeze(1),
                      'depth': depth_t.unsqueeze(1),
                    'sorted_indices': sorted_indices.unsqueeze(1),
                    'seq_unique_counts': seq_unique_counts.unsqueeze(1),
                      'seq_unique_list': seq_unique_list.unsqueeze(1)}
        w, w_mask = embedder.embed_obs(input_dict, past_w=w, past_w_num_mask=w_mask)
        
        recon_rgb, _ = embedder.generate(w, {'Rt': Rt_t.unsqueeze(1), 'K':K.unsqueeze(1)}, out_res=64)

        orig_rgb = cv2.resize(data['rgb'][t], dsize=(view_size, view_size))
        orig_rgb = add_title(orig_rgb, 'Gt Image')
        
        depth_im = cv2.resize(data['depth'][t], dsize=(view_size, view_size))
        depth_im = np.tile(depth_im[:,:,np.newaxis] * 255, (1,1,3))
        depth_im = add_title(depth_im, "Depth").astype(np.uint8)
        
        recon_rgb = (recon_rgb.squeeze().permute(1,2,0).detach().cpu() * 255).numpy().astype(np.uint8)
        recon_rgb = cv2.resize(recon_rgb, (view_size, view_size))
        recon_rgb = add_title(recon_rgb, 'Recon Image')

        w_im = w.mean(0).mean(0).detach().cpu().numpy()
        w_im = ((w_im - w_im.min())/(w_im.max()-w_im.min()) * 255).astype(np.uint8)
        w_im = cv2.applyColorMap(w_im, cv2.COLORMAP_VIRIDIS)[:,:,::-1]
        last_w_im = w_im.copy()

        w_im = add_agent_view_on_w(w_im, Rt_t, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15)#15)
        w_im = cv2.resize(w_im, (view_size, view_size))
        w_img = np.fliplr(w_im)
        w_im = add_title(w_im, 'Map')

        # print(orig_rgb.shape, recon_rgb.shape, w_im.shape)
        view_im = np.concatenate([orig_rgb, recon_rgb, depth_im, w_im],1)
        # view_im = cv2.resize(view_im, (0,0), fx=3, fy=3)

        imgs.append(view_im)
        cv2.imshow("view", view_im[:,:,[2,1,0]])
        key = cv2.waitKey(1)
        if key == ord("q"): break

last_w = w

128


  0%|          | 0/3421 [00:00<?, ?it/s]

[t=0]  [-0.000 0.000 0.000] [-0.000 0.000 0.000] [1.000 -0.000 -0.000 -0.000]


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[t=1]  [-0.000 0.000 0.000] [-0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=2]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=3]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=4]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=5]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=6]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=7]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=8]  [-0.000 0.000 0.000] [0.000 0.000 -0.000] [-1.000 -0.000 0.000 -0.000]
[t=9]  [0.009 0.000 0.000] [-0.000 0.000 0.009] [-1.000 -0.000 0.000 -0.000]
[t=10]  [0.027 0.000 0.000] [-0.000 0.000 0.027] [-1.000 -0.000 0.000 -0.000]
[t=11]  [0.056 0.000 0.000] [-0.000 0.000 0.056] [-1.000 -0.000 0.000 -0.000]
[t=12]  [0.089 0.000 0.000] [-0.000 0.000 0.089] [-1.000 -0.000 0.000 -0.000]
[t=13]  [0.121 0.000 0.000] [-0.000 0.000 0.121] [-1.000 -0.000 

In [6]:
imageio.mimwrite('demo/embedding_traj.gif', imgs, fps=15)
display(HTML('<img src={}>'.format("demo/embedding_traj.gif")))

## Explore inside RNR-Map
- Press 'w, a, s, d' to move
- Press 'q' to quit

In [78]:
images = []
Rt_current = torch.eye(4).unsqueeze(0).to(device).unsqueeze(1)
while True:
    with torch.no_grad():
        rgb, _ = embedder.generate(last_w, {"Rt": Rt_current, 'K': K.unsqueeze(1)}, out_res=64)
        rgb = (rgb.squeeze().permute(1,2,0).detach().cpu() * 255).numpy().astype(np.uint8)
        # rgb = cv2.resize(rgb, (data['rgb'][0].shape[0], data['rgb'][0].shape[0]))
        rgb = cv2.resize(rgb, (view_size, view_size))
        rgb = add_title(rgb, 'Recon Image')
        
        w_color = copy.deepcopy(last_w_im)
        w_color = add_agent_view_on_w(w_color, Rt_current, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15)
        w_color = np.fliplr(w_color)
        w_color = cv2.resize(w_color, (view_size, view_size))
        w_color = add_title(w_color, 'Map')
        view_im = np.concatenate([rgb, w_color],1)
        cv2.imshow("view", view_im[:,:,::-1])
        key = cv2.waitKey(0)
        if key == ord('q'): break
        elif key == ord('a'):
            Rt = rotate_n(n=-10.0).to(device)
            Rt_current = (Rt@Rt_current.squeeze()).unsqueeze(0).unsqueeze(0)
        elif key == ord('d'):
            Rt = rotate_n(n=10.0).to(device)
            Rt_current = (Rt@Rt_current.squeeze()).unsqueeze(0).unsqueeze(0)
        elif key == ord("w"):
            Rt_current = go_forward(Rt_current, step=0.1)
        elif key == ord('s'):
            Rt_current = go_backward(Rt_current, step=0.1)
        images.append(view_im)

In [7]:
#imageio.mimwrite("demo/explore_RNR_map.gif", images, fps=15)
display(HTML('<img src={}>'.format("demo/explore_RNR_map.gif")))

# Localize seen images

In [10]:
import torch.nn.functional as F
device = 'cuda'
embedder = Embedder(pretrained_ckpt='pretrained/autoenc_large.ckpt',
                   img_res=128, w_size=128, coordinate_scale=32, w_ch=32, nerf_res=128, voxel_res=128)
embedder = embedder.to(device).eval()

from src.utils.image_rotator import ImageRotator
from src.models.localization.models import UConv
# localizer = UConv(w_size=128,num_rot=36, w_ch=32, angle_ch=18)
localizer = UConv(w_size=128,num_rot=36, w_ch=32, angle_ch=18)
rotator = ImageRotator(36)
sd = torch.load('pretrained/img_loc.ckpt', map_location='cpu')
localizer.load_state_dict(sd)
localizer = localizer.cuda().eval()

In [15]:
data_test = {'rgb': data['rgb'], 'depth': data['depth'], 'position': data['position'], 'rotation': data['rotation'], 'map': w, 'orig_Rt': orig_Rt}


origin = torch.eye(4).unsqueeze(0).to(device)
images = []
diffs = []

coordinate_scale = embedder.coordinate_scale
map_size = embedder.w_size
patch_size = map_size//4
angle_bin = 18
VIS_RES = 512

for t in range(len(data_test['rgb'])):
    time.sleep(0.1) #0.05)
    
    # make target RNR-Map
    rgb = torch.from_numpy(data_test['rgb'][t]).unsqueeze(0).permute(0,3,1,2).to(device)
    depth = torch.from_numpy(data_test['depth'][t]).unsqueeze(0).permute(0,3,1,2).to(device)
    sorted_indices, seq_unique_list, seq_unique_counts, pose_map = embedder.calculate_mask_func(depth * 10.0, origin, K)

    sample_dict = {'sorted_indices': sorted_indices.unsqueeze(0),
               'seq_unique_list': seq_unique_list.unsqueeze(0),
               'seq_unique_counts': seq_unique_counts.unsqueeze(0),
              'rgb': rgb.unsqueeze(1)/255., 'depth': depth.unsqueeze(1)}
    for k,v in sample_dict.items():
        sample_dict[k] = v.cuda()

    with torch.no_grad():
        latent_target, _ = embedder.embed_obs(sample_dict)
        latent_target = latent_target[:, :, map_size//2 - patch_size//2 : map_size//2 + patch_size//2,
                                            map_size//2 - patch_size//2 : map_size//2 + patch_size//2]

        # Localize
        pred_heatmap, pred_angle = localizer(data_test['map'], latent_target, rotator)
        seen_area = (data_test['map'].mean(dim=1) != 0)
        bs, ws, hs = torch.where(seen_area == 0)
        pred_heatmap[bs, :, ws, hs] = -99999
        pred_heatmap[bs, :, :, -1] = -99999
        pred_heatmap[bs, :, -1, :] = -99999
        pred_heatmap_flattened = F.softmax(pred_heatmap.view(1, -1), dim=-1)
        pred = pred_heatmap_flattened.view(map_size+1, map_size+1)



    pred_max = pred_heatmap.view(1, -1).argmax(dim=1).item()
    pred_h, pred_w = pred_max//pred_heatmap.shape[-1], pred_max%pred_heatmap.shape[-1]
    pred_x = (pred_h-(map_size//2))/(map_size//2) * (coordinate_scale/2)
    pred_y = (pred_w-(map_size//2))/(map_size//2) * (coordinate_scale/2)

    pred_Rt = np.eye(4)
    pred_Rt[:3,3] = np.array([pred_x, 0., pred_y])
    pred_Rt[:3,:3] = q.as_rotation_matrix(q.from_euler_angles([0., 2*np.pi/angle_bin * pred_angle.argmax().item(), 0.0]))
    pred_Rt = np.linalg.inv(pred_Rt)
    pred_sim_Rt = np.linalg.inv(np.matmul(pred_Rt, data_test['orig_Rt']))

    # Get observation from predicted pose
    pred_position_ = pred_sim_Rt[:3,3]
    pred_rotation = q.from_rotation_matrix(pred_sim_Rt[:3,:3])
    # pred_obs = sim.get_observations_at(pred_position_, pred_rotation)
    
    with torch.no_grad():
        Rt_t = torch.from_numpy(pred_sim_Rt).unsqueeze(0).float().to(device)
        rgb_recon, _ = embedder.generate(w, {"Rt": Rt_t.unsqueeze(1), 'K': K.unsqueeze(1)}, out_res=64)
        

    # Calculate Localization error
    Rtt = np.eye(4)
    Rtt[:3,3] = data_test['position'][t]
    Rtt[:3,:3] = q.as_rotation_matrix(q.from_float_array(data_test['rotation'][t]))
    Rtt = np.linalg.inv(Rtt)@np.linalg.inv(data_test['orig_Rt'])
    answer_x, _, answer_y = np.linalg.inv(Rtt)[:3,3]
    answer_h = int(answer_x*(map_size/2)/(coordinate_scale/2.)+(map_size/2))
    answer_w = int(answer_y*(map_size/2)/(coordinate_scale/2.)+(map_size/2))

    
    ####### convert from Gazebo coord to RNR-Map coord
    
    t_ = np.zeros(3) 
    R_ = q.from_rotation_matrix(Rtt[:3,:3])   
    R_.y = R_.z; R_.z = 0
    t_[0] = -Rtt[1,3] 
    t_[2] = -Rtt[0,3] 
    t_[1] = Rtt[2,3]
    Rtt[:3, 3] = t_
    Rtt[:3, :3] = q.as_rotation_matrix(R_)
    
    #############
    
    
    #diff = np.linalg.norm(data_test['position'][t][:2]-pred_position_[:2])
    diff = np.linalg.norm(data_test['position'][t][:2]-np.array([-pred_position_[2], -pred_position_[0] ]))
    diffs.append(diff)
    # print(data_test['position'][t])
    # print(pred_position_)
    print(f"[t={t}] err:", diff)

    # Visualization
    map_im = last_w_im.copy()
    map_im = add_agent_view_on_w(map_im, Rtt, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15, agent_color=(255,0,0), view_color=(255,0,0))
    map_im = add_agent_view_on_w(map_im, pred_Rt, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15, agent_color=(0,0,255), view_color=(0,0,255))
    map_im = cv2.resize(map_im, dsize=(VIS_RES, VIS_RES))

    pred_im = pred[:-1,:-1].detach().cpu().numpy()
    pred_im = (pred_im - pred_im.min())/(pred_im.max()-pred_im.min())
    pred_im = cv2.resize((pred_im*255).astype(np.uint8), dsize=(VIS_RES, VIS_RES))
    pred_im = cv2.applyColorMap(pred_im, cv2.COLORMAP_VIRIDIS)[:,:,::-1]
    pred_im = cv2.addWeighted(cv2.resize(last_w_im, dsize=(VIS_RES, VIS_RES)), 0.3, pred_im, 0.7, 0.0)

    map_im = add_title(map_im, 'RNR-Map')
    pred_im = add_title(pred_im, 'Loc. Heatmap')

    rgb = cv2.resize(data_test['rgb'][t], dsize=(VIS_RES, VIS_RES))
    rgb = add_title(rgb, 'Query Img.')
    # pred_rgb = cv2.resize(pred_obs['rgb'], dsize=(VIS_RES, VIS_RES))
    # pred_rgb = add_title(pred_rgb, 'Localized')
    
    rgb_recon = (rgb_recon.squeeze().permute(1,2,0).detach().cpu() * 255).numpy().astype(np.uint8)
    rgb_recon = cv2.resize(rgb_recon, (view_size, view_size))
    rgb_recon = add_title(rgb_recon, 'Recon Image')

    # view_im = np.concatenate([rgb, pred_rgb,map_im, pred_im],1)
    view_im = np.concatenate([rgb, map_im, pred_im],1)
    # view_im = np.concatenate([rgb, rgb_recon, map_im, pred_im],1)
    images.append(view_im)
    cv2.imshow("view", view_im[:,:,::-1])
    key = cv2.waitKey(1)
    if key == ord("q"): break

print(f'average error: {np.mean(diffs):.3f}m' )

[t=0] err: 2.704240471489373
[t=1] err: 2.7042406335848557
[t=2] err: 2.7042406890415602
[t=3] err: 2.7042407860149424
[t=4] err: 2.704240906644439
[t=5] err: 2.7042410328377136
[t=6] err: 2.7042411327366462
[t=7] err: 2.7042412082583387
[t=8] err: 2.7042413206228435
[t=9] err: 2.6994590521199644
[t=10] err: 2.689199639843121
[t=11] err: 2.6736037534757617
[t=12] err: 2.656063078972564
[t=13] err: 2.638760782198205
[t=14] err: 4.798524253895989
[t=15] err: 4.770564717175383
[t=16] err: 4.742666610770953
[t=17] err: 2.572573970706163
[t=18] err: 2.5568123192813124
[t=19] err: 2.5455095384414093
[t=20] err: 2.5390068545262547
[t=21] err: 5.205519173268581
[t=22] err: 2.659130952335609
[t=23] err: 4.647332704267653
[t=24] err: 2.5346602825079176
[t=25] err: 2.534655917189357
[t=26] err: 2.5346519054673107
[t=27] err: 2.5346466387129167
[t=28] err: 2.5346382355510553
[t=29] err: 2.534630776572119
[t=30] err: 4.984788496386711
[t=31] err: 5.173468442191848
[t=32] err: 2.534615563623889
[t=3

# Localize unseen images (not observed during exploration)

In [7]:
import torch.nn.functional as F
device = 'cuda'
embedder = Embedder(pretrained_ckpt='pretrained/autoenc_large.ckpt',
                   img_res=128, w_size=128, coordinate_scale=32, w_ch=32, nerf_res=128, voxel_res=128)
embedder = embedder.to(device).eval()

from src.utils.image_rotator import ImageRotator
from src.models.localization.models import UConv
# localizer = UConv(w_size=128,num_rot=36, w_ch=32, angle_ch=18)
localizer = UConv(w_size=128,num_rot=36, w_ch=32, angle_ch=18)
rotator = ImageRotator(36)
sd = torch.load('pretrained/img_loc.ckpt', map_location='cpu')
localizer.load_state_dict(sd)
localizer = localizer.cuda().eval()

In [8]:
data_ = np.load("./sample_data/gazebo_traj_unseen_2.npy", allow_pickle=True).item()

data_test = {'rgb': [data_['rgb'][-1]], 'depth': [data_['depth'][-1]], 'position': [data_['position'][-1]], 'rotation': [data_['rotation'][-1]], 'map': w, 'orig_Rt': orig_Rt}


origin = torch.eye(4).unsqueeze(0).to(device)
images = []
diffs = []

coordinate_scale = embedder.coordinate_scale
map_size = embedder.w_size
patch_size = map_size//4
angle_bin = 18
VIS_RES = 512

for t in range(len(data_test['rgb'])):
    time.sleep(0.1) #0.05)
    
    # make target RNR-Map
    rgb = torch.from_numpy(data_test['rgb'][t]).unsqueeze(0).permute(0,3,1,2).to(device)
    depth = torch.from_numpy(data_test['depth'][t]).unsqueeze(0).permute(0,3,1,2).to(device)
    sorted_indices, seq_unique_list, seq_unique_counts, pose_map = embedder.calculate_mask_func(depth * 10.0, origin, K)

    sample_dict = {'sorted_indices': sorted_indices.unsqueeze(0),
               'seq_unique_list': seq_unique_list.unsqueeze(0),
               'seq_unique_counts': seq_unique_counts.unsqueeze(0),
              'rgb': rgb.unsqueeze(1)/255., 'depth': depth.unsqueeze(1)}
    for k,v in sample_dict.items():
        sample_dict[k] = v.cuda()

    with torch.no_grad():
        latent_target, _ = embedder.embed_obs(sample_dict)
        latent_target = latent_target[:, :, map_size//2 - patch_size//2 : map_size//2 + patch_size//2,
                                            map_size//2 - patch_size//2 : map_size//2 + patch_size//2]

        # Localize
        pred_heatmap, pred_angle = localizer(data_test['map'], latent_target, rotator)
        seen_area = (data_test['map'].mean(dim=1) != 0)
        bs, ws, hs = torch.where(seen_area == 0)
        pred_heatmap[bs, :, ws, hs] = -99999
        pred_heatmap[bs, :, :, -1] = -99999
        pred_heatmap[bs, :, -1, :] = -99999
        pred_heatmap_flattened = F.softmax(pred_heatmap.view(1, -1), dim=-1)
        pred = pred_heatmap_flattened.view(map_size+1, map_size+1)

    pred_max = pred_heatmap.view(1, -1).argmax(dim=1).item()
    pred_h, pred_w = pred_max//pred_heatmap.shape[-1], pred_max%pred_heatmap.shape[-1]
    pred_x = (pred_h-(map_size//2))/(map_size//2) * (coordinate_scale/2)
    pred_y = (pred_w-(map_size//2))/(map_size//2) * (coordinate_scale/2)

    pred_Rt = np.eye(4)
    pred_Rt[:3,3] = np.array([pred_x, 0., pred_y])
    pred_Rt[:3,:3] = q.as_rotation_matrix(q.from_euler_angles([0., 2*np.pi/angle_bin * pred_angle.argmax().item(), 0.0]))
    pred_Rt = np.linalg.inv(pred_Rt)
    pred_sim_Rt = np.linalg.inv(np.matmul(pred_Rt, data_test['orig_Rt']))

    # Get observation from predicted pose
    pred_position_ = pred_sim_Rt[:3,3]
    pred_rotation = q.from_rotation_matrix(pred_sim_Rt[:3,:3])
    # pred_obs = sim.get_observations_at(pred_position_, pred_rotation)

    # Calculate Localization error
    Rtt = np.eye(4)
    Rtt[:3,3] = data_test['position'][t]
    Rtt[:3,:3] = q.as_rotation_matrix(q.from_float_array(data_test['rotation'][t]))
    Rtt = np.linalg.inv(Rtt)@np.linalg.inv(data_test['orig_Rt'])
    answer_x, _, answer_y = np.linalg.inv(Rtt)[:3,3]
    answer_h = int(answer_x*(map_size/2)/(coordinate_scale/2.)+(map_size/2))
    answer_w = int(answer_y*(map_size/2)/(coordinate_scale/2.)+(map_size/2))

    
    ####### convert from Gazebo coord to RNR-Map coord
    
    t_ = np.zeros(3) #Rt_t[:3,3]
    R_ = q.from_rotation_matrix(Rtt[:3,:3])   
    R_.y = R_.z; R_.z = 0
    t_[0] = -Rtt[1,3] 
    t_[2] = -Rtt[0,3] 
    t_[1] = Rtt[2,3]
    Rtt[:3, 3] = t_
    Rtt[:3, :3] = q.as_rotation_matrix(R_)
    
    #############
    
    diff = np.linalg.norm(data_test['position'][t][:2]-np.array([-pred_position_[2], -pred_position_[0] ]))
    diffs.append(diff)
    print('err:', diff)

    # Visualization
    map_im = last_w_im.copy()
    map_im = add_agent_view_on_w(map_im, Rtt, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15, agent_color=(255,0,0), view_color=(255,0,0))
    map_im = add_agent_view_on_w(map_im, pred_Rt, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15, agent_color=(0,0,255), view_color=(0,0,255))
    map_im = cv2.resize(map_im, dsize=(VIS_RES, VIS_RES))

    pred_im = pred[:-1,:-1].detach().cpu().numpy()
    pred_im = (pred_im - pred_im.min())/(pred_im.max()-pred_im.min())
    pred_im = cv2.resize((pred_im*255).astype(np.uint8), dsize=(VIS_RES, VIS_RES))
    pred_im = cv2.applyColorMap(pred_im, cv2.COLORMAP_VIRIDIS)[:,:,::-1]
    pred_im = cv2.addWeighted(cv2.resize(last_w_im, dsize=(VIS_RES, VIS_RES)), 0.3, pred_im, 0.7, 0.0)

    map_im = add_title(map_im, 'RNR-Map')
    pred_im = add_title(pred_im, 'Loc. Heatmap')

    rgb = cv2.resize(data_test['rgb'][t], dsize=(VIS_RES, VIS_RES))
    rgb = add_title(rgb, 'Query Img.')
    # pred_rgb = cv2.resize(pred_obs['rgb'], dsize=(VIS_RES, VIS_RES))
    # pred_rgb = add_title(pred_rgb, 'Localized')

    # view_im = np.concatenate([rgb, pred_rgb,map_im, pred_im],1)
    view_im = np.concatenate([rgb,map_im, pred_im],1)
    images.append(view_im)
    cv2.imshow("view", view_im[:,:,::-1])
    key = cv2.waitKey(1)
    if key == ord("q"): break

print(f'average error: {np.mean(diffs):.3f}m' )

err: 2.1506365747037592
average error: 2.151m
