In [None]:
%load_ext autoreload
%autoreload 2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import dataclasses
import sys
import timeit
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
import pandas as pd


In [None]:
# from scannet_dataset import ScanNetDataset
# from referit3d_data import ReferIt3dDataConfig
# from scanrefer_data import ScanReferDataConfig
from pytorch3d.io import IO
from pytorch3d.structures import Pointclouds
from home_robot.datasets.scannet import ScanNetDataset, ReferIt3dDataConfig, ScanReferDataConfig, NUM_CLASSES_LONG
data = ScanNetDataset(
    root_dir = '/private/home/ssax/home-robot/src/home_robot/home_robot/datasets/scannet/data',
    frame_skip = 180,
    n_classes=NUM_CLASSES_LONG,
    # n_classes=50,
    referit3d_config = ReferIt3dDataConfig(),
    scanrefer_config = ScanReferDataConfig(),
)

# Load specific scene
# idx = data.scene_list.index("scene0192_00") #'scene0000_00'
idx = 0
print(f"Loaded images of (h: {data.height}, w: {data.width}) - resized from ({data.DEFAULT_HEIGHT},{data.DEFAULT_WIDTH})")
scene_obs = data.__getitem__(idx, show_progress=True)

# Load GT mesh
from pytorch3d.io import IO, load_obj, load_ply
scene_id = scene_obs['scan_name']
print("Loading GT mesh for", scene_id)
# verts = load_ply(data.root_dir / f'scans/{scene_id}/{scene_id}_vh_clean.ply')
pc = IO().load_pointcloud(data.root_dir / f'scans/{scene_id}/{scene_id}_vh_clean.ply')
verts = pc.points_packed()
aligned_verts = torch.cat([verts, torch.ones_like(verts[:,:1])], dim=-1) @ scene_obs['axis_align_mats'][0].T
pointcloud_aligned = Pointclouds(points=aligned_verts[...,:3].unsqueeze(0), features=pc.features_packed().unsqueeze(0))


In [None]:
# df[:10]
scene_obs['ref_expr']

In [None]:
# Visualize referring expression
# Title: Query
# Trace: Pointcloud 
# Trace: GT bbox
# Trace: Distractors of same class
selected = scene_obs['box_target_ids'] == 39
id_to_name = dict(zip(data.METAINFO['CLASS_IDS'], data.METAINFO['CLASS_NAMES']))
id_to_name[scene_obs['box_classes'][selected].item()]

In [None]:
# K = scene_obs['intrinsics'][0][:3,:3]
# depth = scene_obs['depths'][0].squeeze().unsqueeze(0).unsqueeze(1)
# valid_depth  = (0.1 < depth) & (depth < 4.0)

# xyz = unproject_masked_depth_to_xyz_coordinates(
#     depth = depth,
#     mask  = ~valid_depth,
#     pose  = torch.eye(4).unsqueeze(0),
#     inv_intrinsics = torch.linalg.inv(K).unsqueeze(0),
# )
# rgb = scene_obs['images'][0].reshape(-1,3)[valid_depth.flatten()]
# print(scene_obs['image_paths'][0])
# print(f"Proportion depth valid: {float(valid_depth.float().mean())}")
# print(f"Depth min + max: {float(depth.flatten()[valid_depth.flatten()].min())}, {float(depth.flatten()[valid_depth.flatten()].max())}")
# print("These are the mins-maxes along each world axis. They should be in meters:")
# for i in range(3):
#     print(f"  {i}: ({float(xyz[:,i].min())}, {float(xyz[:,i].max())})")

In [None]:
plt.imshow(scene_obs['depths'][0])
plt.show()
plt.imshow(scene_obs['images'][0])
plt.show()


In [None]:
from home_robot.perception.detection.detic.detic_perception import DeticPerception
segmenter = DeticPerception(
        config_file=None,
        vocabulary="coco",
        custom_vocabulary="",
        checkpoint_file=None,
        sem_gpu_id=0,
        # verbose: bool = False,
    )

In [None]:
plt.imshow(res['semantic_frame'])
plt.show()
plt.imshow(scene_obs['instance_map'][-1] == 0)
plt.show()

In [None]:
# -> SparseVoxelMapWithInstanceViews.show(backend='pytorch3d')

# Plot GT scene
from home_robot.utils.bboxes_3d import BBoxes3D, join_boxes_as_scene, join_boxes_as_batch
from home_robot.utils.bboxes_3d_plotly import plot_scene_with_bboxes
from pytorch3d.vis.plotly_vis import AxisArgs
from pytorch3d.structures import Pointclouds
import seaborn as sns

colors = torch.tensor(sns.color_palette("husl", len(scene_obs['boxes_aligned'])))
gt_boxes = BBoxes3D(
    bounds = [scene_obs['boxes_aligned']],
    # features = [colors[0].unsqueeze(0).expand(27,3)],
    features = [colors],
    names = [scene_obs['box_classes'].unsqueeze(-1)]
)

fig = plot_scene_with_bboxes(
    plots = { f"{scene_id}": { 

                                "GT boxes": gt_boxes,
                                "GT points": pointcloud_aligned,
                                # "cameras": cameras,
                            }
    },
    xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
    yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
    zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    axis_args=AxisArgs(showgrid=True),
    pointcloud_marker_size=3,
    pointcloud_max_points=30_000,
    boxes_wireframe_width=3,
    boxes_add_cross_face_bars=False,
    # boxes_name_int_to_display_name_dict = dict(zip([int(i) for i in data.METAINFO['seg_valid_class_ids']], data.METAINFO['classes'])),
    boxes_name_int_to_display_name_dict = dict(zip(data.METAINFO['CLASS_IDS'], data.METAINFO['CLASS_NAMES'])),

    boxes_plot_together=False,
    height=1000,
    # width=1000,
)
fig