In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

## 1. Imports and Model Loading

In [None]:
import os
import uuid
import imageio
import numpy as np
import torch
from IPython.display import Image as ImageDisplay
from inference import Inference, ready_gaussian_for_video_rendering, load_image, load_masks, display_image, make_scene, render_video, interactive_visualizer


In [None]:
PATH = os.getcwd()
TAG = "hf"
config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"
inference = Inference(config_path, compile=False)

## 2. Load input image to lift to 3D (multiple objects)

In [None]:
IMAGE_PATH = f"{PATH}/images/nocs_0003_0354/rgb.png"
IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))

image = load_image(IMAGE_PATH)
masks = load_masks(os.path.dirname(IMAGE_PATH), extension=".png")
display_image(image, masks)

## 3. Generating Pointmap from Depth Image 

In [5]:
import imageio.v3 as iio

depth_path = f"{PATH}/images/nocs_0003_0354/depth.png"
depth = iio.imread(depth_path).astype(np.float32)
depth = depth / 1000.0 #convert to mm -> m
depth[depth <= 0] = np.nan  

H, W = depth.shape

K = np.array([
    [591.012500, 0.0,      322.525000],
    [0.0,        590.167750, 244.110840],
    [0.0,        0.0,        1.0]
], dtype=np.float32)

fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]

# ------------------ PIXEL GRID ------------------
u = np.arange(W)
v = np.arange(H)
uu, vv = np.meshgrid(u, v)

Z = depth
X = (uu - cx) * Z / fx
Y = (vv - cy) * Z / fy

# ---------------------------------------------------------------------
# Convert image coordinates (x→right, y→down) into PyTorch3D coordinates:
#   PyTorch3D expects a right-handed camera frame with:
#       +x → right, +y → UP, +z → forward.
#   So we flip both X and Y:
#       -Y  converts image Y-down into Y-up,
#       -X  keeps the coordinate system right-handed.
# ---------------------------------------------------------------------

pointmap = np.stack([-X, -Y, Z], axis=-1)
pointmaP = torch.tensor(pointmap, dtype=torch.float32)

## 4. Generate Gaussian Splats

In [None]:
outputs = [inference(image, mask, seed=42,pointmap=pointmaP) for mask in masks]

## 5. Mesh Alignment & Coordinate Frame Conversion

In [7]:
from pytorch3d.transforms import quaternion_to_matrix, Transform3d

# Z-up → Y-up conversion
R_zup_to_yup = torch.tensor([
    [-1, 0, 0],
    [0, 0, 1],
    [0, 1, 0],
], dtype=torch.float32)

R_yup_to_zup = R_zup_to_yup.T

# flip Z-axis 
R_flip_z = torch.tensor([
    [1, 0, 0],
    [ 0, 1, 0],
    [ 0, 0, -1],
], dtype=torch.float32)

# Convert from pointmap convention [-X, -Y, Z] back to true
R_pytorch3d_to_cam = torch.tensor([
    [-1,  0,  0],  
    [ 0, -1,  0],  
    [ 0,  0,  1], 
], dtype=torch.float32)

def transform_mesh_vertices(vertices, rotation, translation, scale):
    """
    Transform mesh vertices from local object space to world/camera frame:

    1. Flip Z-axis ( depending on GLB orientation)
    2. Convert from Y-up (GLB) to Z-up (canonical PyTorch3D frame)
    3. Apply GS outputs: scale, rotation, translation
    4. Convert back to Y-up for GLB export
    """

    if isinstance(vertices, np.ndarray):
        vertices = torch.tensor(vertices, dtype=torch.float32)

    vertices = vertices.unsqueeze(0)  #  batch dimension [1, N, 3]

    # Flip Z-axis
    vertices = vertices @ R_flip_z.to(vertices.device) 

    # Convert mesh from Y-up (GLB) → Z-up (canonical PyTorch3D)
    vertices = vertices @ R_yup_to_zup.to(vertices.device)

    # apply gaussian splatting transformations 
    R_mat = quaternion_to_matrix(rotation.to(vertices.device))
    tfm = Transform3d(dtype=vertices.dtype, device=vertices.device)
    tfm = (
        tfm.scale(scale)
           .rotate(R_mat)
           .translate(translation[0], translation[1], translation[2])
    )
    vertices_world = tfm.transform_points(vertices)

    # convert back to Y-up so GLB is saved correctly
    vertices = vertices @ R_zup_to_yup.to(vertices.device)

    # remove batch dimension
    return vertices_world[0]  


for i, out in enumerate(outputs):
    mesh = out["glb"]
    vertices = mesh.vertices

    vertices_tensor = torch.tensor(vertices)

    S = out["scale"][0].cpu().float()
    T = out["translation"][0].cpu().float()
    R = out["rotation"].squeeze().cpu().float()

    # Transform vertices
    vertices_transformed = transform_mesh_vertices(vertices, R, T, S)
  
    # --- Convert vertices from pointmap frame back to true camera frame ---
    # (undoing the earlier pointmap conversion: [-X, -Y, Z])
    vertices_transformed = vertices_transformed @ R_pytorch3d_to_cam.to(vertices_transformed.device)

    # Update mesh vertices
    mesh.vertices = vertices_transformed.cpu().numpy().astype(np.float32)

    # Export mesh
    save_path = f"{PATH}/meshes/multi/{IMAGE_NAME}/object_{i}.ply"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    mesh.export(save_path)

## 4. Visualize Gaussian Splat of the Scene
### a. Animated Gif

In [None]:
scene_gs = make_scene(*outputs)
scene_gs = ready_gaussian_for_video_rendering(scene_gs,fix_alignment=False)

# export gaussian splatting (as point cloud)
scene_gs.save_ply(f"{PATH}/gaussians/multi/{IMAGE_NAME}.ply")

video = render_video(
    scene_gs,
    r=1,
    fov=60,
    resolution=512,
)["color"]

# save video as gif
imageio.mimsave(
    os.path.join(f"{PATH}/gaussians/multi/{IMAGE_NAME}.gif"),
    video,
    format="GIF",
    duration=1000 / 30,  # default assuming 30fps from the input MP4
    loop=0,  # 0 means loop indefinitely
)

# notebook display
ImageDisplay(url=f"gaussians/multi/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}",)

### b. Interactive Visualizer

In [None]:
# might take a while to load (black screen)
interactive_visualizer(f"{PATH}/gaussians/single/{IMAGE_NAME}.ply")