In [1]:
import torch
import cv2
import joblib
import quaternion as q
import numpy as np
from src.utils.render_utils import add_title, add_agent_view_on_w
from src.models.autoencoder.autoenc import Embedder
import copy
from src.utils.camera_trajectory import go_forward, go_backward, rotate_n
import imageio
from IPython.display import HTML, display
from tqdm.notebook import tqdm

In [2]:
data = joblib.load("sample_data/Elmira_random_traj.dat.gz")

In [3]:
device = 'cuda'
embedder = Embedder(pretrained_ckpt='pretrained/autoenc_large.ckpt',
                    img_res=128, w_size=128, coordinate_scale=32, w_ch=32, nerf_res=64, voxel_res=128)
embedder = embedder.to(device).eval()

## Embed RNR-Map along trajectory

In [None]:
w, w_mask = None, None

T = len(data['rgb'])
K = torch.eye(3)
K[0,0] = (embedder.img_res/2.) / np.tan(np.deg2rad(90.0) / 2)
K[1,1] = -(embedder.img_res/2.) / np.tan(np.deg2rad(90.0) / 2)
K = K.unsqueeze(0).to(device)

start_position = data['position'][0]
start_rotation = q.from_float_array(data['rotation'][0])

orig_Rt = np.eye(4)
orig_Rt[:3,3] = start_position
orig_Rt[:3,:3] = q.as_rotation_matrix(start_rotation)
orig_Rt = np.linalg.inv(orig_Rt)

view_size = data['rgb'][0].shape[0]
time_embedding, time_rendering = [], []
imgs = []
for t in tqdm(range(T)):

    Rt_t = np.eye(4)
    Rt_t[:3,3] = data['position'][t]
    Rt_t[:3,:3] = q.as_rotation_matrix(q.from_float_array(data['rotation'][t]))
    Rt_t = np.linalg.inv(Rt_t)
    Rt_t = Rt_t @ np.linalg.inv(orig_Rt)

    rgb_t = torch.from_numpy(data['rgb'][t]/255.).unsqueeze(0).permute(0,3,1,2).to(device)
    depth_t = torch.from_numpy(data['depth'][t]).unsqueeze(0).permute(0,3,1,2).to(device)

    Rt_t = torch.from_numpy(Rt_t).unsqueeze(0).float().to(device)

    with torch.no_grad():

        output = embedder.calculate_mask_func(depth_t*10.0, Rt_t, K)
        sorted_indices, seq_unique_list, seq_unique_counts, _ = output
        input_dict = {'rgb': rgb_t.unsqueeze(1),
                      'depth': depth_t.unsqueeze(1),
                    'sorted_indices': sorted_indices.unsqueeze(1),
                    'seq_unique_counts': seq_unique_counts.unsqueeze(1),
                      'seq_unique_list': seq_unique_list.unsqueeze(1)}
        w, w_mask = embedder.embed_obs(input_dict, past_w=w, past_w_num_mask=w_mask)
        recon_rgb, _ = embedder.generate(w, {'Rt': Rt_t.unsqueeze(1), 'K':K.unsqueeze(1)}, out_res=64)

        orig_rgb = add_title(data['rgb'][t], 'Gt Image')
        recon_rgb = (recon_rgb.squeeze().permute(1,2,0).detach().cpu() * 255).numpy().astype(np.uint8)
        recon_rgb = cv2.resize(recon_rgb, (view_size, view_size))
        recon_rgb = add_title(recon_rgb, 'Recon Image')

        w_im = w.mean(0).mean(0).detach().cpu().numpy()
        w_im = ((w_im - w_im.min())/(w_im.max()-w_im.min()) * 255).astype(np.uint8)
        w_im = cv2.applyColorMap(w_im, cv2.COLORMAP_VIRIDIS)[:,:,::-1]
        last_w_im = w_im.copy()

        w_im = add_agent_view_on_w(w_im, Rt_t, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15)
        w_im = cv2.resize(w_im, (view_size, view_size))
        w_img = np.fliplr(w_im)
        w_im = add_title(w_im, 'Map')

        view_im = np.concatenate([orig_rgb, recon_rgb, w_im],1)

        imgs.append(view_im)
        cv2.imshow("view", view_im[:,:,[2,1,0]])
        key = cv2.waitKey(1)
        if key == ord("q"): break

last_w = w

In [5]:
imageio.mimwrite('demo/embedding_traj.gif', imgs, fps=15)
display(HTML('<img src={}>'.format("demo/embedding_traj.gif")))

## Explore inside RNR-Map
- Press 'w, a, s, d' to move
- Press 'q' to quit

In [6]:
images = []
Rt_current = torch.eye(4).unsqueeze(0).to(device).unsqueeze(1)
while True:
    with torch.no_grad():
        rgb, _ = embedder.generate(last_w, {"Rt": Rt_current, 'K': K.unsqueeze(1)}, out_res=64)
        rgb = (rgb.squeeze().permute(1,2,0).detach().cpu() * 255).numpy().astype(np.uint8)
        rgb = cv2.resize(rgb, (data['rgb'][0].shape[0], data['rgb'][0].shape[0]))
        rgb = add_title(rgb, 'Recon Image')
        w_color = copy.deepcopy(last_w_im)
        w_color = add_agent_view_on_w(w_color, Rt_current, embedder.coordinate_scale, embedder.w_size, agent_size=4, view_size=15)
        w_color = np.fliplr(w_color)
        w_color = add_title(w_color, 'Map')
        view_im = np.concatenate([rgb, w_color],1)
        cv2.imshow("view", view_im[:,:,::-1])
        key = cv2.waitKey(0)
        if key == ord('q'): break
        elif key == ord('a'):
            Rt = rotate_n(n=-10.0).to(device)
            Rt_current = (Rt@Rt_current.squeeze()).unsqueeze(0).unsqueeze(0)
        elif key == ord('d'):
            Rt = rotate_n(n=10.0).to(device)
            Rt_current = (Rt@Rt_current.squeeze()).unsqueeze(0).unsqueeze(0)
        elif key == ord("w"):
            Rt_current = go_forward(Rt_current, step=0.1)
        elif key == ord('s'):
            Rt_current = go_backward(Rt_current, step=0.1)
        images.append(view_im)

In [7]:
#imageio.mimwrite("demo/explore_RNR_map.gif", images, fps=15)
display(HTML('<img src={}>'.format("demo/explore_RNR_map.gif")))