In [None]:
%load_ext autoreload
%autoreload 2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import dataclasses
import sys
import timeit
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
import pandas as pd


In [None]:
# from scannet_dataset import ScanNetDataset
# from referit3d_data import ReferIt3dDataConfig
# from scanrefer_data import ScanReferDataConfig
from pytorch3d.io import IO
from pytorch3d.structures import Pointclouds
from home_robot.datasets.scannet import ScanNetDataset, ReferIt3dDataConfig, ScanReferDataConfig, NUM_CLASSES_LONG
data = ScanNetDataset(
    root_dir = '/private/home/ssax/home-robot/src/home_robot/home_robot/datasets/scannet/data',
    frame_skip = 180,
    n_classes=NUM_CLASSES_LONG,
    # n_classes=50,
    referit3d_config = ReferIt3dDataConfig(),
    scanrefer_config = ScanReferDataConfig(),
)

# Load specific scene
# idx = data.scene_list.index("scene0192_00") #'scene0000_00'
idx = 0
print(f"Loaded images of (h: {data.height}, w: {data.width}) - resized from ({data.DEFAULT_HEIGHT},{data.DEFAULT_WIDTH})")
scene_obs = data.__getitem__(idx, show_progress=True)

# Load GT mesh
from pytorch3d.io import IO, load_obj, load_ply
scene_id = scene_obs['scan_name']
print("Loading GT mesh for", scene_id)
# verts = load_ply(data.root_dir / f'scans/{scene_id}/{scene_id}_vh_clean.ply')
pc = IO().load_pointcloud(data.root_dir / f'scans/{scene_id}/{scene_id}_vh_clean.ply')
verts = pc.points_packed()
aligned_verts = torch.cat([verts, torch.ones_like(verts[:,:1])], dim=-1) @ scene_obs['axis_align_mats'][0].T
pointcloud_aligned = Pointclouds(points=aligned_verts[...,:3].unsqueeze(0), features=pc.features_packed().unsqueeze(0))

# Load short-form GT mesh (50k points, with semantic + instance labels) 
ins_50k = torch.from_numpy(np.load(data.instance_dir / f'{scene_id}_ins_label.npy').astype(np.int32))
_verts = torch.from_numpy(np.load(data.instance_dir / f'{scene_id}_vert.npy'))
locs_50k, col_50k = _verts[:,:3], (_verts[:,3:] / 255.)
locs_50k = torch.cat([locs_50k, torch.ones_like(locs_50k[:,:1])], dim=-1) @ scene_obs['axis_align_mats'][0].T
locs_50k = locs_50k[:,:3]


In [None]:
# df[:10]
scene_obs['ref_expr']

In [None]:
# Visualize referring expression
# Title: Query
# Trace: Pointcloud 
# Trace: GT bbox
# Trace: Distractors of same class
selected = scene_obs['box_target_ids'] == 39
id_to_name = dict(zip(data.METAINFO['CLASS_IDS'], data.METAINFO['CLASS_NAMES']))
id_to_name[scene_obs['box_classes'][selected].item()]

In [None]:
# K = scene_obs['intrinsics'][0][:3,:3]
# depth = scene_obs['depths'][0].squeeze().unsqueeze(0).unsqueeze(1)
# valid_depth  = (0.1 < depth) & (depth < 4.0)

# xyz = unproject_masked_depth_to_xyz_coordinates(
#     depth = depth,
#     mask  = ~valid_depth,
#     pose  = torch.eye(4).unsqueeze(0),
#     inv_intrinsics = torch.linalg.inv(K).unsqueeze(0),
# )
# rgb = scene_obs['images'][0].reshape(-1,3)[valid_depth.flatten()]
# print(scene_obs['image_paths'][0])
# print(f"Proportion depth valid: {float(valid_depth.float().mean())}")
# print(f"Depth min + max: {float(depth.flatten()[valid_depth.flatten()].min())}, {float(depth.flatten()[valid_depth.flatten()].max())}")
# print("These are the mins-maxes along each world axis. They should be in meters:")
# for i in range(3):
#     print(f"  {i}: ({float(xyz[:,i].min())}, {float(xyz[:,i].max())})")

In [None]:
plt.imshow(scene_obs['depths'][0])
plt.show()
plt.imshow(scene_obs['images'][0])
plt.show()


In [None]:
# -> SparseVoxelMapWithInstanceViews.show(backend='pytorch3d')

# Plot GT scene
from home_robot.utils.bboxes_3d import BBoxes3D, join_boxes_as_scene, join_boxes_as_batch
from home_robot.utils.bboxes_3d_plotly import plot_scene_with_bboxes
from pytorch3d.vis.plotly_vis import AxisArgs
from pytorch3d.structures import Pointclouds
import seaborn as sns

colors = torch.tensor(sns.color_palette("husl", len(scene_obs['boxes_aligned'])))
gt_boxes = BBoxes3D(
    bounds = [scene_obs['boxes_aligned']],
    # features = [colors[0].unsqueeze(0).expand(27,3)],
    features = [colors],
    names = [scene_obs['box_classes'].unsqueeze(-1)]
)

pointcloud_50k = Pointclouds(points=[locs_50k], features=[col_50k])

fig = plot_scene_with_bboxes(
    plots = { f"{scene_id}": { 

                                "GT boxes": gt_boxes,
                                # "GT points": pointcloud_aligned,
                                "GT points smol": pointcloud_50k,
                                # "cameras": cameras,
                            }
    },
    xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
    yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
    zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    axis_args=AxisArgs(showgrid=True),
    pointcloud_marker_size=3,
    pointcloud_max_points=200_000,
    boxes_wireframe_width=3,
    boxes_add_cross_face_bars=False,
    # boxes_name_int_to_display_name_dict = dict(zip([int(i) for i in data.METAINFO['seg_valid_class_ids']], data.METAINFO['classes'])),
    boxes_name_int_to_display_name_dict = dict(zip(data.METAINFO['CLASS_IDS'], data.METAINFO['CLASS_NAMES'])),

    boxes_plot_together=False,
    height=1000,
    # width=1000,
)
fig

In [None]:
import torch
from torch import Tensor
from typing import Tuple, Optional
from home_robot.utils.point_cloud_torch import get_bounds
from home_robot.utils.bboxes_3d import box3d_volume_from_bounds
def transform_basis(points: torch.Tensor, normal_vector: torch.Tensor) -> torch.Tensor:
    """
    Transforms a set of points to a basis where the first two basis vectors
    are in the plane computed using SVD, and the third basis vector is along
    the normal dimension.

    :param points: A 2D tensor of shape (N, 3), representing N points in 3D space.
    :param normal_vector: A 1D tensor of shape (3), representing the normal vector.
    :return: A 2D tensor of shape (N, 3), representing the transformed points.
    """
    assert points.dim() == 2 and points.size(1) == 3, "points must be a 2D tensor with shape (N, 3)"
    assert normal_vector.dim() == 1 and normal_vector.size(0) == 3, "normal_vector must be a 1D tensor with shape (3)"

    # Normalize the normal vector
    normal_vector = normal_vector / torch.norm(normal_vector)

    # Compute the centroid of the points
    centroid = torch.mean(points, dim=0)

    # Compute the points in the plane by subtracting centroid and projecting to the plane
    points_in_plane = points - centroid
    points_in_plane = points_in_plane - (points_in_plane @ normal_vector.unsqueeze(-1)) * normal_vector.unsqueeze(0)

    # Compute the SVD of the points_in_plane
    u, s, vh = torch.linalg.svd(points_in_plane, full_matrices=False)

    # Construct the transformation matrix using the first two singular vectors and the normal vector
    transformation_matrix = torch.stack([vh[0], vh[1], normal_vector])

    # Transform the points using the transformation matrix
    transformed_points = (points - centroid) @ transformation_matrix.T

    return transformed_points, transformation_matrix

def fit_plane_to_points(
        normal_vec: Tensor,
        points: Tensor,
        return_residuals: bool = False
    ) -> Tensor:
    """
    Use least squares to fit a plane to a given set of points in K-dimensional space using a specified normal vector.
    This function computes the d coefficient of the plane equation: <n, P> + d = 0
    using the provided normal vector n and a set of points P.
    
    
    Parameters:
    -----------
    normal_vec : torch.Tensor
        A 1D tensor of shape (K,) representing the normal vector (a1, a2, ..., an) to the hyperplane.
    points : torch.Tensor
        A 2D tensor of shape (N, K), representing K points in n-dimensional space, where each row is a point (x1, x2, ..., xK).
    return_residuals : bool, optional
        Whether to return the residuals, i.e., the perpendicular distances of the points from the fitted hyperplane. Default is False.

    Returns:
    --------
    plane_params : torch.Tensor
        A 1D tensor of shape (K+1,) representing the coefficients (a1, a2, ..., aK, d) of the hyperplane equation.
    residuals : torch.Tensor (only if return_residuals is True)
        A 1D tensor of shape (N,) representing the residuals, i.e., the perpendicular distances of the points from the fitted hyperplane.

    Example:
    --------
    >>> normal_vec = torch.tensor([0., 1.])
    >>> points = torch.tensor([[1., 2.], [3., 4.], [5., 6.]])
    >>> fit_plane_to_points(normal_vec, points, return_residuals=True)
    (tensor([0., 1., -3.]), tensor([1., 0., -1.]))

    """
    assert normal_vec.dim() == 1, "normal_vec must be a 1D tensor"
    assert points.dim() == 2 and points.size(1) == normal_vec.size(0), "points must be a 2D tensor of shape (N, n) where n is the length of normal_vec"
    
    # Normalize the normal vector
    normal_vec = normal_vec / normal_vec.norm()
    
    # Solve for d in the hyperplane equation: a1*x1 + a2*x2 + ... + an*xn + d = 0
    d = - (points * normal_vec).sum(dim=-1).mean()

    # If residuals are requested
    if return_residuals:
        residuals = (points * normal_vec).sum(dim=-1) + d
        return torch.cat([normal_vec, d.unsqueeze(0)]), residuals
    
    return torch.cat([normal_vec, d.unsqueeze(0)])

def find_placeable_location(
        pointcloud: Tensor,
        ground_normal: Tensor,
        nbr_dist: float,
        residual_thresh: float,
        max_tries: Optional[int] = None,
        min_neighborhood_points: int = 3,
        min_area_prop: float = 0.25,
    ) -> Tuple[Tensor, float]:
    """
    Finds a suitable placement location on a flat surface in the given pointcloud.
    
    Args:
        pointcloud (Tensor): A 2D tensor representing the pointcloud, with shape (num_points, dims), 
            where dims is the dimensionality of the points.
        ground_normal (Tensor): A 1D tensor representing the normal vector to the ground plane.
        neighborhood_thresh (float): The threshold distance to determine the neighborhood of a point.
        residual_thresh (float): The threshold for the average absolute residual to determine if a 
            sampled point is suitable as a placement location.

    Returns:
        Tuple[Tensor, float]: A tuple containing the suitable location as a 1D tensor and the average 
            absolute residual of the fit at that location as a float.
            
    Usage:
        >>> pointcloud = torch.rand((1000, 3))  # Example pointcloud
        >>> ground_normal = torch.tensor([0., 1., 0.])  # Example normal vector
        >>> neighborhood_thresh = 0.1  # Example neighborhood threshold
        >>> residual_thresh = 0.01  # Example residual threshold
        >>> location, residual = find_placeable_location(pointcloud, ground_normal, neighborhood_thresh, residual_thresh)
        
    Note:
        The function will keep searching for a suitable location until it finds one. Consider setting a maximum 
        limit of tries or a timeout to avoid possible infinite loops if using in a scenario where a suitable 
        location may not exist.
    """
    assert pointcloud.ndim == 2 and pointcloud.shape[1] >= 3, f"Pointcloud must be a 2D Tensor with shape (num_points, 3), not {pointcloud.shape=}"
    num_points = pointcloud.shape[0]
    max_tries = max_tries if max_tries is not None else num_points
    max_tries = min(max_tries, num_points)
    
    print(max_tries,)
    idxs = torch.randperm(num_points)[:max_tries]
    print(idxs)
    for idx in idxs:
        # 1. Sample a location from the pointcloud
        sample_point = pointcloud[idx]
        
        # 2. Extract a neighborhood around that location
        dists = torch.norm(pointcloud - sample_point.unsqueeze(0), dim=1)
        neighborhood = pointcloud[dists < nbr_dist]
         
        if neighborhood.shape[0] < min_neighborhood_points:
            # If there are less than 3 points in the neighborhood, skip this iteration
            continue
        
        # 3. Check the fit of the oriented plane in that location using fit_plane_to_points
        plane_params, residuals = fit_plane_to_points(ground_normal, neighborhood, return_residuals=True)
        nbrhd_plane, tform = transform_basis(points=neighborhood, normal_vector=normal_vec)
        bounds = get_bounds(nbrhd_plane)
        mins, maxs = bounds[:2].unbind(dim=-1)
        area = torch.prod(maxs - mins, dim=-1)
        if area < (nbr_dist * 2) ** 2 * min_area_prop:
            continue
        residuals = nbrhd_plane[:, 2]
        
        # 4. If the fit average absolute residual is under some threshold, return that location
        avg_residual = torch.mean(torch.abs(residuals))
        if avg_residual < residual_thresh:
            return sample_point, avg_residual
    raise ValueError(f'No suitable location found after {max_tries} tries')

In [None]:
cls_idx = data.ALL_CLASS_NAMES_TO_CLASS_IDS['couch']
first_idx = list(scene_obs['box_classes'].cpu().numpy()).index(cls_idx)
inst_id = scene_obs['box_target_ids'][first_idx]


mask_obj = (ins_50k == inst_id)
pointcloud_obj_50k = Pointclouds(points=[locs_50k[mask_obj]], features=[col_50k[mask_obj]])
pointcloud_50k = Pointclouds(points=[locs_50k], features=[col_50k])

In [None]:
nbrhd_plane.shape

In [None]:
nbrhd_dist = 0.1
normal_vec = torch.tensor((0., 0., 1.))
loc, residual = find_placeable_location(
    pointcloud = pointcloud_obj_50k.points_packed(),
    ground_normal = normal_vec,
    nbr_dist = nbrhd_dist, # in meters
    residual_thresh = 0.01,
    max_tries = 100,
)
dists = torch.norm(pointcloud_obj_50k.points_packed() - loc.unsqueeze(0), dim=1)
nbrhd = pointcloud_obj_50k.points_packed()[dists < nbrhd_dist]
nbrhd_plane, tform = transform_basis(points=nbrhd, normal_vector=normal_vec)
nbrhd.shape

In [None]:
from home_robot.utils.point_cloud_torch import get_bounds
from home_robot.utils.bboxes_3d import box3d_volume_from_bounds
bounds = get_bounds(nbrhd_plane)
volume = box3d_volume_from_bounds(bounds)

mins, maxs = bounds[:2].unbind(dim=-1)
area = torch.prod(maxs - mins, dim=-1) * 1000
print(f'Area {area.item():0.4f} / Volume {volume.item():0.4f}')

In [None]:
maxs - mins

In [None]:

# -> SparseVoxelMapWithInstanceViews.show(backend='pytorch3d')

# Plot GT scene
from home_robot.utils.bboxes_3d import BBoxes3D, join_boxes_as_scene, join_boxes_as_batch
from home_robot.utils.bboxes_3d_plotly import plot_scene_with_bboxes
from pytorch3d.vis.plotly_vis import AxisArgs
from pytorch3d.structures import Pointclouds
import seaborn as sns

colors = torch.tensor(sns.color_palette("husl", len(scene_obs['boxes_aligned'])))
gt_boxes = BBoxes3D(
    bounds = [scene_obs['boxes_aligned']],
    # features = [colors[0].unsqueeze(0).expand(27,3)],
    features = [colors],
    names = [scene_obs['box_classes'].unsqueeze(-1)]
)

mask_obj = (ins_50k == inst_id)
# pointcloud_50k = Pointclouds(points=[locs_50k[mask_obj]], features=[col_50k[mask_obj]])
neighborhood = Pointclouds(points=[nbrhd])

fig = plot_scene_with_bboxes(
    plots = { f"{scene_id}": { 

                                "GT boxes": gt_boxes,
                                # "GT points": pointcloud_aligned,
                                "GT points smol": pointcloud_obj_50k,
                                "Place location": neighborhood,
                                # "cameras": cameras,
                            }
    },
    xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
    yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
    zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    axis_args=AxisArgs(showgrid=True),
    pointcloud_marker_size=3,
    pointcloud_max_points=200_000,
    boxes_wireframe_width=3,
    boxes_add_cross_face_bars=False,
    # boxes_name_int_to_display_name_dict = dict(zip([int(i) for i in data.METAINFO['seg_valid_class_ids']], data.METAINFO['classes'])),
    boxes_name_int_to_display_name_dict = dict(zip(data.METAINFO['CLASS_IDS'], data.METAINFO['CLASS_NAMES'])),

    boxes_plot_together=False,
    height=1000,
    # width=1000,
)
fig