# Notebook for generating training data for NeRF from ROSBAG

In [None]:
#load requirements for working with point clouds
from vedo import *
from ipyvtklink.viewer import ViewInteractiveWidget
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import tensorflow as tf

#limit GPU memory ------------------------------------------------
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
if gpus:
  try:
    memlim = 10*1024 #22*1024
    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=memlim)])
  except RuntimeError as e:
    print(e)
#-----------------------------------------------------------------

from ICET_spherical import ICET
from scipy.spatial.transform import Rotation as R
import copy

from matplotlib import pyplot as p
from nerf_utils import *
from coarse_network_utils import*

%load_ext autoreload
%autoreload 2
%autosave 180

# Download dataset from external source

ex: https://ori-drs.github.io/newer-college-dataset/

Move to corresponding folder in ../data/


# Export point clouds from ROSBAG to csv file  

1. edit <bag2mapframe.py> to save generated point clouds to correct directory


2. Move <bag2mapframe.py> to ros directory on your machine and run it with:

```
mv bag2mapframe.py ~/catkin_ws/src/bagconverter
cd ~/catkin_ws
catkin_make
cd src/bagconverter
roscore
python3 bag2mapframe.py
```
3. play rosbag 

```
rosbag play -r 0.1 myBag.bag
```

# Load and filter ground truth pose data

In [None]:
dir_name = "~/PLINK/data/NewerCollegeDataset/"
experiment_name = "01_short_experiment-20230331T172433Z-009/01_short_experiment/"
fn_gt = dir_name + experiment_name + "ground_truth/registered_poses.csv"
#sec,nsec,x,y,z,qx,qy,qz,qw
gt = np.loadtxt(fn_gt, delimiter=',',skiprows = 1)
seconds = gt[:, 0]
nano_seconds = gt[:, 1]
xyz = gt[:, 2:5]
qxyzw = gt[:, 5:]
num_poses = qxyzw.shape[0]
sensor_poses = np.eye(4, dtype=np.float64).reshape(1, 4, 4).repeat(num_poses, axis=0)
sensor_poses[:, :3, :3] = R.from_quat(qxyzw).as_matrix()
sensor_poses[:, :3, 3] = xyz
T_CL = np.eye(4, dtype=np.float32)
T_CL[:3, :3] = R.from_quat([0.0, 0.0, 0.924, 0.383]).as_matrix() #was this --1134.97 deg
T_CL[:3, 3] = np.array([-0.084, -0.025, 0.050], dtype=np.float32) #was this
sensor_poses = np.einsum("nij,jk->nik", sensor_poses, T_CL)
initial_pose = np.linalg.inv(sensor_poses[0]) 
poses_timestamps = seconds * 10e9 + nano_seconds
sensor_poses = np.einsum("ij,njk->nik", np.linalg.inv(sensor_poses[0]), sensor_poses) #TRY COMMENTING OUT...

#get body frame vel to remove motion disortion from training data
vel_world_frame = np.diff(sensor_poses[:,:3,-1], axis = 0)
vel_body_frame = np.linalg.pinv(sensor_poses[1:,:3,:3]) @ vel_world_frame[:,:,None]
vel_body_frame = vel_body_frame[:,:,0]
#smooth out velocity estimates
def moving_average(a, n=10):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n
window=50
MAx = moving_average(vel_body_frame[:,0], n = window)
MAy = moving_average(vel_body_frame[:,1], n = window)
MAz = moving_average(vel_body_frame[:,2], n = window)
vel_body_frame = np.array([MAx, MAy, MAz]).T

rot_vel_euls = np.diff(R.from_matrix(sensor_poses[:,:3,:3]).as_euler('xyz'), axis = 0)
idx = np.argwhere(rot_vel_euls > (np.pi))
rot_vel_euls[idx] = 0
idx = np.argwhere(rot_vel_euls < (-np.pi))
rot_vel_euls[idx] = 0

#load HD map
#courtyard
pl = '~/PLINK/data/NewerCollegeDataset/new-college-29-01-2020-1cm-resolution-1stSection - mesh.ply'
#forest
# pl = '~/PLINK/data/NewerCollegeDataset/new-college-29-01-2020-1cm-resolution-5thSection.ply'
HD_map = trimesh.load(pl).vertices
show_nth = 5 #10
submap = HD_map[::show_nth]

# Generate training data

In [None]:
n_images = 24 #240 
n_rots = 128 #128    #number of horizontal patches in 360 degrees
n_vert_patches = 1 #1 #number of vertical patches between phimin and phimaxs
useICET = True #need to turn off when working with the foliage dataset???
image_width = 1024//n_rots
image_height = 64//n_vert_patches
shrink_factor = 0.005 #courtyard

n_cols_to_skip = n_rots // 8 #remove this much from the beginning and end of each scan
                             #   (need to remove parts of frame containing researcher carrying LIDAR)
#Ouster OS1-64
#took forever to calibrate this correctly-- not the same as on the sensor spec sheet!
phimin = np.deg2rad(-15.594) 
phimax = np.deg2rad(17.743)
vert_fov = np.rad2deg(phimax-phimin)

poses = np.zeros([n_images*n_rots*n_vert_patches,4,4])
images = np.ones([n_images*n_rots*n_vert_patches, 64//n_vert_patches, 1024//n_rots, 2]) #depth and raydrop channels
# [n total "patches", patch height, patch width, xyz]
rays_o_all = np.zeros([n_images*n_rots*n_vert_patches, 64//n_vert_patches, 1024//n_rots, 3]) 
rays_d_all = np.zeros([n_images*n_rots*n_vert_patches, 64//n_vert_patches, 1024//n_rots, 3]) 

H, W = images.shape[1:3]
redfix_hist = np.zeros([n_images,4,4]) #holds on to the corrective transforms we get from ICET 

for i in range(n_images):
    print(i) 
    #full loop first courtyard
    idx = i*50 + 7650 #for debug and visualization
#     idx = i*5 + 7650 #train set
#     idx = i*40 + 10600 #forest
    fn1 = "~/PLINK/data/NewerCollegeDataset/01_Short_Experiment/point_clouds/frame_" + str(idx) + ".npy"
    pc1 = np.load(fn1)

    #apply distortion correction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #m_hat = [dx, dy, dz, droll, dpitch, dyaw]
    m_hat = np.array([-vel_body_frame[idx,0],
                      -vel_body_frame[idx,1],
                      -vel_body_frame[idx,2],
#                       -rot_vel_euls[idx,0], 
#                       -rot_vel_euls[idx,1],
#                       -rot_vel_euls[idx,2]
                      0.,0.,0.
                     ])   
    pc1 = apply_motion_profile(pc1, m_hat, period_lidar=1.)
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    pc1 = np.flip(pc1, axis = 0)

    #Register undistorted PC against HD Map using ICET to correct issues in ground truth
    if useICET:
        submap_in_pc1_frame = (np.linalg.pinv(sensor_poses[idx]) @ initial_pose @ np.append(submap, np.ones([len(submap),1]), axis =1).T).T #test
        submap_in_pc1_frame = submap_in_pc1_frame[:,:3]

        initial_guess = tf.constant([0.,0.,0.,0.,0.,0.])
        it = ICET(cloud1 = submap_in_pc1_frame, cloud2 = pc1, fid = 50, niter = 8, 
           draw = False, group = 2, RM = False, DNN_filter = False, x0 = initial_guess)

        pc1_in_map_frame = (initial_pose @ sensor_poses[idx] @ np.append(pc1, np.ones([len(pc1),1]), axis =1).T).T #test
        pc1_in_map_frame = pc1_in_map_frame[:,:3]

        pc1_corrected_in_map_frame = (initial_pose @ sensor_poses[idx] @ np.append(it.cloud2_tensor.numpy(), np.ones([len(it.cloud2_tensor.numpy()),1]), axis =1).T).T #test
        pc1_corrected_in_map_frame = pc1_corrected_in_map_frame[:,:3]    

        #draw red scan corrected by output of ICET
        redFix = np.eye(4)
        redFix[:3,-1] = it.X[:3]
        redFix[:3,:3] = redFix[:3,:3] @ R.from_euler('xyz', [it.X[3], it.X[4], it.X[5]]).as_matrix()
        redfix_hist[i] = redFix
        redScanFixed = (redFix @ np.append(pc1, np.ones([len(pc1),1]), axis =1).T).T
        redScanFixed = (sensor_poses[idx] @ np.append(redScanFixed[:,:3], np.ones([len(redScanFixed),1]), axis =1).T).T
 
    else:
        redFix = np.eye(4)
        redfix_hist[i] = redFix
        redScanFixed = (redFix @ np.append(pc1, np.ones([len(pc1),1]), axis =1).T).T
        redScanFixed = (sensor_poses[idx] @ np.append(redScanFixed[:,:3], np.ones([len(redScanFixed),1]), axis =1).T).T

    #convert to depth image
    pc1_spherical = cartesian_to_spherical(pc1).numpy() #[r, theta, phi]
    pcs = np.reshape(pc1_spherical, [-1,64,3])
    pcs = np.flip(pcs, axis = 1)
    raw_data = pcs[:,:,:]
    raw_data = np.transpose(pcs, [1,0,2])

    #destagger depth images (OS1 unit has delay in sensor return bus)
    data = np.zeros([64, 1024])
    for k in range(np.shape(data)[0]//4):
        data[4*k,1:-8] = raw_data[4*k,9:,0]
        data[4*k+1,1:-2] = raw_data[4*k+1,3:,0]
        data[4*k+2,4:] = raw_data[4*k+2,:-4,0]
        data[4*k+3,10:] = raw_data[4*k+3,:-10,0]
    data = np.flip(data, axis =1)

    #get rays_o and rays_d directly inside data generation loop 
    rotm = sensor_poses[idx] @ redfix_hist[i]
    rotm[0,-1] += 30
    rotm[1,-1] += 30
    rotm[2,-1] += 15 
    rotm[:3,-1] *= shrink_factor #0.02 #0.005 #COURTYARD
    #courtyard
    rotm[0,-1] += 0.01 
    rotm[1,-1] += 0.25 
    rotm[2,-1] += 0.25 #translate above xy plane
#     #forest
#     rotm[0,-1] += 1.2 
#     rotm[1,-1] += 1.25 
#     rotm[2,-1] += 0.25 #translate above xy plane
    ro, rd = get_rays_from_point_cloud(pc1, m_hat, rotm) 

    for j in range(n_rots):
        for k in range(n_vert_patches):    
            #store rays_o and rays_d info
            rd_in_patch = rd[k*image_height:(k+1)*image_height,j*image_width:(j+1)*image_width, :]
            rays_d_all[k+(j+(i*n_rots))*n_vert_patches,:,:,:] = rd_in_patch
            ro_in_patch = ro[k*image_height:(k+1)*image_height,j*image_width:(j+1)*image_width, :]
            rays_o_all[k+(j+(i*n_rots))*n_vert_patches,:,:,:] = ro_in_patch            

            #get cropped depth image ~~~~~~~~~~~~~~~~~~~~    
            #crop vertically and horizontally
            pcs = data[k*image_height:(k+1)*image_height,j*image_width:(j+1)*image_width] 
            #save depth information to first channel
            images[k+(j+(i*n_rots))*n_vert_patches,:,:,0] = pcs
            #save raydrop mask to 2nd channel
            a = np.argwhere(abs(pcs) < 1)
            images[k+(j+(i*n_rots))*n_vert_patches, a[:,0],a[:,1],1] = 0

            #get transformation matrix ~~~~~~~~~~~~~~~~~~
            #centers origin at actual origin of HD map 
            rotm = sensor_poses[idx] @ redfix_hist[i]

            crop_angle = j*(2*np.pi/n_rots) - np.pi/2 + (np.pi/n_rots)
            #account for the fact that sensor points back and to the left
            rotm_crop = R.from_euler('xyz', [0,0,-crop_angle]).as_matrix() #test
            rotm[:3,:3] = rotm[:3,:3] @ rotm_crop
            rotm[0,-1] += 30
            rotm[1,-1] += 30
            rotm[2,-1] += 15 
            rotm[:3,-1] *= shrink_factor
            images[k+(j+(i*n_rots))*n_vert_patches,:,:,0] *= shrink_factor
            #courtyard
            rotm[0,-1] += 0.01 #shift up just a little
            rotm[1,-1] += 0.25 #shift towards positive x
            rotm[2,-1] += 0.25 #translate above xy plane
#             #forest
#             rotm[0,-1] += 1.2 
#             rotm[1,-1] += 1.25 
#             rotm[2,-1] += 0.25 #translate above xy plane

            poses[k+(j+(i*n_rots))*n_vert_patches] = rotm 

# # Remove patches where sensor is occluded by person holding lidar 
#calculate how many columns of patches we need to skip at the beginning and end of each scan to avoid
bad_idx = np.zeros([0,n_rots - 2*n_cols_to_skip])
a = np.linspace(0,n_rots*n_images*n_vert_patches-1,n_rots*n_images*n_vert_patches)
for i in range(n_vert_patches*n_cols_to_skip):
    bad_i_left = a[i::n_rots*n_vert_patches]
    bad_idx = np.append(bad_idx, bad_i_left)
    bad_i_right = a[(i+n_vert_patches*(n_rots-n_cols_to_skip))::n_rots*n_vert_patches]
    bad_idx = np.append(bad_idx, bad_i_right)

bad_idx = np.sort(bad_idx)
all_idx = np.linspace(0,n_rots*n_images*n_vert_patches-1,n_rots*n_images*n_vert_patches)
good_idx = np.setdiff1d(all_idx, bad_idx).astype(int)
images = images[good_idx,:,:,:]
poses = poses[good_idx,:,:]
rays_d_all = rays_d_all[good_idx,:,:,:]
rays_o_all = rays_o_all[good_idx,:,:,:]

images = images.astype(np.float32)
poses = poses.astype(np.float32)

# Visualize training data

draw training data in the same frame using depth images, ray origins (rays_o), 
and view directions (rays_d)

In [None]:
def draw_frame_from_rays(disp, n_rots=128, n_vert_patches=1, frameIdx=0, 
                         color = 'red', stitched_map = np.zeros([0,3]) ):

    phimin = np.deg2rad(-15.594) #took forever to figure this out...
    phimax = np.deg2rad(17.743)
    H = 64 // n_vert_patches
    W = 1024 // n_rots
    vertical_bins = np.linspace(phimin, phimax, n_vert_patches+1)  
    phivals = np.linspace(phimin, phimax, 64)#new (correct) way to bin elevation angles
    n_cols_to_skip = n_rots // 8

    pts1 = np.zeros([1,3])
    for p in range(frameIdx*(n_rots - 2*n_cols_to_skip), (frameIdx + 1 )*(n_rots - 2*n_cols_to_skip)):
        for i in range(n_vert_patches):
            img_i = i
            idx_first=len(phivals) - (img_i%(n_vert_patches))*(64//n_vert_patches)-1
            idx_second= (len(phivals)- ((img_i+1)%(n_vert_patches))*(64//n_vert_patches))%len(phivals)
            phimin_patch = phivals[idx_first]
            phimax_patch = phivals[idx_second]

            pose = poses[i + p*n_vert_patches]
            rays_o = rays_o_all[i + p*n_vert_patches]
            rays_d = rays_d_all[i + p*n_vert_patches]
            
            inMap1 = add_patch(rays_o, rays_d, images[i+p*n_vert_patches,:,:,0])

            pts1 = np.append(pts1, inMap1, axis = 0)
        disp.append(Points(rays_o[0,:1,:], r = 15, c = 'purple')) #DEBUG 

    vizPts1 = Points(pts1, c = color, r = 3., alpha = 0.125)
    disp.append(vizPts1)
    
    stitched_map = np.append(stitched_map, pts1, axis = 0)
    return stitched_map
    
plt = Plotter(N = 1, axes = 1, bg = (1, 1, 1), interactive = True) #axes = 4 (simple), 1(scale)
disp=[]          

stitched_map = np.zeros([0,3])
# colors = ['red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet', 'red']
colors = np.linspace(0.1,0.3,n_images)[:,None] * np.array([[1,1,1]])
for i in range(len(colors)):
    print(i)
    stitched_map = draw_frame_from_rays(disp, n_rots = 128, n_vert_patches=1, 
                                        frameIdx = i, color=colors[i], stitched_map=stitched_map)
    
plt.show(disp, "Drawing training data from depth images, rays_o, and rays_d")
ViewInteractiveWidget(plt.window)  

# Save training data 

In [None]:
np.save("~/PLINK/data/NewerCollegeDataset/images.npy", images)
np.save("~/PLINK/data/NewerCollegeDataset/poses.npy", poses)
np.save("~/PLINK/data/NewerCollegeDataset/rays_o.npy", rays_o_all)
np.save("~/PLINK/data/NewerCollegeDataset/rays_d.npy", rays_d_all)