In [9]:
import torch
from torchvision.transforms import ToTensor
import open3d as o3d

from utils.dataset import NYUv2
from utils.dataloader import DepthToTensor
from utils.dataset import get_segmentation_colors


# RGB Parameters
fx, fy = 5.1885790117450188e+02, 5.1946961112127485e+02
cx, cy = 3.2558244941119034e+02, 2.5373616633400465e+02


In [10]:
def depth_to_point_cloud(depth, K):
    """
    depth: Tensor of shape (H, W) - values in meters
    K: camera params
    Retorna: Tensor of shape (N, 3) with 3D points
    """
    device = depth.device
    H, W = depth.shape

    y, x = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij')

    z = depth
    x = (x - K["cx"]) * z / K["fx"]
    y = (y - K["cy"]) * z / K["fy"]
    xyz = torch.stack((x, y, z), dim=-1)  # (H, W, 3)
    xyz = xyz.view(-1, 3)  # (N, 3)
    valid = (depth > 0).view(-1)  # Mask for valid points
    return xyz[valid], valid


In [11]:
dataset = NYUv2(
    root="data/nyuv2",
    train=False,
    rgb_transform=ToTensor(),
    seg_transform=ToTensor(),
    depth_transform=DepthToTensor(),
)

rgb, seg, depth = dataset[0]  # depth in meters
print(f"RGB shape: {rgb.shape}, Segmentation shape: {seg.shape}, Depth shape: {depth.shape}")
depth = depth.squeeze(0)  # (H, W, 1) -> (H, W)

K = {"fx": fx, "fy": fy, "cx": cx, "cy": cy}

RGB shape: torch.Size([3, 480, 640]), Segmentation shape: torch.Size([480, 640]), Depth shape: torch.Size([1, 480, 640])


In [12]:
point_cloud, valid = depth_to_point_cloud(depth, K)  # (N, 3)

print("Point cloud:", point_cloud.shape)


Point cloud: torch.Size([307200, 3])


In [13]:
def visualize_point_cloud(xyz, rgb=None):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz.cpu().numpy())
    if rgb is not None:
        pcd.colors = o3d.utility.Vector3dVector(rgb.cpu().numpy())
    o3d.visualization.draw_geometries([pcd])

In [14]:
valid_rgb = rgb.permute(1, 2, 0).reshape(-1, 3)[valid]  # (N, 3)
visualize_point_cloud(point_cloud, valid_rgb)

In [None]:
colored_seg = get_segmentation_colors(seg)

IndexError: too many indices for array: array is 3-dimensional, but 4 were indexed

In [None]:
valid_seg = colored_seg.permute(1, 2, 0).reshape(-1, 3)[(depth > 0).view(-1)]  # (N, 3)
print("Valid segmentation shape:", valid_seg.shape)
visualize_point_cloud(point_cloud, valid_seg)

Valid segmentation shape: torch.Size([253016, 3])


In [None]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

2.6.0
12.6
True
