In [1]:
# Some useful settings for interactive work
%load_ext autoreload
%autoreload 2

%matplotlib widget

import torch
torch.set_float32_matmul_precision('high')

In [24]:
# Import the relevant modules
import sousvide.synthesize.rollout_generator as rg
import sousvide.visualize.plot_synthesize as ps
import figs.visualize.generate_videos as gv

In [3]:
cohort = "trace_test"

data_method = "eval_single"
eval_method = "eval_nominal"

scene = "mid_gate"

courses = [
    "traverse",
    ]   

roster = [
    "Maverick",
    "Rooster"
    ]

In [None]:
# Generate Rollouts
rg.generate_rollout_data(cohort,courses,scene,data_method)

# Review the Rollout Data
ps.plot_rollout_data(cohort)

In [50]:
import numpy as np
import torch
import os
from scipy.spatial.transform import Rotation

workspace = os.path.join("../cohorts/trace_test/rollout_data/traverse")

# Load the Rollout Data
trajs_path = os.path.join(workspace,"trajectories")
imgs_path = os.path.join(workspace,"images")

# Extract the image and trajectory files
trajs_files = [os.path.join(trajs_path,f) for f in os.listdir(trajs_path) if f.endswith('.pt')]
imgs_files = [os.path.join(imgs_path,f) for f in os.listdir(imgs_path) if f.endswith('.pt')]

trajs_files.sort()
imgs_files.sort()

trajs_files = trajs_files[:1]

def get_Rt(x:np.ndarray) -> tuple[np.ndarray,np.ndarray]:
    """
    Convert a quaternion to a rotation matrix.
    """
    R = Rotation.from_quat(x[6:10]).as_matrix()
    t = x[0:3]
    return R,t

def get_yaw(R:np.ndarray) -> float:
    """
    Get the yaw angle from a rotation matrix.
    """
    yaw = Rotation.from_matrix(R).as_euler('zyx')[0]
    return yaw

def get_local_FO(xcr:np.ndarray,Xhn:np.ndarray) -> np.ndarray:
    Rc2w,t0c_w = get_Rt(xcr)
    Rw2c = Rc2w.T
    
    Ndt = Xhn.shape[1]

    FO = np.zeros((4,Ndt))
    for i in range(Ndt):
        Ri2w,t0i_w = get_Rt(Xhn[:,i])
        Rw2i = Ri2w.T

        pic = Rw2c@(t0i_w-t0c_w)
        yaw = get_yaw(Rw2i@Rc2w)

        FO[:,i] = np.concatenate((pic,[yaw]))
        
    return FO

def get_local_PS(xcr:np.ndarray,Xhn:np.ndarray) -> np.ndarray:
    Rc2w,t0c_w = get_Rt(xcr)
    Rw2c = Rc2w.T
    
    Ndt = Xhn.shape[1]

    PS = np.zeros((7,Ndt))
    for i in range(Ndt):
        Ri2w,t0i_w = get_Rt(Xhn[:,i])

        PS[0:3,i] = Rw2c@(t0i_w-t0c_w)
        PS[3:7,i] = Rotation.from_matrix(Rw2c@Ri2w).as_quat() 
    
    return PS

def get_patch_index(u,v, patch_height, patch_width, n_rows, n_cols):

    row = int(v / patch_height)
    col = int(u / patch_width)

    # Clamp to valid index range
    row = min(max(row, 0), n_rows - 1)
    col = min(max(col, 0), n_cols - 1)

    return row, col

def get_patch_bounds(u,v, H, W, n_rows, n_cols):
    patch_height = H / n_rows
    patch_width = W / n_cols

    row,col = get_patch_index(u,v, patch_height, patch_width, n_rows, n_cols)

    top = int(round(row * patch_height))
    bottom = int(round((row + 1) * patch_height))
    left = int(round(col * patch_width))
    right = int(round((col + 1) * patch_width))

    return [top, bottom, left, right]

In [26]:
# Camera parameters
K = np.array([
    [ 462.956,   0.000, 323.076],
    [   0.000, 463.002, 181.184],
    [   0.000,   0.000,   1.000]
])

# (camera) lens to current frame
Rl2c = np.array([
    [ 0.000,  0.000, -1.000],
    [ 1.000,  0.000,  0.000],
    [ 0.000, -1.000,  0.000]
])
t0l_c = np.array([0.100, -0.030, -0.010])

# Current frame to (camera) lens
Rc2l = Rl2c.T
t0c_l = -Rc2l@t0l_c

# OpenGL to OpenCV
Rgl2cv = np.array([
    [ 1.0, 0.0, 0.0],
    [ 0.0,-1.0, 0.0],
    [ 0.0, 0.0,-1.0]
])

n_rows, n_cols = 64,64
# Body frame x
xb = np.array([1.0,0.0,0.0])

In [63]:
# Params
id0 = np.arange(1,2)
WPs,PSs,Ucr = [],[],[]
for traj_file,img_file in zip(trajs_files,imgs_files):
    # Load the trajectories
    trajs = torch.load(traj_file)
    imgs = torch.load(img_file)
    Ntj = len(trajs)

    for i in range(Ntj):
        Xro,Uro = trajs[i]['Xro'],trajs[i]['Uro']
        Iro = imgs[i]['images']
        Nro = Xro.shape[1]
        Height,Width = Iro.shape[1:3] 

        for j in range(Nro-1):
            idxs = j+id0
            idxs = np.clip(idxs,0,Nro-1)
            xcr,ucr = Xro[:,j],Uro[:,j]            
            xhn = Xro[:,idxs]
            wps = get_local_FO(xcr,xhn)
            pss = get_local_PS(xcr,xhn)

            WPs.append(wps),PSs.append(pss)
            Ucr.append(ucr)

            pts = Rc2l@wps[0:3,:]+t0c_l[:,None]
            pts = K@Rgl2cv@pts
            heights,widths = pts[0,:]/pts[2,:],pts[1,:]/pts[2,:]
            for height,width in zip(heights,widths):
                bnds = get_patch_bounds(height,width,Height,Width,n_rows,n_cols)

                Iro[j,bnds[0]:bnds[1],bnds[2]:bnds[3],:] = 0
                Iro[j,bnds[0]:bnds[1],bnds[2]:bnds[3],0] = 255
                Iro[j,bnds[0]:bnds[1],bnds[2]:bnds[3],1] = 0

        gv.images_to_mp4(Iro,'output.mp4', 20)

In [79]:
PSs = np.array(PSs)
Ucr = np.array(Ucr)

Nol = 0
for pss,ucr in zip(PSs,Ucr):
    Err_x = PSs - pss[None,:,:]
    err_x = np.linalg.norm(Err_x, axis=(1, 2), keepdims=True)
    err_x = err_x.squeeze(-1)

    Err_u = Ucr - ucr[None,:]
    err_u = np.linalg.norm(Err_u, axis=1, keepdims=True)
    err_u = err_u.squeeze(-1)

    idxs_x = np.where(err_x < 1e-3)[0]
    idxs_u = np.where(err_u < 4e-2)[0]

    test_out = err_x[idxs_u]
    print(test_out.T)
    has_overlap = bool(set(idxs_x) & set(idxs_u))
    if has_overlap:
        Nol += 1

[[0.         0.00322943 0.00321419 0.00326996]]
[[0.00322943 0.         0.00028901 0.0008518 ]]
[[0.00321419 0.00028901 0.         0.00056367]]
[[0.00326996 0.0008518  0.00056367 0.         0.00080511]]
[[0.00080511 0.         0.0010178 ]]
[[0.0010178  0.         0.00120771]]
[[0.00120771 0.         0.00137973]]
[[0.00137973 0.         0.00153746]]
[[0.00153746 0.         0.00168342]]
[[0.00168342 0.         0.00181915]]
[[0.00181915 0.         0.00194553]]
[[0.00194553 0.         0.00206289]]
[[0.00206289 0.         0.00217123]]
[[0.00217123 0.         0.00227035]]
[[0.00227035 0.         0.00235995]]
[[0.00235995 0.         0.00243969 0.00494813]]
[[0.00243969 0.         0.00250927 0.00507685]]
[[0.00494813 0.00250927 0.         0.00256846 0.00518462 0.01822173 0.01624849]]
[[0.00507685 0.00256846 0.         0.0026171  0.00527123 0.01937477 0.01709469 0.01506299]]
[[0.00518462 0.0026171  0.         0.00265515 0.00533668 0.02090849 0.01849819 0.01621644 0.01417387]]
[[0.00527123 0.002