In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install tqdm
%pip install matplotlib
%pip install piq
%pip install imageio
%pip install opencv-python
%pip install tensorboard
%pip install pycolmap
%pip install pyquaternion

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import imageio.v3 as iio
import json
import pycolmap
import os
from pathlib import Path
import shutil
from pyquaternion import Quaternion  # pip install pyquaternion


In [None]:
# Define the helper
def to_homogeneous(pose_3x4):
    bottom_row = torch.tensor([[0, 0, 0, 1]], dtype=pose_3x4.dtype, device=pose_3x4.device)
    pose_4x4 = torch.cat([pose_3x4, bottom_row], dim=0)
    return pose_4x4

In [None]:
def colmap_to_nerf_point(points):
    # Same flip to move world into NeRF-style frame where Z-forward = negative
    return points @ torch.diag(torch.tensor([1.0, -1.0, -1.0], dtype=points.dtype, device=points.device)).T

def colmap_to_nerf_pose(pose):
    flip = torch.diag(torch.tensor([1.0, -1.0, -1.0], dtype=pose.dtype, device=pose.device))
    R = pose[:3, :3] @ flip
    t = pose[:3, 3]
    return torch.cat([
        torch.cat([R, t.view(3, 1)], dim=1),
        torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=pose.dtype, device=pose.device)
    ], dim=0)

In [None]:
def sfm_extract(image_dir, device='cuda'):
    """
    Run SfM pipeline using pycolmap Python API, extract camera poses, intrinsics, and 3D points.
    """
    image_dir = Path(image_dir)
    database_path = Path("temp/database.db")
    sfm_path = Path("temp/sfm_output")

    if sfm_path.exists() and any(sfm_path.iterdir()):
        print(f"[INFO] Loading existing SfM reconstruction from {sfm_path}")
        reconstruction = pycolmap.Reconstruction(str(sfm_path / "0"))
    else:
        # Clean up previous runs
        if database_path.exists():
            database_path.unlink()
        if sfm_path.exists():
            shutil.rmtree(sfm_path)
        sfm_path.mkdir(parents=True, exist_ok=True)

        # 1. Extract features
        print("[INFO] Extracting features...")


        pycolmap.extract_features(
            database_path=str(database_path),
            image_path=str(image_dir),
            camera_model='PINHOLE',
            camera_mode='SINGLE'
        )

        # 2. Match features
        print("[INFO] Matching features...")
        pycolmap.match_exhaustive(str(database_path))

        # 3. Incremental mapping
        print("[INFO] Performing incremental mapping...")
        reconstructions = pycolmap.incremental_mapping(
            str(database_path),
            str(image_dir),
            str(sfm_path),
            initial_image_pair_callback=lambda: print("[INFO] Initial image pair registered."),
            next_image_callback=lambda: print("[INFO] Next image registered.")
        )

        if not reconstructions:
            raise RuntimeError("No reconstructions found")
        reconstruction = reconstructions[0]

    print(f"[INFO] Number of registered images: {len(reconstruction.images)}")
    print(f"[INFO] Number of 3D points: {len(reconstruction.points3D)}")

    # Extract camera poses and intrinsics
    pose_c2w_dict = {}
    intrinsics_dict = {}
    camera_models = set()

    for img_id, img in reconstruction.images.items():
        img_name = os.path.basename(img.name)
        pose = torch.tensor(img.cam_from_world.matrix(), dtype=torch.float32, device=device)
        intrinsics = torch.tensor(img.camera.params, dtype=torch.float32, device=device)

        pose_c2w_dict[img_name] = colmap_to_nerf_pose(pose)
        intrinsics_dict[img_name] = intrinsics

        camera = img.camera
        camera_models.add(camera.model)

        print(f"[CAMERA INFO] Image ID: {img_id}")
        print(f" - Image name: {img.name}")
        print(f" - Camera ID: {camera.camera_id}")
        print(f" - Camera model: {camera.model}")
        print(f" - Image size: {camera.width} x {camera.height}")
        print(f" - Intrinsic parameters ({len(camera.params)}): {camera.params}")
        print(f" - cam_from_world:\n{img.cam_from_world.matrix()}")
        print("-" * 60)

    if len(camera_models) == 1:
        print(f"[INFO] Single camera model detected: {list(camera_models)[0]}")
    else:
        print(f"[WARNING] Multiple camera models detected: {camera_models}")

    # Extract 3D points
    points3D = (
        torch.stack([
            colmap_to_nerf_point(torch.tensor(p.xyz, dtype=torch.float32, device=device))
            for p in reconstruction.points3D.values()
        ])
        if reconstruction.points3D else
        torch.empty((0, 3), device=device)
    )

    return pose_c2w_dict, intrinsics_dict, points3D


In [None]:
def get_points(points3D: torch.Tensor, N: int, jitter=1e-3):
    """
    Adjust the number of points in points3D to exactly N.
    If points3D has more than N points, randomly subsample.
    If points3D has fewer than N points, randomly duplicate points with small jitter.

    Args:
        points3D: (M, 3) tensor of input points
        N: desired number of points
        jitter: standard deviation of Gaussian noise added to duplicated points

    Returns:
        (N, 3) tensor of points
    """
    M = points3D.shape[0]
    device = points3D.device

    if M == N:
        return points3D.clone()

    elif M > N:
        # Subsample without replacement
        indices = torch.randperm(M, device=device)[:N]
        return points3D[indices].clone()

    else:
        # Need to upsample
        repeats = N // M
        remainder = N % M

        # Repeat entire set multiple times
        points_repeated = points3D.repeat(repeats, 1)

        # Sample some more points for remainder
        extra_indices = torch.randperm(M, device=device)[:remainder]
        extra_points = points3D[extra_indices]

        upsampled = torch.cat([points_repeated, extra_points], dim=0)

        # Add jitter noise to duplicated points only (not strictly necessary but helps break duplicates)
        noise = torch.randn_like(upsampled) * jitter
        upsampled += noise

        return upsampled

In [None]:
class NeRFDataset(Dataset):
    def __init__(self, json_path, image_size=512, device='cuda', sfm_poses=None, sfm_intrinsics=None):
        with open(json_path, 'r') as f:
            meta = json.load(f)

        self.frames = meta['frames']
        self.camera_angle_x = meta['camera_angle_x']
        self.image_size = image_size
        self.device = device
        self.base_dir = os.path.dirname(json_path)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((image_size, image_size))
        ])

        self.sfm_poses = sfm_poses
        self.sfm_intrinsics = sfm_intrinsics

    def rescale_intrinsics(self, intrinsics, orig_size, new_size):
        scale_x = new_size[0] / orig_size[0]
        scale_y = new_size[1] / orig_size[1]

        intrinsics = intrinsics.clone()
        intrinsics[0] *= scale_x  # fx
        intrinsics[1] *= scale_y  # fy
        intrinsics[2] *= scale_x  # cx
        intrinsics[3] *= scale_y  # cy
        return intrinsics

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]
        img_filename = os.path.basename(frame['file_path']) + '.png'  # or '.jpg'

        # Load raw image to get original size
        img_path = os.path.join(self.base_dir, frame['file_path'] + '.png')
        raw = iio.imread(img_path).astype(np.float32) / 255.0
        if raw.shape[-1] == 4:
            raw = raw[:, :, :3]

        orig_height, orig_width = raw.shape[:2]

        # Resize image
        image = self.transform(raw).to(self.device)

        # Load pose
        if self.sfm_poses is not None and img_filename in self.sfm_poses:
            pose = to_homogeneous(self.sfm_poses[img_filename]).to(self.device)
        else:
            # We only need to invert if using NeRF data
            pose = torch.inverse(torch.tensor(frame['transform_matrix'], dtype=torch.float32).to(self.device))

        # Load and rescale intrinsics
        if self.sfm_intrinsics is not None and img_filename in self.sfm_intrinsics:
            params = self.sfm_intrinsics[img_filename].cpu()
            fx, fy, cx, cy = params[:4]

            intrinsics = torch.tensor([fx, fy, cx, cy], dtype=torch.float32)
            intrinsics = self.rescale_intrinsics(intrinsics, (orig_width, orig_height), (self.image_size, self.image_size))

            K = torch.tensor([
                [intrinsics[0], 0, intrinsics[2]],
                [0, intrinsics[1], intrinsics[3]],
                [0, 0, 1]
            ], dtype=torch.float32, device=self.device)

        else:
            # Fallback generic intrinsics from NeRF JSON
            focal = 0.5 * self.image_size / np.tan(0.5 * self.camera_angle_x)
            K = torch.tensor([
                [focal, 0, self.image_size / 2],
                [0, focal, self.image_size / 2],
                [0, 0, 1]
            ], dtype=torch.float32, device=self.device)

        #print(f"IMAGE: {img_filename} - POSE: {pose.cpu()}")
        return image, pose, K


In [None]:
def project_points(points, pose, K):
    N = points.shape[0]
    device = points.device
    points_h = torch.cat([points, torch.ones(N, 1, device=device)], dim=-1)  # (N,4)
    cam_points = (pose @ points_h.T).T[:, :3]  # (N,3)

    z = cam_points[:, 2]

    fx, fy = K[0, 0], K[1, 1]
    cx, cy = K[0, 2], K[1, 2]

    # Avoid division by zero by clamping z (only for valid points, will handle negative z separately)
    z_safe = z.clamp(min=1e-5)

    u = fx * cam_points[:, 0] / z_safe + cx
    v = fy * cam_points[:, 1] / z_safe + cy

    coords = torch.stack([u, v], dim=-1)

    return coords, z, cam_points

In [None]:
def render_gaussian_points(points, colors, radii, pose, K, image_size, kernel_radius=3):
    """
    Vectorized differentiable rendering of isotropic Gaussian splats.

    Args:
        points: (N, 3) tensor of world-space 3D points.
        colors: (N, 3) RGB colors.
        radii: (N,)  standard deviations (pixels) for each point in screen space.
        pose: (4, 4) camera-to-world transform.
        K: (3, 3) camera intrinsics.
        image_size: int, output image size (square).
        kernel_radius: int, splat radius in pixels.

    Returns:
        (3, H, W) image tensor.
    """
    device = points.device
    H = W = image_size
    N = points.shape[0]

    # Project points to image plane
    coords, _, _ = project_points(points, pose, K)  # (N, 2)
    u, v = coords[:, 0], coords[:, 1]

    # Build (2R+1)x(2R+1) pixel offset grid
    k = 2 * kernel_radius + 1
    offset_y, offset_x = torch.meshgrid(
        torch.arange(-kernel_radius, kernel_radius + 1, device=device),
        torch.arange(-kernel_radius, kernel_radius + 1, device=device),
        indexing='ij'
    )  # (k, k)

    offsets = torch.stack([offset_x, offset_y], dim=-1).view(-1, 2)  # (K², 2)

    # Expand all coordinates to per-pixel grid
    coords_exp = coords.unsqueeze(1) + offsets.unsqueeze(0).float()  # (N, K², 2)
    dx_dy = coords_exp - coords.unsqueeze(1)  # (N, K², 2)

    # Compute squared distances and Gaussian weights
    r2 = radii.clamp(min=1e-2).view(-1, 1) ** 2  # (N, 1)
    distsq = dx_dy.pow(2).sum(dim=-1)  # (N, K²)
    weights = torch.exp(-0.5 * distsq / r2)  # (N, K²)

    # Map coords_exp to pixel indices
    ix = coords_exp[..., 0].round().long().clamp(0, W - 1)
    iy = coords_exp[..., 1].round().long().clamp(0, H - 1)
    flat_idx = iy * W + ix  # (N, K²)

    # Flatten for scatter
    flat_idx = flat_idx.view(-1)            # (N*K²,)
    flat_weights = weights.view(-1)         # (N*K²,)
    expanded_colors = colors.unsqueeze(1).expand(-1, k*k, -1).reshape(-1, 3)  # (N*K², 3)

    # Create flattened canvas
    canvas = torch.zeros(3, H * W, device=device)
    alpha = torch.zeros(H * W, device=device)

    for c in range(3):
        canvas[c].index_add_(0, flat_idx, flat_weights * expanded_colors[:, c])
    alpha.index_add_(0, flat_idx, flat_weights)

    # Reshape and normalize
    canvas = canvas.view(3, H, W)
    alpha = alpha.view(1, H, W).clamp(min=1e-5)
    return (canvas / alpha).clamp(0.0, 1.0)


In [None]:
def init_gaussians(points3D, device='cuda'):
    """
    Initialize Gaussians exactly at points3D without noise.

    Args:
        points3D: tensor (N, 3) of 3D points
        device: device string

    Returns:
        dict with keys 'xyz', 'color', 'radius' containing nn.Parameters
    """
    xyz = points3D.to(device)

    N = xyz.shape[0]
    colors = torch.full((N, 3), 0.5, device=device) + 0.1 * torch.randn(N, 3, device=device)
    colors = colors.clamp(0.0, 1.0)

    radius = torch.full((N,), 1.0, device=device)

    return {
        'xyz': nn.Parameter(xyz),
        'color': nn.Parameter(colors),
        'radius': nn.Parameter(radius)
    }


In [None]:
def save_img(tensor_img, filename):
    """
    Save a tensor image (C, H, W) to a file after clamping and cleaning invalid values.
    
    Args:
        tensor_img: torch.Tensor of shape (C, H, W), with values expected in [0,1].
        filename: str, path to save the image file.
    """
    # Clamp values to valid range [0, 1]
    safe_img = torch.clamp(tensor_img, 0.0, 1.0)

    # Check for NaNs or infinite values and replace them safely
    if torch.isnan(safe_img).any() or torch.isinf(safe_img).any():
        print(f"Warning: Image contains NaN or Inf values when saving {filename}")
        safe_img = torch.nan_to_num(safe_img, nan=0.0, posinf=1.0, neginf=0.0)

    # Convert tensor to HWC numpy uint8 image
    img_np = (safe_img.permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)

    # Write image using imageio
    iio.imwrite(filename, img_np)


In [None]:
def train(json_path, image_size=512, N_gaussians=250000, epochs=1000, lr=1e-2):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # 1. Run SfM or load existing SfM results (poses_c2w and intrinsics as dicts keyed by filename)
    poses_w2c_dict, intrinsics_dict, points3D = sfm_extract("nerf_synthetic/lego/train", device=device)

    # 2. Prepare fixed number of 3D points for gaussians
    points3D = get_points(points3D, N_gaussians)

    # 3. Initialize gaussians directly at SfM points (no noise)
    gaussians = init_gaussians(points3D, device=device)

    # 4. Create dataset with SfM poses and intrinsics dicts
    dataset = NeRFDataset(
        json_path,
        image_size=image_size,
        device=device,
        sfm_poses=poses_w2c_dict,
        sfm_intrinsics=intrinsics_dict
    )
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    optimizer = torch.optim.Adam(gaussians.values(), lr=lr)

    pbar = tqdm(range(epochs), desc="Training")
    for epoch in pbar:
        running_loss = 0.0

        for image, pose_w2c, K in loader:
            image = image.squeeze(0)
            pose_w2c = pose_w2c.squeeze(0)
            K = K.squeeze(0)

            optimizer.zero_grad()

            render = render_gaussian_points(
                gaussians['xyz'], gaussians['color'], gaussians['radius'],
                pose_w2c, K, image_size
            )

            loss = F.mse_loss(render, image)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(loader)
        pbar.set_postfix(loss=f"{avg_loss:.4f}")

        if (epoch + 1) % 10 == 0 or epoch == 0:
            save_img(render, f"render_{epoch+1}.png")
            save_img(image, f"image_{epoch+1}.png")
            print(f"Epoch {epoch+1} - Input image range: {image.min().item():.4f} to {image.max().item():.4f}")
            print(f"Epoch {epoch+1} - Render image range: {render.min().item():.4f} to {render.max().item():.4f}")

    save_img(render, "final.png")


In [None]:
train("nerf_synthetic/lego/transforms_train.json")