In [1]:
import logging
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from detectron2.structures import Instances
from depr import utils
from copy import deepcopy
import trimesh
import open3d as o3d
import depth_pro

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
torch.autograd.set_grad_enabled(False)
torch.cuda.set_device(0)
logging.basicConfig(level=logging.WARNING)

In [3]:
def initialize_depth_pro():
    config = depth_pro.depth_pro.DEFAULT_MONODEPTH_CONFIG_DICT
    config.checkpoint_uri = "checkpoint/depth_pro.pt"
    model, transform = depth_pro.create_model_and_transforms()
    model.eval()
    model.cuda()
    
    def get_depth_from_depth_pro(image: np.ndarray, f_px: float) -> tuple[torch.Tensor, torch.Tensor]:
        image = transform(image).cuda()
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]
        f_px = prediction['focallength_px']
        return depth.detach().cpu(), f_px.detach().cpu()
    return get_depth_from_depth_pro

get_depth_from_depth_pro = initialize_depth_pro()

In [4]:
%%capture
config_path = "checkpoint/config.yaml"
ckpt_path = "checkpoint/unet.safetensors"
model = utils.inference.prepare_model(config_path, ckpt_path)

In [5]:
model.scheduler.set_timesteps(50)
device = torch.device("cuda")
sam_path = "demo/segm"

In [6]:
image_id = "1"
image_path = Path("demo/imgs") / f"{image_id}.jpg"
print(f"Processing {image_id}...")
output_path = Path("output/demo/output") / image_id
output_path.mkdir(parents=True, exist_ok=True)
image = Image.open(image_path)
depth, f_px = get_depth_from_depth_pro(image, None)

image = torch.as_tensor(np.array(image)).permute(2, 0, 1)  # 3, 484, 646
K = torch.tensor([[f_px, 0, 323], [0, f_px, 242], [0, 0, 1]]).float()
intrinsics = K
print(f"f_px={f_px}")

masks, inst_ids, labels = utils.data.read_sam_segmentation(
    image_id, return_labels=True, sam_path=sam_path
)
instances = Instances((image.shape[-2], image.shape[-1]))
instances.gt_masks = masks

batched_inputs = [
    {
        "image": image.to(device),
        "depth": depth.to(device),
        "intrinsics": intrinsics.to(device),
        "instances": instances.to(device),
    }
]
print("Number of instances:", len(instances.gt_masks))


target_list = utils.geom.prepare_depth_point_cloud(
    depth,
    masks,
    intrinsics,
    num_sample_points=5000,
    clean_up_with_cluster=False,
)
utils.geom.save_colored_pointcloud(
    np.stack(target_list, axis=0), output_path / "scene_depth.ply"
)

output = model(
    batched_inputs,
    return_step_results=True,
    guidance_scale=None,
    return_first_step=False,
)[0]
triplanes = model.denormalize_triplanes(output["triplane"])
valid_indices = output["gt_indices"]

source_list, source_meshes = utils.geom.triplanes_to_point_clouds(
    model,
    triplanes,
    num_points=5000,
    sample_method="poisson",
    return_meshes=True,
)

for idx, mesh in enumerate(source_meshes):
    o3d.io.write_triangle_mesh(
        (output_path / f"obj_{idx}.ply").as_posix(),
        mesh,
    )

Processing 1...
f_px=455.90167236328125
Number of instances: 4


In [7]:
transform_matrices = utils.sample.get_scene_transformations(
    source_list,
    target_list,
    intrinsics,
    num_repeats=30,
    num_steps=500,
    lr=5e-2,
    cutoff_steps=300,
    verbose=True,
    return_losses=False,
    loss_3d_weight=100,
    loss_2d_weight=1,
    dof=5,
    enable_global_rotation=True,
)

np.savez(output_path / "info.npz", transform_matrices=transform_matrices, intrinsics=K.numpy())

pred_meshes = []
for idx in range(len(source_meshes)):
    mesh = source_meshes[idx]
    t = transform_matrices[idx]
    if t is None or mesh is None:
        continue
    mesh = deepcopy(mesh)
    pred_meshes.append(mesh.transform(t))
utils.geom.save_scene_meshes(pred_meshes, output_path / "scene.glb")

scene = trimesh.load(output_path / "scene.glb")
scene.apply_transform(np.diag([1, -1, -1, 1]))

Beginning scene registration...


  0%|          | 2/500 [00:00<00:26, 19.07it/s]

Step 0: 3D Loss: 0.06291420757770538, 2D Loss: 2764.984375, Total Loss: 6.291420936584473, Best Loss: 0.23658880591392517


 11%|█         | 55/500 [00:01<00:14, 30.55it/s]

Step 50: 3D Loss: 0.022647922858595848, 2D Loss: 758.2872924804688, Total Loss: 2.2647922039031982, Best Loss: 0.0732509046792984


 21%|██▏       | 107/500 [00:03<00:12, 30.69it/s]

Step 100: 3D Loss: 0.01798519305884838, 2D Loss: 571.7958374023438, Total Loss: 1.798519492149353, Best Loss: 0.060663528740406036


 31%|███       | 155/500 [00:05<00:11, 30.80it/s]

Step 150: 3D Loss: 0.01701332814991474, 2D Loss: 530.4622802734375, Total Loss: 1.701332926750183, Best Loss: 0.057855088263750076


 41%|████▏     | 207/500 [00:06<00:09, 30.90it/s]

Step 200: 3D Loss: 0.016743382439017296, 2D Loss: 522.1926879882812, Total Loss: 1.6743381023406982, Best Loss: 0.057142473757267


 51%|█████     | 255/500 [00:08<00:07, 30.73it/s]

Step 250: 3D Loss: 0.016668066382408142, 2D Loss: 519.1973876953125, Total Loss: 1.6668065786361694, Best Loss: 0.057099416851997375


 61%|██████    | 303/500 [00:09<00:06, 30.52it/s]

Step 300: 3D Loss: 0.016552280634641647, 2D Loss: 514.0314331054688, Total Loss: 515.6866455078125, Best Loss: 0.057105615735054016


 71%|███████   | 355/500 [00:11<00:04, 30.17it/s]

Step 350: 3D Loss: 0.01841696910560131, 2D Loss: 466.5566711425781, Total Loss: 468.39837646484375, Best Loss: 0.06080016493797302


 81%|████████  | 403/500 [00:13<00:03, 30.41it/s]

Step 400: 3D Loss: 0.01809140481054783, 2D Loss: 440.80645751953125, Total Loss: 442.6155700683594, Best Loss: 0.06127791479229927


 91%|█████████ | 455/500 [00:14<00:01, 30.23it/s]

Step 450: 3D Loss: 0.018200859427452087, 2D Loss: 437.137451171875, Total Loss: 438.9575500488281, Best Loss: 0.061569929122924805


100%|██████████| 500/500 [00:16<00:00, 30.37it/s]


Scene registration completed.


<trimesh.Scene(len(geometry)=4)>

In [8]:
scene.show()