In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.chdir('../')
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import sys
sys.path.append("Marigold")

import torch
import numpy as np
import jhutil; jhutil.color_log(1111, )

# 1. 3DGS -> GSplat

In [None]:
from collections import OrderedDict
from gaussiansplatting.scene.gaussian_model import GaussianModel


def gs_to_gsplat(gs_path, output_path):
    assert gs_path.endswith(".ply")
    assert output_path.endswith(".pt")

    gaussian = GaussianModel(sh_degree=0)
    gaussian.load_ply(gs_path)

    gsplat_data = {
        'step': 29999,
        'splats': OrderedDict({
            "means": gaussian._xyz.data,
            "sh0": gaussian._features_dc.data,
            "shN": gaussian._features_rest.data,
            "opacities": gaussian._opacity.data.squeeze(-1),
            "scales": gaussian._scaling.data,
            "quats": gaussian._rotation.data,
        })
    }
    torch.save(gsplat_data, output_path)

gs_path = "./dataset/llff_data/fern/gs/point_cloud/iteration_30000/point_cloud.ply"
output_path = "tmp.pt"
gs_to_gsplat(gs_path, output_path)

## 2. VGGT -> GSplat

In [None]:
from vggt import vggt_inference
vggt_out = vggt_inference(image_folder="./dataset/llff_data/fern/images")

In [None]:
# image_indices = [1, 3, 5]
# vggt_out["world_points_from_depth"] = vggt_out["world_points_from_depth"][[1, 3, 5]]
# vggt_out['depth'] = vggt_out['depth'][image_indices]
# vggt_out['images'] = vggt_out['images'][image_indices]

In [None]:
from einops import rearrange
from jhutil import rgb_to_sh0

def vggt_to_gsplat(vggt_out, output_path, pointmap_indices=None):
    means = vggt_out["world_points_from_depth"].float()
    images = vggt_out["images"]
    depth = vggt_out["depth"]
    if pointmap_indices is not None:
        means = means[pointmap_indices]
        images = images[pointmap_indices]
        depth = depth[pointmap_indices]
    b, h, w, _ = means.shape
    n = b * h * w

    means = rearrange(means, 'b h w c -> (b h w) c')
    sh0 = rgb_to_sh0(images)
    shN = torch.zeros(n, 15, 3)
    # opacities
    scale = rearrange(depth, 'b h w 1 -> (b h w) 1').repeat(1, 3) - 7.8
    opacities = torch.ones(n) * 5
    quats = torch.zeros(n, 4)
    quats[:, 3] = 1

    gsplat_data = {
        'step': 29999,
        'splats': OrderedDict({
            "means": means,
            "sh0": sh0,
            "shN": shN,
            "opacities": opacities,
            "scales": scale,
            "quats": quats,
        })
    }
    torch.save(gsplat_data, output_path)

pointmap_indices = [1, 3, 5]
vggt_to_gsplat(vggt_out, "tmp_vggt.pt", pointmap_indices)


## 3. VGGT -> 3DGS -> GSplat

In [None]:
vggt_out

In [None]:
from src.vggt_to_3dgs import *


In [None]:

output_dir = "./dataset/llff_data/fern/vggt"
vggt_to_3dgs(vggt_out, output_dir, pointmap_indices=pointmap_indices, extrinsic_is_c2w=False)

In [None]:
vggt_gs_path = "./dataset/llff_data/fern/vggt/point_cloud/iteration_0/point_cloud.ply"
output_path = "tmp_tmp.pt"
gs_to_gsplat(vggt_gs_path, output_path)

## 4. Rendering with Gsplat

In [None]:
import torch.nn.functional as F
from gaussiansplatting.scene.colmap_loader import (
    read_intrinsics_binary, 
    read_extrinsics_binary,
    qvec2rotmat
)
from gsplat.rendering import rasterization



def render_gsplat(ckpt_path, cameras_path, images_path):
    ckpt = torch.load(ckpt_path, map_location='cuda')['splats']

    # Prepare Gaussian parameters (apply activations)
    means = ckpt["means"]
    quats = F.normalize(ckpt["quats"], p=2, dim=-1)
    scales = torch.exp(ckpt["scales"])
    opacities = torch.sigmoid(ckpt["opacities"])
    sh0 = ckpt["sh0"]
    shN = ckpt["shN"]
    colors = torch.cat([sh0, shN], dim=-2)

    print(f"Number of Gaussians: {len(means)}")

    # Load COLMAP camera data
    cameras = read_intrinsics_binary(cameras_path)
    images = read_extrinsics_binary(images_path)

    print(f"Number of cameras: {len(cameras)}")
    print(f"Number of images: {len(images)}")

    # Select camera to render (change cam_idx to render different views)
    cam_idx = 1
    from jhutil import color_log; color_log(1111, images[cam_idx])
    image_id = sorted(images.keys())[cam_idx]
    img = images[image_id]
    cam = cameras[img.camera_id]

    print(f"\nRendering camera {cam_idx}: {img.name}")
    print(f"  Image size: {cam.width}x{cam.height}")

    # Build viewmat (world-to-camera, 4x4)
    R = qvec2rotmat(img.qvec)
    t = img.tvec
    w2c = np.eye(4)
    w2c[:3, :3] = R
    w2c[:3, 3] = t
    viewmat = torch.from_numpy(w2c).float().cuda()

    # Build intrinsic matrix K
    fx, fy, cx, cy = cam.params
    K = torch.tensor([
        [fx, 0, cx],
        [0, fy, cy],
        [0, 0, 1]
    ], dtype=torch.float32).cuda()

    width, height = cam.width, cam.height

    # Compute sh_degree from colors shape
    sh_degree = int(np.sqrt(colors.shape[1]) - 1)
    print(f"  SH degree: {sh_degree}")

    # Render RGB
    render_colors, render_alphas, meta = rasterization(
        means,          # [N, 3]
        quats,          # [N, 4]
        scales,         # [N, 3]
        opacities,      # [N]
        colors,         # [N, S, 3]
        viewmat[None],  # [1, 4, 4]
        K[None],        # [1, 3, 3]
        width,
        height,
        sh_degree=sh_degree,
        render_mode="RGB",
    )

    rendered_img = render_colors[0]  # [H, W, 3]
    return rendered_img


ckpt_path = "tmp_tmp.pt"
cameras_path = "./dataset/llff_data/fern/vggt/sparse/0/cameras.bin"
images_path = "./dataset/llff_data/fern/vggt/sparse/0/images.bin"
rendered_img = render_gsplat(ckpt_path, cameras_path, images_path)
print(f"\nRendered image shape: {rendered_img.shape}")
print(f"Rendered image range: [{rendered_img.min():.3f}, {rendered_img.max():.3f}]")

In [None]:
rendered_img.permute(2, 0, 1).rgb

In [None]:
from jhutil import load_img

gt_img = load_img("dataset/llff_data/fern/images/IMG_4027.JPG")
gt_img[:, ::5, ::5].rgb