In [None]:
import matplotlib.pyplot as pl
import numpy as np
import torch
from scipy.spatial.transform import Rotation
import plotly.graph_objects as go
import plotly.subplots as sp
import trimesh
import hydra
from omegaconf import OmegaConf
import os
import time
import copy

import rootutils
rootutils.setup_root("/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex/dust3r/fast3r/src", indicator=".project-root", pythonpath=True)

from src.dust3r.inference_multiview import inference
from src.dust3r.model import FlashDUSt3R
from src.dust3r.utils.image import load_images, rgb
from src.dust3r.viz import CAM_COLORS, OPENGL, add_scene_cam, cat_meshes, pts3d_to_trimesh


pl.ion()


def get_reconstructed_scene(
    outdir,
    model,
    device,
    silent,
    image_size,
    filelist,
    dtype=torch.float32,
):
    """
    from a list of images, run dust3r inference, global aligner.
    then run get_3D_model_from_scene
    """
    multiple_views_in_one_sample = load_images(filelist, size=image_size, verbose=not silent)

    # time the inference
    start = time.time()
    output = inference(multiple_views_in_one_sample, model, device, dtype=dtype, verbose=not silent)
    end = time.time()
    print(f"Time elapsed: {end - start}")

    return output



def plot_rgb_images(views, title="RGB Images"):
    fig = sp.make_subplots(rows=1, cols=len(views), subplot_titles=[f"View {i} Image" for i in range(len(views))])

    # Plot the RGB images
    for i, view in enumerate(views):
        img_rgb = view['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (224, 224, 3)
        # Rescale RGB values from [-1, 1] to [0, 255]
        img_rgb = ((img_rgb + 1) * 127.5).astype(int).clip(0, 255)
        
        fig.add_trace(go.Image(z=img_rgb), row=1, col=i+1)

    fig.update_layout(
        title=title,
        margin=dict(l=0, r=0, b=0, t=40)
    )

    fig.show()

def plot_confidence_maps(preds, title="Confidence Maps"):
    fig = sp.make_subplots(rows=1, cols=len(preds), subplot_titles=[f"View {i} Confidence" for i in range(len(preds))])

    # Plot the confidence maps
    for i, pred in enumerate(preds):
        conf = pred['conf'].cpu().numpy().squeeze()
        fig.add_trace(go.Heatmap(z=conf, colorscale='Viridis', showscale=False), row=1, col=i+1)

    fig.update_layout(
        title=title,
        margin=dict(l=0, r=0, b=0, t=40)
    )

    for i in range(len(preds)):
        fig['layout'][f'yaxis{i+1}'].update(autorange='reversed')

    fig.show()

def plot_3d_points_with_colors(preds, views, title="3D Points Visualization", flip_axes=False, as_mesh=False, min_conf_thr_percentile=80, export_ply_path=None):
    fig = go.Figure()

    all_points = []
    all_colors = []
    
    if as_mesh:
        meshes = []
        for i, pred in enumerate(preds):
            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Ensure tensor is on CPU and convert to numpy
            img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (224, 224, 3)
            conf = pred['conf'].cpu().numpy().squeeze()

            # Determine the confidence threshold based on the percentile
            conf_thr = np.percentile(conf, min_conf_thr_percentile)

            # Filter points based on the confidence threshold
            mask = conf > conf_thr

            # Rescale RGB values from [-1, 1] to [0, 255]
            img_rgb = ((img_rgb + 1) * 127.5).astype(np.uint8).clip(0, 255)

            # Generate the mesh for the current view
            mesh_dict = pts3d_to_trimesh(img_rgb, pts3d, valid=mask)
            meshes.append(mesh_dict)

        # Concatenate all meshes
        combined_mesh = trimesh.Trimesh(**cat_meshes(meshes))

        # Flip axes if needed
        if flip_axes:
            combined_mesh.vertices[:, [1, 2]] = combined_mesh.vertices[:, [2, 1]]
            combined_mesh.vertices[:, 2] = -combined_mesh.vertices[:, 2]

        # Export as .ply if the path is provided
        if export_ply_path:
            combined_mesh.export(export_ply_path)

        # Add the combined mesh to the plotly figure
        vertex_colors = combined_mesh.visual.vertex_colors[:, :3]  # Ensure the colors are in RGB format
        # Map vertex colors to face colors
        face_colors = []
        for face in combined_mesh.faces:
            face_colors.append(np.mean(vertex_colors[face], axis=0))
        face_colors = np.array(face_colors).astype(int)
        face_colors = ['rgb({}, {}, {})'.format(r, g, b) for r, g, b in face_colors]

        fig.add_trace(go.Mesh3d(
            x=combined_mesh.vertices[:, 0], 
            y=combined_mesh.vertices[:, 1], 
            z=combined_mesh.vertices[:, 2],
            i=combined_mesh.faces[:, 0], 
            j=combined_mesh.faces[:, 1], 
            k=combined_mesh.faces[:, 2],
            facecolor=face_colors,
            opacity=0.5,
            name="Combined Mesh"
        ))
    else:
        # Loop through each set of points in preds
        for i, pred in enumerate(preds):
            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Ensure tensor is on CPU and convert to numpy
            img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (224, 224, 3)
            conf = pred['conf'].cpu().numpy().squeeze()

            # Determine the confidence threshold based on the percentile
            conf_thr = np.percentile(conf, min_conf_thr_percentile)

            # Flatten the points and colors
            x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()
            r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()
            conf_flat = conf.flatten()

            # Apply confidence mask
            mask = conf_flat > conf_thr
            x, y, z = x[mask], y[mask], z[mask]
            r, g, b = r[mask], g[mask], b[mask]

            # Collect points and colors for exporting
            all_points.append(np.vstack([x, y, z]).T)
            all_colors.append(np.vstack([r, g, b]).T)

            # Rescale RGB values from [-1, 1] to [0, 255]
            r = ((r + 1) * 127.5).astype(int).clip(0, 255)
            g = ((g + 1) * 127.5).astype(int).clip(0, 255)
            b = ((b + 1) * 127.5).astype(int).clip(0, 255)

            colors = ['rgb({}, {}, {})'.format(r[j], g[j], b[j]) for j in range(len(r))]
            
            # Check the flag and flip axes if needed
            if flip_axes:
                x, y, z = x, z, y
                z = -z

            # Add points to the plot
            fig.add_trace(go.Scatter3d(
                x=x, y=y, z=z,
                mode='markers',
                marker=dict(size=2, opacity=0.8, color=colors),
                name=f"View {i}"
            ))

        # Export as .ply if the path is provided
        if export_ply_path:
            all_points = np.vstack(all_points)
            all_colors = np.vstack(all_colors)
            point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)
            point_cloud.export(export_ply_path)

    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        margin=dict(l=0, r=0, b=0, t=40),
        height=1000
    )

    fig.show()


import numpy as np
import torch
import plotly.graph_objects as go
from src.dust3r.cloud_opt.init_im_poses import fast_pnp
from src.dust3r.viz import auto_cam_size
from src.dust3r.viz_plotly import SceneViz
from src.dust3r.utils.image import rgb  # Assuming you have this utility for image processing

# Function to estimate camera poses using fast_pnp
def estimate_camera_poses(preds, views, niter_PnP=10):
    """Estimate camera poses and focal lengths using fast_pnp."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    poses_c2w = []  # List of camera-to-world poses
    estimated_focals = []  # List of guessed focal lengths

    for view_idx in range(len(preds)):
        pts3d = preds[view_idx]["pts3d_in_other_view"].cpu().numpy().squeeze()  # (224, 224, 3) shape
        valid_mask = preds[view_idx]["conf"].cpu().numpy().squeeze() > 0.5  # Confidence mask
        img_rgb = views[view_idx]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # (224, 224, 3)

        # Call fast_pnp with unflattened pts3d and mask
        focal_length, pose_c2w = fast_pnp(
            torch.tensor(pts3d, device=device),  # Unmasked points
            None,  # Guess focal length
            torch.tensor(valid_mask, device=device, dtype=torch.bool),  # Valid mask (unflattened)
            device,
            pp=None,  # Use default principal point (center of image)
            niter_PnP=niter_PnP
        )

        if pose_c2w is None:
            print(f"Failed to estimate pose for view {view_idx}")
            continue

        # Store estimated camera-to-world pose and focal length
        poses_c2w.append(pose_c2w.cpu().numpy())
        estimated_focals.append(focal_length)

    return poses_c2w, estimated_focals

# Function to visualize 3D points and camera poses with SceneViz
def plot_3d_points_with_estimated_camera_poses(preds, views, title="3D Points and Camera Poses", flip_axes=False, min_conf_thr_percentile=80, export_ply_path=None):
    # Initialize SceneViz for visualization
    viz = SceneViz()

    # Flip axes if requested
    if flip_axes:
        preds = copy.deepcopy(preds)
        for i, pred in enumerate(preds):
            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # (224, 224, 3)
            pts3d = pts3d[..., [0, 2, 1]]  # Swap Y and Z axes
            pts3d[..., 2] *= -1  # Flip the sign of the Z axis
            pred['pts3d_in_other_view'] = torch.tensor(pts3d)  # Reassign the modified points back to pred

    # Estimate camera poses and focal lengths
    poses_c2w, estimated_focals = estimate_camera_poses(preds, views, niter_PnP=10)
    cam_size = max(auto_cam_size(poses_c2w), 0.05)  # Auto-scale based on the point cloud

    # Set up point clouds and visualization
    for i, (pred, pose_c2w) in enumerate(zip(preds, poses_c2w)):
        pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # (224, 224, 3)
        img_rgb = rgb(views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0))  # Shape: (224, 224, 3)
        conf = pred['conf'].cpu().numpy().squeeze()

        # Determine the confidence threshold based on the percentile
        conf_thr = np.percentile(conf, min_conf_thr_percentile)
        mask = conf > conf_thr

        # Add the point cloud directly to the SceneViz object
        viz.add_pointcloud(pts3d, img_rgb, mask=mask, point_size=2, view_idx=i)

        # Add camera to the visualization
        viz.add_camera(
            pose_c2w=pose_c2w,  # Estimated camera-to-world pose
            focal=estimated_focals[i],  # Estimated focal length for each view
            color=np.random.randint(0, 256, size=3),  # Generate a random RGB color for each camera
            image=img_rgb,  # Image of the view
            cam_size=cam_size,  # Auto-scaled camera size
            view_idx=i
        )

    # Export point clouds and meshes if the path is provided
    if export_ply_path:
        all_points = []
        all_colors = []
        for i, pred in enumerate(preds):
            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()
            img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)
            conf = pred['conf'].cpu().numpy().squeeze()
            conf_thr = np.percentile(conf, min_conf_thr_percentile)
            mask = conf > conf_thr
            all_points.append(pts3d[mask])
            all_colors.append(img_rgb[mask])
        
        all_points = np.vstack(all_points)
        all_colors = np.vstack(all_colors)
        point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)
        point_cloud.export(export_ply_path)

    # Show the visualization
    viz.show()


In [None]:
import matplotlib.pyplot as plt
from PIL import Image

data_root = "/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex/dust3r/data"

filelist_train = [
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000001.jpg",
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000002.jpg"
]

# apple
filelist_test = [
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000200.jpg",
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000085.jpg",
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000090.jpg",
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000170.jpg",
    f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000199.jpg",
]


# bench test
# filelist_test = [
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000006.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000016.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000026.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000036.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000096.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000126.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000156.jpg",
#     f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000186.jpg",
# ]

# # teddy bear train
# filelist_test = [
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000001.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000002.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000003.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000004.jpg",
#     # f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000012.jpg",
#     # f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000022.jpg",
#     # f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000032.jpg",
# ]
# teddy bear test
# filelist_test = [
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000016.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000026.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000126.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000156.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000186.jpg",
# ]

# teddy bear random order
# filelist_test = [
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000126.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000026.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000186.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000016.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000156.jpg",
# ]


# filelist_test = [
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000066.jpg",
#     f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg",
# ]

# suitcase test
# filelist_test = [
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000006.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000016.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000026.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000036.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000096.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000126.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000156.jpg",
#     f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000186.jpg",
# ]

# cake test
# filelist_test = [
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000006.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000016.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000026.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000036.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000096.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000126.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000156.jpg",
#     f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000186.jpg",
# ]

# in-the-wild obj: book
# filelist_test = [
#     f"{data_root}/unseen_book/IMG_9837.jpg",
#     f"{data_root}/unseen_book/IMG_9838.jpg",
#     f"{data_root}/unseen_book/IMG_9839.jpg",
#     f"{data_root}/unseen_book/IMG_9840.jpg",
#     f"{data_root}/unseen_book/IMG_9841.jpg",
#     f"{data_root}/unseen_book/IMG_9842.jpg",
#     f"{data_root}/unseen_book/IMG_9843.jpg",
#     f"{data_root}/unseen_book/IMG_9844.jpg",
# ]

# in-the-wild obj: beef jerky
# filelist_test = [
#     f"{data_root}/beef_jerky/IMG_0050.jpg",
#     f"{data_root}/beef_jerky/IMG_0051.jpg",
#     f"{data_root}/beef_jerky/IMG_0052.jpg",
#     f"{data_root}/beef_jerky/IMG_0053.jpg",
#     f"{data_root}/beef_jerky/IMG_0054.jpg",
#     f"{data_root}/beef_jerky/IMG_0055.jpg",
#     f"{data_root}/beef_jerky/IMG_0056.jpg",
#     f"{data_root}/beef_jerky/IMG_0057.jpg",
#     f"{data_root}/beef_jerky/IMG_0058.jpg",
# ]

# HSSD
# filelist_test = [
#     f"{data_root}/0_102344022_0/rgb/0000{i:02d}.png" for i in range(8)
# ]

filelist_test = [
    f"{data_root}/17_102344250_4/rgb/0000{i:02d}.png" for i in range(2,8)
]

# unseen obj: teddy bear from co3d
# filelist_test = [
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000006.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000036.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000056.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000086.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000096.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000126.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000156.jpg",
#     "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000186.jpg",
# ]


# unseen obj: keyboard from co3d
# filelist_test = [
#     "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000096.jpg",
#     "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000126.jpg",
#     "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000156.jpg",
#     # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000186.jpg",
#     # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000006.jpg",
#     # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000016.jpg",
#     # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000026.jpg",
#     "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000036.jpg",
# ]

# filelist_test = [
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000006.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000016.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000026.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000036.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000096.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000126.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000156.jpg",
#     "/fsx-cortex/jianingy/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000186.jpg",
# ]

# DTU
# filelist_test = [
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_001_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_002_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_003_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_004_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_005_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_006_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_007_max.png",
#     "/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_008_max.png",
# ]

# display the images
def display_images(filelist, title):
    fig, axes = plt.subplots(1, len(filelist), figsize=(30, 4))
    fig.suptitle(title)
    for ax, filepath in zip(axes if hasattr(axes, '__iter__') else [axes], filelist):
        img = Image.open(filepath)
        ax.imshow(img)
        ax.axis('off')
    plt.show()

# # Display train images
# display_images(filelist_train, 'Train Images')

# Display test images
display_images(filelist_test, 'Test Images')

In [None]:
device = torch.device("cuda")

checkpoint_root = "/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex/dust3r/checkpoints"

# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_longer_epochs/checkpoint-best.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview/checkpoint-best.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview_co3d_full/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview_co3d_full_100_epochs_100_samples_per_window/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset/checkpoint-best.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_dec_and_head/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_large/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth/checkpoint-last.pth').to(device)
model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/bf16_flash_attn_unfreeze_everything_co3d_scannetpp_megadepth_large_bs4/checkpoint-10.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/bf16_flash_attn_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)

In [None]:
# Lightning model
# instantiate lit module from hydra yaml /opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex/dust3r/fast3r/configs/model/multiview_dust3r.yaml
device = torch.device("cuda")
checkpoint_root = "/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex"

# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/train/runs/2024-08-13_04-40-37"  #fp32-fancy-sun-181
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/train/runs/2024-08-13_08-06-08"  #fp32_workers11_giddy-gorge-182
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_3782640"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs2_views8/runs/fp32_bs2_views8_3782638"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4007485"  # with random image idx embeddings
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4030983"  # fix Regr3D loss (wrong rotation)
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4037511"  # fix Regr3D loss (fixed rotation)
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_scannetpp_only/runs/fp32_bs6_views4_scannetpp_only_4060428"  # ScanNet++ only no random emb
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_scannetpp_only/runs/fp32_bs6_views4_scannetpp_only_4051504"  # ScanNet++ only
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_arkitscenes_only/runs/arkitscenes_only_4123064"  # ARKitScenes only
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/arkitscenes_only_no_pairs/runs/arkitscenes_only_no_pairs_4129400"  # ARKitScenes only no pairs
checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp/runs/co3d_scannetpp_4123062"  # co3d_scannetpp
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes/runs/co3d_scannetpp_arkitscenes_4123063"  # co3d_scannetpp_arkitscenes

print("Creating an empty lightning module to hold the weights...")
cfg = OmegaConf.load(os.path.join(checkpoint_dir, '.hydra/config.yaml'))

# replace all occurances of "dust3r." in cfg.model.net with "src.dust3r." (this is due to relocation of our code)
def replace_dust3r_in_config(cfg):
    for key, value in cfg.items():
        if isinstance(value, dict):
            replace_dust3r_in_config(value)
        elif isinstance(value, str):
            if "dust3r." in value and "src.dust3r." not in value:
                cfg[key] = value.replace("dust3r.", "src.dust3r.")
    return cfg


cfg.model.net = replace_dust3r_in_config(cfg.model.net)

cfg.model.net.patch_embed_cls = "PatchEmbedDust3R"  # TODO: investigate what exactly this does, this seems to support inferencing images of protrait orientation
cfg.model.net.landscape_only = False  # TODO: investigate what exactly this does


empty_lit_module = hydra.utils.instantiate(cfg.model, train_criterion=None, validation_criterion=None)


print("Loading weights from checkpoint...")
CKPT_PATH = os.path.join(checkpoint_dir, 'checkpoints/last.ckpt')

checkpoint = torch.load(CKPT_PATH)
empty_lit_module.load_state_dict(checkpoint["state_dict"])
model = empty_lit_module.net.to(device)


In [None]:
output = get_reconstructed_scene(
    outdir = "./output",
    model = model,
    device = device,
    silent = False,
    # image_size = 224,
    image_size = 512,
    filelist = filelist_test,
    dtype = torch.float32,
    # dtype = torch.bfloat16,
)

In [None]:
%load_ext autoreload
%autoreload 2


# Usage example in your context
# Plot the RGB images
plot_rgb_images(output['views'])

# Plot the confidence maps
plot_confidence_maps(output['preds'])

# Plot the 3D points along with estimated camera poses
plot_3d_points_with_estimated_camera_poses(
    output['preds'],  # Predictions containing 3D points
    output['views'],  # Views containing RGB images
    flip_axes=True,   # Enable flipping of axes (swap Y and Z and flip Z)
    min_conf_thr_percentile=40,  # Confidence threshold percentile for filtering points
    export_ply_path='./output/combined_mesh.ply'  # Export path for the .ply file
)


In [None]:

# Plot the RGB images
plot_rgb_images(output['views'])

# Plot the confidence maps
plot_confidence_maps(output['preds'])

# Plot the 3D points
plot_3d_points_with_colors(output['preds'], output['views'], flip_axes=True, as_mesh=False, min_conf_thr_percentile=30, export_ply_path='./output/combined_mesh.ply')


In [None]:
import numpy as np
import plotly.graph_objects as go
from scipy.linalg import rq

from tqdm import tqdm

def estimate_camera_matrix(world_points, image_points):
    """
    Estimate the camera matrix from 3D world points and 2D image points using DLT.
    
    Parameters:
    world_points (np.ndarray): Array of 3D points in the world coordinates, shape (N, 3).
    image_points (np.ndarray): Array of 2D points in the image coordinates, shape (N, 2).
    
    Returns:
    np.ndarray: The 3x4 camera matrix.
    """
    assert world_points.shape[0] == image_points.shape[0], "Number of points must match"
    num_points = world_points.shape[0]
    
    # Add homogeneous coordinates to the world points
    homogeneous_world_points = np.hstack((world_points, np.ones((num_points, 1))))
    
    A = []
    
    for i in range(num_points):
        X, Y, Z, _ = homogeneous_world_points[i]
        u, v = image_points[i]
        
        # Two rows of the equation for each point
        A.append([X, Y, Z, 1, 0, 0, 0, 0, -u*X, -u*Y, -u*Z, -u])
        A.append([0, 0, 0, 0, X, Y, Z, 1, -v*X, -v*Y, -v*Z, -v])
    
    # Convert A to a numpy array
    A = np.array(A)
    
    # Solve using SVD (Singular Value Decomposition)
    U, S, Vt = np.linalg.svd(A)
    
    # The last row of Vt (or last column of V) is the solution
    P = Vt[-1].reshape(3, 4)
    
    return P

def decompose_camera_matrix(P):
    """
    Decompose the camera matrix into intrinsic and extrinsic matrices.
    
    Parameters:
    P (np.ndarray): The 3x4 camera matrix.
    
    Returns:
    K (np.ndarray): The 3x3 intrinsic matrix.
    R (np.ndarray): The 3x3 rotation matrix.
    t (np.ndarray): The 3x1 translation vector.
    """
    # Extract the camera matrix K and rotation matrix R using RQ decomposition
    M = P[:, :3]  # The first 3x3 part of P
    
    # RQ Decomposition of M
    K, R = rq(M)
    
    # Normalize K so that K[2,2] = 1
    K /= K[2, 2]
    
    # Compute translation vector
    t = np.dot(np.linalg.inv(K), P[:, 3])
    
    return K, R, t

def plot_camera_cones(fig, R, t, K, color='blue', scale=0.1):
    """
    Plot the camera as a cone in 3D space based on the intrinsic matrix K for focal length.
    
    Parameters:
    fig (plotly.graph_objects.Figure): The existing Plotly figure.
    R (np.ndarray): The 3x3 rotation matrix.
    t (np.ndarray): The 3x1 translation vector.
    K (np.ndarray): The 3x3 intrinsic matrix.
    color (str): Color of the camera cone.
    scale (float): Scale factor for the size of the cone base.
    """
    # The focal length is the element K[0, 0] (assuming fx and fy are equal)
    focal_length = K[0, 0] / K[2, 2]

    # The camera center (apex of the cone)
    camera_center = -R.T @ t

    # Define the orientation of the cone based on the rotation matrix
    direction = R.T @ np.array([0, 0, 1])  # Camera looks along the +Z axis in camera space

    # Scale the direction by the focal length
    direction = direction * focal_length

    # Plot the camera cone
    fig.add_trace(go.Cone(
        x=[camera_center[0]],
        y=[camera_center[1]],
        z=[camera_center[2]],
        u=[direction[0]],
        v=[direction[1]],
        w=[direction[2]],
        colorscale=[[0, color], [1, color]],  # Single color for the cone
        showscale=False,
        sizemode="absolute",
        sizeref=scale,  # The size of the cone base
        anchor="tip",  # The tip of the cone is the camera center
        name="Camera Cone"
    ))

def plot_3d_points_with_estimated_camera(output, fig, camera_poses, min_conf_thr_percentile=80):
    """
    Plot 3D points together with estimated camera cones in the same plot.
    
    Parameters:
    output (dict): The output containing 'preds' with 3D points and corresponding 2D image points.
    fig (plotly.graph_objects.Figure): The existing 3D plot.
    camera_poses (list): List of estimated camera poses.
    min_conf_thr_percentile (int): Percentile threshold for confidence values to filter points.
    """
    # Plot the 3D points first
    all_points = []
    all_colors = []

    for i, pred in enumerate(output['preds']):
        pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # 3D points
        img_rgb = output['views'][i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # RGB image (224x224)
        conf = pred['conf'].cpu().numpy().squeeze()  # Confidence map

        # Apply confidence threshold
        conf_thr = np.percentile(conf, min_conf_thr_percentile)
        mask = conf > conf_thr

        # Rescale RGB values from [-1, 1] to [0, 255]
        img_rgb = ((img_rgb + 1) * 127.5).astype(np.uint8).clip(0, 255)

        # Flatten the points and colors, and apply mask
        x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()
        r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()
        x, y, z = x[mask.flatten()], y[mask.flatten()], z[mask.flatten()]
        r, g, b = r[mask.flatten()], g[mask.flatten()], b[mask.flatten()]

        colors = ['rgb({}, {}, {})'.format(r[j], g[j], b[j]) for j in range(len(r))]

        # Add points to the plot
        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode='markers',
            marker=dict(size=2, opacity=0.8, color=colors),
            name=f"View {i} Points"
        ))

    # Now, plot the estimated cameras as cones
    for i, (R, t, K) in enumerate(camera_poses):
        plot_camera_cones(fig, R, t, K, color='blue')

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )

def estimate_camera_poses(output, min_conf_thr_percentile=80):
    """
    Estimate camera poses from 3D points and 2D image points.
    
    Parameters:
    output (dict): The output containing 'preds' with 3D points and corresponding 2D image points.
    min_conf_thr_percentile (int): Percentile threshold for confidence values to filter points.
    
    Returns:
    list: A list of camera poses (R, t, K) where R is rotation, t is translation, and K is intrinsic matrix.
    """
    camera_poses = []
    
    # Loop through all views in output['preds']
    for i, pred in enumerate(output['preds']):
        # Get the 3D points and confidence map for the current view
        world_points = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Shape: (272, 512, 3)
        conf = pred['conf'].cpu().numpy().squeeze()  # Confidence map

        # Determine the confidence threshold based on the percentile
        conf_thr = np.percentile(conf, min_conf_thr_percentile)

        # Apply confidence mask to filter points
        mask = conf > conf_thr
        world_points_filtered = world_points[mask]

        # Generate 2D pixel coordinates corresponding to the filtered points
        h, w, _ = world_points.shape
        image_points = np.indices((h, w)).reshape(2, -1).T  # Shape: (N, 2)
        image_points_filtered = image_points[mask.flatten()]  # Apply mask to 2D points

        if world_points_filtered.shape[0] == 0:
            print(f"View {i}: No points above confidence threshold. Skipping camera estimation.")
            continue

        # Estimate the camera matrix
        P = estimate_camera_matrix(world_points_filtered, image_points_filtered)
        print(f"Camera matrix for view {i}:\n", P)

        # Decompose into intrinsic and extrinsic matrices
        K, R, t = decompose_camera_matrix(P)
        print(f"Intrinsic matrix (K) for view {i}:\n", K)
        print(f"Rotation matrix (R) for view {i}:\n", R)
        print(f"Translation vector (t) for view {i}:\n", t)

        # Store the camera pose (rotation, translation, and intrinsic matrix)
        camera_poses.append((R, t, K))
    
    return camera_poses

# Estimate the camera poses first
camera_poses = estimate_camera_poses(output, min_conf_thr_percentile=80)

# Create a 3D plot and plot the 3D points together with the estimated cameras
fig = go.Figure()
plot_3d_points_with_estimated_camera(output, fig, camera_poses, min_conf_thr_percentile=50)

# Display the final plot with 3D points and camera cones
fig.show()


In [None]:
output['views'][0]['img'].shape

# Align with DTU point cloud

In [None]:
# The Rt matrix of the first image lives at /fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18/pos_001.txt
# it looks like this:
# 2607.429996 -3.844898 1498.178098 -533936.661373
# -192.076910 2862.552532 681.798177 23434.686572
# -0.241605 -0.030951 0.969881 22.540121
# I'd like to use this this to rotate an input 3D points to the correct orientation
# my 3d points assumes the camera is at (0, 0, 0) and looking at (0, 0, 1)



In [None]:
import numpy as np
import trimesh
import plotly.graph_objs as go
from scipy.linalg import rq

def load_camera_matrix(filepath):
    """Loads the camera calibration matrix from the given file."""
    with open(filepath, 'r') as f:
        lines = f.readlines()
    camera_matrix = np.array([list(map(float, line.split())) for line in lines])
    return camera_matrix

def decompose_camera_matrix(camera_matrix):
    """Decomposes the camera calibration matrix into intrinsic matrix (K), rotation matrix (R), and translation vector (t)."""
    # The camera matrix is 3x4
    M = camera_matrix[:, :3]
    
    # RQ decomposition to separate K and R
    K, R = rq(M)
    
    # Normalize K to ensure the sign of the diagonal is positive
    T = np.diag(np.sign(np.diag(K)))
    K = K @ T
    R = T @ R
    
    # Compute translation vector t
    t = np.linalg.inv(K) @ camera_matrix[:, 3]
    
    # Camera position C = -R^T * t
    camera_position = -R.T @ t
    
    return K, R, t, camera_position

def apply_transformation_to_point_cloud(ply_filepath, camera_matrix_filepath):
    """Applies the rotation and translation from the decomposed camera matrix to a point cloud loaded from a .ply file."""
    
    # Load the point cloud
    point_cloud = trimesh.load(ply_filepath)
    
    # Load and decompose the camera matrix
    camera_matrix = load_camera_matrix(camera_matrix_filepath)
    K, R, t, camera_position = decompose_camera_matrix(camera_matrix)

    
    # print point cloud range before transformation
    print(f"X range: {np.min(point_cloud.vertices[:, 0])} - {np.max(point_cloud.vertices[:, 0])} = {np.max(point_cloud.vertices[:, 0]) - np.min(point_cloud.vertices[:, 0])}")
    print(f"Y range: {np.min(point_cloud.vertices[:, 1])} - {np.max(point_cloud.vertices[:, 1])} = {np.max(point_cloud.vertices[:, 1]) - np.min(point_cloud.vertices[:, 1])}")
    print(f"Z range: {np.min(point_cloud.vertices[:, 2])} - {np.max(point_cloud.vertices[:, 2])} = {np.max(point_cloud.vertices[:, 2]) - np.min(point_cloud.vertices[:, 2])}")

    # prting the camera position
    print(f"Camera position: {camera_position}")
    
    # Apply the rotation matrix to the point cloud vertices
    rotated_points = (R @ point_cloud.vertices.T).T
    
    # Apply translation
    transformed_points = rotated_points + t
    
    # Print the range of the transformed points per axis
    print(f"X range: {np.min(transformed_points[:, 0])} - {np.max(transformed_points[:, 0])} = {np.max(transformed_points[:, 0]) - np.min(transformed_points[:, 0])}")
    print(f"Y range: {np.min(transformed_points[:, 1])} - {np.max(transformed_points[:, 1])} = {np.max(transformed_points[:, 1]) - np.min(transformed_points[:, 1])}")
    print(f"Z range: {np.min(transformed_points[:, 2])} - {np.max(transformed_points[:, 2])} = {np.max(transformed_points[:, 2]) - np.min(transformed_points[:, 2])}")
    
    # Create a new point cloud with rotated and translated points
    transformed_point_cloud = trimesh.PointCloud(vertices=transformed_points, colors=point_cloud.colors)
    
    return transformed_point_cloud

def plot_point_cloud(point_cloud, title="Transformed Point Cloud"):
    """Visualizes a point cloud using Plotly."""
    x = point_cloud.vertices[:, 0]
    y = point_cloud.vertices[:, 1]
    z = point_cloud.vertices[:, 2]
    colors = point_cloud.colors / 255.0  # Normalize colors to [0, 1] for Plotly
    
    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=2,
            color=colors,
            opacity=0.8
        )
    )])
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        margin=dict(l=0, r=0, b=0, t=40),
        height=800
    )
    
    fig.show()

# Example usage:
ply_filepath = '/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex/dust3r/fast3r/notebooks/output/combined_mesh.ply'
camera_matrix_filepath = '/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18/pos_001.txt'

transformed_point_cloud = apply_transformation_to_point_cloud(ply_filepath, camera_matrix_filepath)

# Save the transformed point cloud to a new .ply file
transformed_point_cloud.export('/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/jianingy/research/accel-cortex/dust3r/fast3r/notebooks/output/transformed_output.ply')

# Visualize the transformed point cloud
plot_point_cloud(transformed_point_cloud)


In [None]:
import numpy as np
import trimesh

def load_and_print_xyz_ranges(ply_filepath):
    """Loads a point cloud from a .ply file and prints the XYZ ranges."""
    
    # Load the point cloud
    point_cloud = trimesh.load(ply_filepath)
    
    # Extract the vertices (XYZ coordinates)
    vertices = point_cloud.vertices
    
    # Calculate the ranges for X, Y, and Z
    x_min, x_max = np.min(vertices[:, 0]), np.max(vertices[:, 0])
    y_min, y_max = np.min(vertices[:, 1]), np.max(vertices[:, 1])
    z_min, z_max = np.min(vertices[:, 2]), np.max(vertices[:, 2])
    
    # Print the ranges
    print(f"X range: {x_min} - {x_max} = {x_max - x_min}")
    print(f"Y range: {y_min} - {y_max} = {y_max - y_min}")
    print(f"Z range: {z_min} - {z_max} = {z_max - z_min}")

# Example usage:
reference_ply_filepath = '/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Points/stl/stl006_total.ply'

load_and_print_xyz_ranges(reference_ply_filepath)


In [None]:
import numpy as np
import plotly.graph_objs as go
import os
from scipy.linalg import rq

def load_camera_matrix(filepath):
    """Loads the camera calibration matrix from the given file."""
    with open(filepath, 'r') as f:
        lines = f.readlines()
    camera_matrix = np.array([list(map(float, line.split())) for line in lines])
    return camera_matrix

def decompose_camera_matrix(camera_matrix):
    """Decomposes the camera calibration matrix into intrinsic matrix (K), rotation matrix (R), and translation vector (t)."""
    # The camera matrix is 3x4
    M = camera_matrix[:, :3]
    
    # RQ decomposition to separate K and R
    K, R = rq(M)
    
    # Normalize K to ensure the sign of the diagonal is positive
    T = np.diag(np.sign(np.diag(K)))
    K = K @ T
    R = T @ R
    
    # Compute translation vector t
    t = np.linalg.inv(K) @ camera_matrix[:, 3]
    
    # Camera position C = -R^T * t
    camera_position = -R.T @ t
    
    return K, R, t, camera_position

def plot_camera_poses(base_path, pose_count):
    """Plots all camera poses and visualizes them in Plotly."""
    camera_positions = []
    camera_orientations = []
    
    for i in range(1, pose_count + 1):
        filepath = os.path.join(base_path, f'pos_{i:03d}.txt')
        camera_matrix = load_camera_matrix(filepath)
        
        # Print the full camera matrix
        print(f"Camera Matrix {i}:\n{camera_matrix}\n")
        
        K, R, t, camera_position = decompose_camera_matrix(camera_matrix)
        
        # Print the decomposed matrices
        print(f"Intrinsic Matrix (K) {i}:\n{K}\n")
        print(f"Rotation Matrix (R) {i}:\n{R}\n")
        print(f"Translation Vector (t) {i}:\n{t}\n")
        print(f"Camera Position {i}: {camera_position}\n")
        
        # Camera direction (assuming camera is looking along -Z in its own coordinate system)
        camera_direction = R.T @ np.array([0, 0, -1])
        
        camera_positions.append(camera_position)
        camera_orientations.append(camera_direction)
    
    # Convert lists to numpy arrays
    camera_positions = np.array(camera_positions)
    camera_orientations = np.array(camera_orientations)
    
    # Create the 3D scatter plot for camera positions
    scatter = go.Scatter3d(
        x=camera_positions[:, 0],
        y=camera_positions[:, 1],
        z=camera_positions[:, 2],
        mode='markers',
        marker=dict(size=5, color='blue'),
        name='Camera Positions'
    )
    
    # Create the 3D quiver plot for camera orientations
    quiver = go.Cone(
        x=camera_positions[:, 0],
        y=camera_positions[:, 1],
        z=camera_positions[:, 2],
        u=camera_orientations[:, 0],
        v=camera_orientations[:, 1],
        w=camera_orientations[:, 2],
        sizemode='scaled',
        sizeref=2,
        colorscale='Blues',
        name='Camera Orientations'
    )
    
    # Set up the layout
    layout = go.Layout(
        title='Camera Poses Visualization',
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
        ),
        margin=dict(l=0, r=0, b=0, t=40),
        height=800
    )
    
    # Create the figure and show it
    fig = go.Figure(data=[scatter, quiver], layout=layout)
    fig.show()

# Example usage:
base_path = '/fsx-cortex/jianingy/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18'
pose_count = 49  # Adjust this according to the number of poses available

plot_camera_poses(base_path, pose_count)
