In [None]:
%load_ext autoreload
%autoreload 2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import dataclasses
import sys
import timeit
from typing import Tuple
import pickle
import click
import matplotlib.pyplot as plt
import numpy as np
import open3d
import torch
from tqdm import tqdm


# from home_robot.mapping.voxel import SparseVoxelMap
from home_robot.utils.point_cloud_torch import unproject_masked_depth_to_xyz_coordinates

from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene, get_camera_wireframe
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import numpy as np
torch.set_printoptions(sci_mode=False)

def colormap_to_rgb_strings(data, colormap_name='viridis', include_alpha=False, min_val=None, max_val=None):
    """
    Convert a range of numbers from a given dataset into a series of RGB or RGBA strings using a specified Matplotlib colormap.

    :param data: The dataset from which to derive color mappings.
    :param colormap_name: The name of the Matplotlib colormap to use.
    :param include_alpha: Boolean to decide if the alpha channel should be included in the RGB strings.
    :param min_val: Optional minimum value for colormap scaling.
    :param max_val: Optional maximum value for colormap scaling.
    :return: A list of color strings in the format 'rgb(R,G,B)' or 'rgba(R,G,B,A)'.
    """
    # Compute min and max from the data if not provided
    if min_val is None:
        min_val = np.min(data)
    if max_val is None:
        max_val = np.max(data)

    # Normalize data within the provided or computed min and max range
    norm = plt.Normalize(min_val, max_val)
    colors = plt.cm.get_cmap(colormap_name)(norm(data))

    # Format color strings based on the include_alpha flag
    if include_alpha:
        return ["rgba({},{},{},{})".format(int(r*255), int(g*255), int(b*255), a) for r, g, b, a in colors]
    else:
        return ["rgb({},{},{})".format(int(r*255), int(g*255), int(b*255)) for r, g, b in colors[:, :3]]

def add_camera_poses(
    fig,
    poses,
    linewidth = 3,
    color = None,
    name = 'cam',
    separate = True,
    scale = 0.2,
    colormap_name='plasma'
    ):

    
    cam_points = get_camera_wireframe(scale)
    # Convert p3d (opengl) to opencv
    cam_points[:, 1] *= -1

    if color is None:
        colors = colormap_to_rgb_strings(list(range(len(poses))), colormap_name=colormap_name)
    else:
        colors = [color] * len(poses)
    for i, (pose, color) in enumerate(zip(poses, colors)):
        # cam_points[:, 2] *= -1
        R = pose[:3, :3]
        t = pose[:3, -1]
        cam_points_world = cam_points @ R.T + t.unsqueeze(0)  # (cam_points @ R) # + t)
        x, y, z = [v.cpu().numpy().tolist() for v in cam_points_world.unbind(1)]
        fig.add_trace(
            go.Scatter3d(
                x=x,
                y=y,
                z=z,
                mode="lines",
                marker={
                    "size": 1,
                    "color": color,
                },
                line=dict(
                    width=linewidth,
                    color=color,
                ),
                name=f'{name}-{i}',
            )
        )


In [None]:
from home_robot.datasets.eqa.dataset import EQADataset
config_dict = dict(
    dataset_name = 'eqa',
    camera_params = dict(
        image_height=1080,
        image_width=1920,
        # png_depth_scale = 1000.0 #for depth image in png format
        png_depth_scale = 6553.5 #for depth image in png format
    )
)

dataset = EQADataset(
        config_dict,
        # '/checkpoint/maksymets/eaif/datasets/eqa-v2/frames/scannet-v0/',
        # '108-scannet-scene0354_00',
        '/checkpoint/maksymets/eaif/datasets/eqa-v2/frames/hm3d-v0/',
        '000-hm3d-BFRyYbPCCPE',
        desired_height=480,
        desired_width=853,
        stride = 1,
        device='cpu'
)

In [None]:
# short_dataset = [dataset[i] for i in range(3)]
short_dataset = [v for v in dataset]

In [None]:
rgb, depth, Ks, pose = zip(*short_dataset)
rgb, depth, Ks, pose = [torch.stack(v) for v in [rgb, depth, Ks, pose]]
print(len(dataset))
plt.imshow(rgb[0].cpu() / 255.)
plt.show()
plt.imshow(rgb[1].cpu() / 255.)
plt.show()

In [None]:
v = torch.tensor([[1., 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
v.inverse() - v

In [None]:
from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene, get_camera_wireframe
from home_robot.utils.bboxes_3d_plotly import plot_scene_with_bboxes

unprojected = unproject_masked_depth_to_xyz_coordinates(
    # depth = depth[0, None].unsqueeze(1),
    # pose = poses_opencv[0, None],
    depth = depth.unsqueeze(1).squeeze(-1),
    pose = pose,
    inv_intrinsics = torch.linalg.inv(Ks)[:, :3, :3],
    # mask: Optional[torch.Tensor] = None,
) 

ptc = Pointclouds(
    [unprojected.reshape(-1, 3)],
    features = [rgb.reshape(-1,3) / 255.],
).subsample(100000)

fig = plot_scene({
    "global scene": dict(
        ptc=ptc
    )
    },
    xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
    yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
    zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    axis_args=AxisArgs(showgrid=True),
    pointcloud_marker_size=3,
    pointcloud_max_points=200_000,
    height=1000,
    # width=1000,
)

add_camera_poses(fig, pose)
fig.update_layout(
    # width=width,
    height=1000,
    # aspectmode="data"
)

In [None]:
# dataset.rawpose[1]
pose[1]