In [None]:
%load_ext autoreload
%autoreload 2

import h5py
import torch

from conceptfusion import ConceptFusion

In [None]:
# parameters
from hydra import initialize, compose
import hydra

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize('configs')

args = compose(config_name='conceptfusion')

episode_file = "/srv/flash1/kyadav32/datasets/ovmm/hm3d_512x512/0001.h5"

In [None]:
torch.autograd.set_grad_enabled(False)

concept_fusion = ConceptFusion(**args)

In [None]:
from tqdm import tqdm
import pickle

concept_fusion.clear()

with open("/srv/flash1/kyadav32/datasets/ovmm/scannet/scannet_dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

with open("/srv/flash1/kyadav32/datasets/ovmm/scannet/scannet_scene0011_00_obs.pkl", "rb") as f:
    dataset_0 = [pickle.load(f)]

# SETUP EVAL
class_id_to_class_names = dict(
    zip(
        dataset.METAINFO["CLASS_IDS"],  # IDs [1, 3, 4, 5, ..., 65]
        dataset.METAINFO["CLASS_NAMES"],  # [wall, floor, cabinet, ...]
    )
)
# If this is an open-vocab detector, they sometimes require a vocab
concept_fusion.set_vocabulary(class_id_to_class_names)

keys = [
    "images",
    "depths",
    "poses",
    "intrinsics",
    "boxes_aligned",
    "box_classes",
]

concept_fusion.dbscan_params.epsilon = 0.3
concept_fusion.dbscan_params.min_samples = 25
concept_fusion.similarity_params.similarity_thresh = 0.9

gt_bounds, gt_classes, pred_bounds, pred_classes, pred_scores = [], [], [], [], []
for scene_obs in tqdm(dataset_0, desc="Evaluating scenes..."):
    # Move to device
    for k in keys:
        scene_obs[k] = scene_obs[k].to(concept_fusion.device)

    # Eval each scene and move to CPU
    queries = {
        int(clas): class_id_to_class_names[int(clas)]
        for clas in scene_obs["box_classes"].unique()
    }
    instances_dict = concept_fusion.build_scene_and_get_instances_for_queries(
        scene_obs, queries.values()
    )

In [None]:
concept_fusion.show_point_cloud_pytorch3d(instances_dict)

In [None]:
import open3d as o3d
import numpy as np

pcd = o3d.io.read_point_cloud("/srv/flash1/kyadav32/datasets/ovmm/scannet/scannet_scene0011_00_gt.ply")

from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import AxisArgs
from home_robot.utils.bboxes_3d import BBoxes3D

from home_robot.utils.bboxes_3d_plotly import plot_scene_with_bboxes
from home_robot.utils.data_tools.dict import update

from utils import COLOR_LIST

traces = {}

pc_xyz, pc_rgb = pcd.points, pcd.colors
pc_xyz = torch.tensor(np.asarray(pc_xyz), device="cuda")
pc_rgb = torch.tensor(np.asarray(pc_rgb), device="cuda")

traces["Points"] = Pointclouds(points=[pc_xyz], features=[pc_rgb*255])


box_classes = dataset_0[0]['box_classes']
box_bounds = dataset_0[0]['boxes_aligned']

bounds, names, colors = {}, {}, {}
for class_id, bound in zip(box_classes, box_bounds):
    
    class_name = class_id_to_class_names[class_id.item()]

    if class_name not in bounds:
        bounds[class_name] = []
        names[class_name] = []
        colors[class_name] = []
    bounds[class_name].append(bound)
    names[class_name].append(torch.tensor(class_id, device='cuda'))
    colors[class_name].append(torch.tensor(COLOR_LIST[class_id % len(COLOR_LIST)], device='cuda'))
for class_name in box_classes.keys():
    detected_boxes = BBoxes3D(
        bounds=[torch.stack(bounds, dim=0)],
        features=[torch.stack(colors, dim=0)],
        names=[torch.stack(names, dim=0).unsqueeze(-1)],
    )
    traces[class_name + "_bbox"] = detected_boxes


_default_plot_args = dict(
    xaxis={"backgroundcolor": "rgb(200, 200, 230)"},
    yaxis={"backgroundcolor": "rgb(230, 200, 200)"},
    zaxis={"backgroundcolor": "rgb(200, 230, 200)"},
    axis_args=AxisArgs(showgrid=True),
    pointcloud_marker_size=3,
    pointcloud_max_points=500_000,
    boxes_plot_together=False,
    boxes_wireframe_width=3,
)
fig = plot_scene_with_bboxes(
    plots={f"Conceptfusion Pointcloud": traces},
    **update(_default_plot_args, {}),
)
fig.update_layout(
    height=800,
    width=1600,
)


In [None]:
concept_fusion.similarity_params

In [None]:

from plyfile import PlyData, PlyElement

with open("/srv/flash1/kyadav32/datasets/ovmm/scannet/scannet_scene0011_00_gt.ply", "rb") as f:
    # load ply file
    plydata = PlyData.read(f)

# Print element names and properties
for element in plydata.elements:
    print(f"Element name: {element.name}")
    for prop in element.properties:
        print(f"\tProperty name: {prop.name}")

In [None]:
concept_fusion.dbscan_params.epsilon = 0.3
concept_fusion.dbscan_params.min_samples = 25
concept_fusion.similarity_params.viz_type = "thresh"
concept_fusion.similarity_params.similarity_thresh = 0.9

instances_dict = concept_fusion.get_instances_for_queries(["sofa", "windows", "table", "chair", "dressing_table"])
concept_fusion._show_point_cloud_pytorch3d(instances_dict)