In [180]:
import sys, os
BASE = os.path.join(os.getcwd(), "..")
print(BASE)
sys.path.append(BASE)

/mnt/cvda/zhangzimu/ADT/Ego-Prediction/sample_code/..


In [181]:
import torch, argparse
from base.dataset import ADT_Dataset
from model.simple import Simple_Eye_Gaze_MLP, Simple_Eye_Gaze_Loss, Simple_Trajectory_Loss, Simple_Trajectory_MLP
from base.utils import load_config
from tqdm import tqdm
from train import validate
from base.metrics import Average_Gaze_Angular_Error, Average_Traj_Error
import numpy as np
import quaternion
import rerun as rr 
from rerun.datatypes import Quaternion

In [182]:
config = load_config("../configs/config.yaml")
model_name = config['model']
len_per_input_seq = config['len_per_input_seq']
len_per_output_seq = config['len_per_output_seq']
interval = config['interval']
frame_stride = config['frame_stride']
hidden_dim = config['hidden_dim']
use_gpu = config['use_gpu']
num_workers = config['num_workers']
device = torch.device('cpu')
print("Using device: ", device)

Using device:  cpu


In [183]:
if model_name == "simple_gaze_mlp":
    model = Simple_Eye_Gaze_MLP(input_dim=3 * len_per_input_seq // frame_stride, hidden_dim=hidden_dim, output_dim=3 * len_per_output_seq // frame_stride).to(device)
    criterion = Simple_Eye_Gaze_Loss()
    validation_criterion = Average_Gaze_Angular_Error()
elif model_name == "simple_traj_mlp":
    model = Simple_Trajectory_MLP(input_dim=3 * len_per_input_seq // frame_stride + 16 * len_per_input_seq // frame_stride, hidden_dim=hidden_dim, output_dim=3 * len_per_output_seq // frame_stride + 4 * len_per_output_seq // frame_stride).to(device)   
    criterion = Simple_Trajectory_Loss() 
    validation_criterion = Average_Traj_Error()
else:
    raise NotImplementedError

In [184]:
model.load_state_dict(torch.load("../logs/simple_traj_mlp.pth")['model'])
model.eval()

Simple_Trajectory_MLP(
  (fc1): Linear(in_features=380, out_features=256, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=256, out_features=84, bias=True)
)

In [185]:
dataset = ADT_Dataset("../dataset/data.h5py", len_per_input_seq, len_per_output_seq, interval, frame_stride, train=False, dataset_path="../dataset/")
dataset_length = len(dataset)

Loading data...


100%|██████████| 208/208 [00:00<00:00, 2131.17it/s]


In [186]:
# piece_index = 1096
piece_index = np.random.randint(0, dataset_length)
input_clip, gt_clip = dataset[piece_index]
input_clip_to_model = dict()
for key, value in input_clip.items():
    input_clip_to_model[key] = value.unsqueeze(0)
with torch.no_grad():
    pred_coord, pred_quat = model(input_clip_to_model, device)
    
    pred_coord = pred_coord.squeeze()
    pred_quat = pred_quat.squeeze()

In [187]:
def visualize_clip(clip, start_time, colors, input=True, pred_coord=None, pred_quat=None):
    camera_color = [255, 0, 0]
    camera_traj = []

    for idx, timestamp in tqdm(enumerate(clip['timestamps']), total=clip['timestamps'].shape[0]):
        time = (timestamp - start_time) / 1e9
        rr.set_time_seconds("stable_time", time)

        bboxes_3d = clip['3d_boundingboxes_aabb'][idx].numpy()
        bboxes_3d = np.asarray([bbox for bbox in bboxes_3d if bbox[0] != np.inf])
        bboxes_3d_transform = clip['3d_boundingboxes_transform_scene_object_matrix'][idx].numpy()
        bboxes_3d_transform = np.asarray([bbox_transform for bbox_transform in bboxes_3d_transform if bbox_transform[0, 0] != np.inf])

        mins = bboxes_3d[:, [0, 2, 4]]
        maxs = bboxes_3d[:, [1, 3, 5]]
        translations = bboxes_3d_transform[:, :3, 3]
        rotations = bboxes_3d_transform[:, :3, :3]
        rotations = quaternion.from_rotation_matrix(rotations)
        rotations_to_rerun = []
        for rot in rotations:
            rotations_to_rerun.append(Quaternion(xyzw=quaternion.as_float_array(rot)[[1, 2, 3, 0]]))
        mins += translations
        maxs += translations
        half_sizes = (maxs - mins) / 2

        centers = (mins + maxs) / 2
        # centers = centers[:, [2, 0, 1]]
        # half_sizes = half_sizes[:, [2, 0, 1]]

        rr.log(
            "3D/3dboxes",
            rr.Boxes3D(
                centers=centers,
                half_sizes=half_sizes,
                rotations=rotations_to_rerun,
                colors = colors,
                radii=np.ones(mins.shape[0]) * 0.01
            )
        )

        scene_cam_matrix = clip['scene_cam_matrix'][idx].numpy()
        camera_coord = scene_cam_matrix[:3, 3]
        camera_traj.append(camera_coord)

        camera_rot_quat = np.quaternion(*clip['cam_pose_quat'][idx])
        camera_orientation = quaternion.rotate_vectors(camera_rot_quat, [0, 0, 0.2])

        rr.log(
            "3D/gt_camera_coord",
            rr.Points3D(
                positions=camera_coord,
                radii=0.1,
                colors=camera_color,
            )
        )
        rr.log(
            "3D/gt_camera_orientation",
            rr.Arrows3D(
                origins=camera_coord,
                vectors=camera_orientation,
                colors=camera_color,
                radii=0.1
            )
        )
        if pred_coord is not None:
            rr.log(
                "3D/pred_camera_coord",
                rr.Points3D(
                    positions=pred_coord[idx],
                    radii=0.1,
                )
            )
        if pred_quat is not None:
            pred_camera_rot_quat = np.quaternion(*pred_quat[idx])
            pred_camera_orientation = quaternion.rotate_vectors(pred_camera_rot_quat, [0, 0, 0.2])
            rr.log(
                "3D/pred_camera_orientation",
                rr.Arrows3D(
                    origins=pred_coord[idx],
                    vectors=pred_camera_orientation,
                    radii=0.1
                )
            )

        rr.log(
            "2D/real_RGB",
            rr.Image(
                data=clip['video'][idx]
            )
        )

    rr.set_time_seconds("stable_time", 0)
    camera_traj = np.asarray(camera_traj)
    if input:
        rr.log(
            "3D/input_camera_traj",
            rr.LineStrips3D(
                [camera_traj],
                colors=[0, 255, 0],
            )
        )
    else:
        rr.log(
            "3D/output_camera_traj",
            rr.LineStrips3D(
                [camera_traj],
                colors=[0, 0, 255],
            )
        )

In [188]:
rr.init("dataset_piece_visualization")
rr.set_time_seconds("stable_time", 0)
rr.log(
    "3D",
    rr.ViewCoordinates.RIGHT_HAND_Y_UP
)

start_time = input_clip['timestamps'][0]

colors = np.random.randint(0, 255, (gt_clip['3d_boundingboxes_aabb'].shape[1], 3))

visualize_clip(input_clip, start_time, colors, input=True)
visualize_clip(gt_clip, start_time, colors, input=False, pred_coord=pred_coord, pred_quat=pred_quat)

100%|██████████| 20/20 [00:00<00:00, 24.97it/s]
100%|██████████| 12/12 [00:00<00:00, 32.00it/s]


In [189]:
rr.notebook_show()

Viewer()