In [None]:
%load_ext autoreload
%autoreload 2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

# from home_robot.mapping.voxel import SparseVoxelMap
from home_robot.utils.point_cloud_torch import unproject_masked_depth_to_xyz_coordinates


In [None]:
from home_robot.datasets.scannet import ScanNetDataset, ReferIt3dDataConfig, ScanReferDataConfig
data = ScanNetDataset(
    root_dir = '/private/home/ssax/home-robot/src/home_robot/home_robot/datasets/scannet/data',
    frame_skip = 180,
    n_classes = 50,
    referit3d_config = ReferIt3dDataConfig(),
    scanrefer_config = ScanReferDataConfig(),
)

# Load specific scene
# scene0192_00 -- small scene
# 'scene0000_00' -- large scene
idx = data.scene_list.index("scene0192_00") #'scene0000_00'
# idx = 0
print(f"Loaded images of (h: {data.height}, w: {data.width}) - resized from ({data.DEFAULT_HEIGHT},{data.DEFAULT_WIDTH})")
scene_obs = data.__getitem__(idx, show_progress=True)

# # Load GT mesh
# from pytorch3d.io import IO, load_obj, load_ply
# from pytorch3d.structures import Pointclouds
# def transform_points(points, transform_mat):
#     return (torch.cat([points, torch.ones_like(points[:,:1])], dim=-1) @ transform_mat.T)[...,:3]

# scene_id = scene_obs['scan_name']
# print("Loading GT mesh for", scene_id)
# pointcloud = IO().load_pointcloud(data.root_dir / f'scans/{scene_id}/{scene_id}_vh_clean.ply')
# verts = transform_points(pointcloud.points_packed(), scene_obs['axis_align_mats'][0])
# pointcloud_aligned =  Pointclouds(points = [verts], features = [pointcloud.features_packed()])
# print(f"GT mesh: {len(verts)} verts")

id_to_name = dict(zip(
    data.METAINFO['CLASS_IDS'], # IDs [1, 3, 4, 5, ..., 65]
    data.METAINFO['CLASS_NAMES'] # [wall, floor, cabinet, ...]
))

In [None]:
# box_idx = find(torch.LongTensor(exp_target_ids), scene_obs['box_target_ids'])[:,1]
# scene_obs['boxes_aligned'][box_idx].shape

import torch

def resize_boxes(boxes, factor):
    """
    Resize axis-aligned bounding boxes by a specific factor along each dimension.

    Args:
        boxes (torch.Tensor): Input bounding boxes of shape [K, 3, 2].
        factor (float or tuple): The factor by which to resize the boxes. If a single float is provided,
            all dimensions will be scaled by the same factor. If a tuple (fx, fy, fz) is provided, each dimension
            will be scaled independently.

    Returns:
        torch.Tensor: Resized bounding boxes of the same shape as the input.
    """
    if isinstance(factor, (float, int)):
        # If a single factor is provided, scale all dimensions by the same factor
        factor = (factor, factor, factor)

    # Extract mins and maxes for convenience
    mins = boxes[..., 0]
    maxes = boxes[..., 1]

    # Calculate the center of each box
    center = (mins + maxes) / 2.0

    # Calculate the size of each box along each dimension
    sizes = maxes - mins

    # Convert factor to a PyTorch tensor
    factor = torch.tensor(factor, dtype=boxes.dtype, device=boxes.device)

    # Scale the sizes by the provided factors
    resized_sizes = sizes * factor

    # Calculate the new mins and maxes
    new_mins = center - resized_sizes / 2.0
    new_maxes = center + resized_sizes / 2.0

    # Create the resized bounding boxes
    resized_boxes = torch.stack((new_mins, new_maxes), dim=-1)

    return resized_boxes


In [None]:
from evaluation.refer_exp import eval_obj_selection_bboxes

def find(tensor, values):
    return torch.nonzero(tensor[..., None] == values)

# Predicted
# pred_bounds = torch.stack([inst.bounds.cpu() for inst in svm.get_instances()])
# pred_class = torch.stack([inst.category_id.cpu() for inst in svm.get_instances()])
# pred_scores = torch.stack([torch.max(torch.stack([v.score for v in ins.instance_views])).cpu() for ins in svm.get_instances()])

# GT
gt_bounds = scene_obs['boxes_aligned']
gt_ids = scene_obs['box_target_ids']
gt_class = scene_obs['box_classes']
exp_target_ids = list(scene_obs['ref_expr'].target_id)
box_idx = find(torch.LongTensor(exp_target_ids), scene_obs['box_target_ids'])[:,1]
pred_bounds = scene_obs['boxes_aligned'][box_idx]

pred_bounds = resize_boxes(pred_bounds, 0.7)

_ = eval_obj_selection_bboxes(
    box_gt_bounds = [gt_bounds],
    box_gt_class = [gt_class],
    box_gt_ids = [gt_ids],
    exp_target_ids = [exp_target_ids],
    box_pred_bounds = [pred_bounds],
    iou_thr = (0.25, 0.5, 0.75),
    # label_to_cat = segmenter.seg_id_to_name 
)