# Online (Incremental) Gaussian Splatting [Pure PyTorch Edition]

This notebook implements an **Online / Incremental** 3D reconstruction pipeline using **Pure PyTorch**.

**Updates**:
- **Pure PyTorch Rasterizer**: Robust, zero-install, works on any Colab runtime.
- **Point Dilation**: Renders 2x2 pixel blocks to make sparse points visible and improve gradients.
- **Debug Stats**: Logs render vs ground truth intensity to verify tracking.
- **Stability Fix**: Uses `torch.logit` with clamping to prevent NaN losses.
- **Dual Depth Support**: Toggle between AI Depth (UniDepth) and Sensor Depth (Kinect).
- **Metrics & Viz**: Added PSNR, Rolling Average Loss, and Error Heatmaps.
- **Edge Cleaning**: Added Median Blur to sensor depth to reduce ghosting.

**Pipeline**:
1.  **Process frames sequentially**.
2.  **Spawn New Gaussians** using UniDepth or Kinect.
3.  **Optimization** using a custom Python rasterizer (with dilation).
4.  **Live Visualization**.

**Hardware Requirement**: T4 GPU (free tier) or A100 (Pro) on Google Colab.

In [None]:
# 1. Setup Environment
!nvidia-smi

# Install core libs (No gsplat needed!)
!pip install torch torchvision torchaudio tqdm opencv-python matplotlib scipy pandas imageio plyfile plotly

# Install UniDepth
import os
import sys

if not os.path.exists("UniDepth"):
    !git clone https://github.com/lpiccinelli-eth/UniDepth.git

%cd UniDepth
!pip install -e .
!pip install timm huggingface_hub
%cd ..

# CRITICAL: Add UniDepth to python path so it can be imported immediately
sys.path.append("/content/UniDepth")

In [None]:
# 2. Download TUM Dataset (fr3_office)
%cd /content
!mkdir -p datasets/tum
%cd datasets/tum

# Fixed URL using vision.in.tum.de
dataset_url = "https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_long_office_household.tgz"
!wget -O dataset.tgz {dataset_url}
!tar -xzf dataset.tgz
!mv rgbd_dataset_freiburg3_long_office_household fr3_office
%cd /content

## 3. Generate Metric Depth (UniDepth)
We pre-process the sequence to get clean metric depth maps for initialization.
(In a real C++ system this would run live, but for Python Colab we pre-cache it for speed).

In [None]:
import torch
import numpy as np
import os
import glob
from PIL import Image
from tqdm import tqdm
from unidepth.models import UniDepthV2

def generate_unidepth(dataset_path, output_depth_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Loading UniDepth V2...")
    model = UniDepthV2.from_pretrained("lpiccinelli/unidepth-v2-vitl14").to(device).eval()
    
    rgb_path = os.path.join(dataset_path, "rgb")
    if not os.path.exists(output_depth_path):
        os.makedirs(output_depth_path)
        
    image_files = sorted(glob.glob(os.path.join(rgb_path, "*.png")))
    # Process every nth frame.
    image_files = image_files[::1] 
    
    print(f"Processing {len(image_files)} frames...")
    with torch.no_grad():
        for img_path in tqdm(image_files):
            pil_img = Image.open(img_path).convert("RGB")
            rgb_tensor = torch.from_numpy(np.array(pil_img)).permute(2, 0, 1).unsqueeze(0).to(device)
            
            predictions = model.infer(rgb_tensor)
            depth_map = predictions["depth"].squeeze().cpu().numpy()
            
            # Save as .npy for full float precision (easier for custom loader)
            basename = os.path.basename(img_path).replace(".png", ".npy")
            np.save(os.path.join(output_depth_path, basename), depth_map)

dataset_root = "/content/datasets/tum/fr3_office"
depth_root = os.path.join(dataset_root, "depth_unidepth")
generate_unidepth(dataset_root, depth_root)

## 4. Custom TUM Dataset Loader & Pure PyTorch Rasterizer (Dilated)

In [None]:
import pandas as pd
from scipy.spatial.transform import Rotation
import torch.nn as nn
import torch.nn.functional as F

# --- Dataset Loader ---
def associate_data(root_dir):
    # Helper to read TUM format files
    def read_file_list(filename):
        file = open(filename)
        data = file.read()
        lines = data.replace(",", " ").replace("\t", " ").split("\n")
        list = [[v.strip() for v in line.split(" ") if v.strip() != ""] for line in lines if len(line) > 0 and line[0] != "#"]
        list = [(float(l[0]), l[1:]) for l in list if len(l) > 1]
        return dict(list)

    rgb_list = read_file_list(os.path.join(root_dir, "rgb.txt"))
    gt_list = read_file_list(os.path.join(root_dir, "groundtruth.txt"))
    def_list = read_file_list(os.path.join(root_dir, "depth.txt")) # Load Sensor Depth List
    
    rgb_timestamps = sorted(rgb_list.keys())
    gt_timestamps = sorted(gt_list.keys())
    dep_timestamps = sorted(def_list.keys())
    
    matches = []
    offset = 0.0
    max_difference = 0.02
    
    for t in rgb_timestamps:
        t_with_offset = t + offset
        
        # Associate GT Poses
        best_gt_t = min(gt_timestamps, key=lambda x: abs(x - t_with_offset))
        gt_diff = abs(best_gt_t - t_with_offset)
        
        # Associate Sensor Depth
        best_dep_t = min(dep_timestamps, key=lambda x: abs(x - t_with_offset))
        dep_diff = abs(best_dep_t - t_with_offset)
        
        if gt_diff < max_difference and dep_diff < max_difference:
            matches.append((t, best_gt_t, best_dep_t))
            
    data = []
    for t_rgb, t_gt, t_dep in matches:
        rgb_file = rgb_list[t_rgb][0]
        
        # AI Depth Path
        unidepth_file_name = os.path.basename(rgb_file).replace(".png", ".npy")
        unidepth_path = os.path.join(root_dir, "depth_unidepth", unidepth_file_name)
        
        # Sensor Depth Path
        sensor_depth_file = def_list[t_dep][0]
        sensor_depth_path = os.path.join(root_dir, sensor_depth_file)
        
        gt_data = gt_list[t_gt]
        tx, ty, tz = float(gt_data[0]), float(gt_data[1]), float(gt_data[2])
        qx, qy, qz, qw = float(gt_data[3]), float(gt_data[4]), float(gt_data[5]), float(gt_data[6])
        
        rot = Rotation.from_quat([qx, qy, qz, qw]).as_matrix()
        c2w = np.eye(4)
        c2w[:3, :3] = rot
        c2w[:3, 3] = [tx, ty, tz]
        
        data.append({
            "rgb_path": os.path.join(root_dir, rgb_file),
            "unidepth_path": unidepth_path,
            "sensor_depth_path": sensor_depth_path,
            "c2w": c2w,
            "timestamp": t_rgb
        })
    return data

dataset_data = associate_data(dataset_root)
print(f"Associated {len(dataset_data)} frames (RGB + GT + SensorDepth).")

# --- Pure PyTorch Rasterizer (with Dilation) ---
def pure_pytorch_rasterization(means, colors, opacities, scales, quats, viewmat, K, height, width):
    # 1. World -> Camera
    R = viewmat[:3, :3]
    t = viewmat[:3, 3]
    
    means_c = (R @ means.T).T + t
    
    # Debug info
    # if means.shape[0] > 0:
    #    print(f"Z-Range: {means_c[:,2].min():.2f} - {means_c[:,2].max():.2f}")
    
    mask = means_c[:, 2] > 0.1 # Near plane
    
    points = means_c[mask]
    colors = colors[mask]
    
    if points.shape[0] == 0:
        return torch.zeros((height, width, 3), device=means.device)
    
    # 2. Project to Screen (Perspective)
    fx, fy, cx, cy = K[0,0], K[1,1], K[0,2], K[1,2]
    z = points[:, 2]
    x = points[:, 0] * fx / z + cx
    y = points[:, 1] * fy / z + cy
    
    # Screen bounds check
    mask_screen = (x >= -20) & (x < width + 20) & (y >= -20) & (y < height + 20)
    
    points = points[mask_screen]
    colors = colors[mask_screen]
    x = x[mask_screen]
    y = y[mask_screen]
    z = z[mask_screen]
    
    if points.shape[0] == 0:
        return torch.zeros((height, width, 3), device=means.device)

    # 3. Splatting with 2x2 Dilation
    # Sort back-to-front
    sorted_indices = torch.argsort(z, descending=True)
    x = x[sorted_indices]
    y = y[sorted_indices]
    colors = colors[sorted_indices]
    
    # Create canvas
    canvas = torch.zeros((height, width, 3), device=means.device)
    
    # Quantize to pixels (Top-Left corner)
    ix_base = x.long()
    iy_base = y.long()
    
    # EXPAND indices to 2x2 block (Dilation)
    # Offsets: (0,0), (1,0), (0,1), (1,1)
    off_x = torch.tensor([0, 1, 0, 1], device=means.device)
    off_y = torch.tensor([0, 0, 1, 1], device=means.device)
    
    # Repeat data 4 times
    ix = ix_base.unsqueeze(1) + off_x.unsqueeze(0) # [N, 4]
    iy = iy_base.unsqueeze(1) + off_y.unsqueeze(0) # [N, 4]
    colors_rep = colors.unsqueeze(1).expand(-1, 4, -1) # [N, 4, 3]
    
    # Flatten
    ix = ix.view(-1)
    iy = iy.view(-1)
    colors_rep = colors_rep.reshape(-1, 3)
    
    # Valid check after expansion
    valid = (ix >= 0) & (ix < width) & (iy >= 0) & (iy < height)
    ix = ix[valid]
    iy = iy[valid]
    colors_rep = colors_rep[valid]
    
    # Paint
    indices = iy * width + ix
    canvas_flat = canvas.reshape(-1, 3)
    canvas_flat.index_copy_(0, indices, colors_rep)
    canvas = canvas_flat.reshape(height, width, 3)
    
    return canvas

# --- Gaussian Model ---
class SimpleGaussianModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.means = nn.Parameter(torch.empty(0, 3, device="cuda"))
        self.scales = nn.Parameter(torch.empty(0, 3, device="cuda"))
        self.quats = nn.Parameter(torch.empty(0, 4, device="cuda"))
        self.opacities = nn.Parameter(torch.empty(0, 1, device="cuda"))
        self.colors = nn.Parameter(torch.empty(0, 3, device="cuda"))
        
    def add_gaussians(self, new_means, new_colors):
        with torch.no_grad():
            N_new = new_means.shape[0]
            if N_new == 0: return
            
            new_scales = torch.ones(N_new, 3, device="cuda") * -5.0 
            new_quats = torch.zeros(N_new, 4, device="cuda"); new_quats[:, 0] = 1.0
            new_opacities = torch.zeros(N_new, 1, device="cuda")
            
            # STABLE INIT: Clamp colors to avoid log(0) -> NaN
            new_colors = torch.clamp(new_colors, 0.01, 0.99)
            new_colors = torch.logit(new_colors) # Stable inverse sigmoid
            
            self.means = nn.Parameter(torch.cat([self.means, new_means]))
            self.scales = nn.Parameter(torch.cat([self.scales, new_scales]))
            self.quats = nn.Parameter(torch.cat([self.quats, new_quats]))
            self.opacities = nn.Parameter(torch.cat([self.opacities, new_opacities]))
            self.colors = nn.Parameter(torch.cat([self.colors, new_colors]))

    def forward(self, viewmat, K, height, width):
        render_colors = torch.sigmoid(self.colors)
        opacities = torch.sigmoid(self.opacities)
        
        rgb = pure_pytorch_rasterization(
            self.means,
            render_colors,
            opacities,
            torch.exp(self.scales),
            self.quats,
            viewmat, 
            K,
            height, width
        )
        return rgb

## 5. Online Reconstruction (The Main Loop)
We iterate through frames, adding new Gaussians (Keyframing) and optimizing immediately.

In [None]:
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Setup
model = SimpleGaussianModel().cuda()

# Optimizer config
lr_means = 0.00016
lr_colors = 0.0025

optimizer = torch.optim.Adam([
    {'params': [model.means], 'lr': lr_means},
    {'params': [model.colors], 'lr': lr_colors},
], lr=0.0)

# Camera Intrinsics
H, W = 480, 640
fx, fy, cx, cy = 535.4, 539.2, 320.1, 247.6
K_mat = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
K_torch = torch.tensor(K_mat, device="cuda", dtype=torch.float32)

# Params
DEPTH_MODE = "sensor" # Options: "unidepth" (AI) or "sensor" (Kinect)
KEYFRAME_EVERY = 5 # Add new Gaussians every X frames
ITERS_PER_FRAME = 5 # Optimize X steps per frame
VIZ_EVERY = 10 # Update plot every X frames


def spawn_gaussians_from_frame(frame_data, mode="sensor", subsample=4):
    # Load Image
    img = np.array(Image.open(frame_data['rgb_path']).convert("RGB")) / 255.0
    
    # Load Depth based on mode
    if mode == "unidepth":
        # Metric prediction (float32)
        depth = np.load(frame_data['unidepth_path'])
    else:
        # Sensor depth (PNG uint16, scale=5000) from Kinect
        depth_png = np.array(Image.open(frame_data['sensor_depth_path']))
        
        # FIX: Median Blur to remove "flying pixels" and ghost edges common in Kinect data
        # This fixes the "duplicated cup" artifact
        depth_png = cv2.medianBlur(depth_png, 5) 
        
        depth = depth_png.astype(np.float32) / 5000.0
        # Fix missing values (0) -> make them invalid for masking
        depth[depth == 0] = -1.0
    
    # Backproject
    ys, xs = np.indices((H, W))
    
    # Subsample
    ys, xs = ys[::subsample, ::subsample], xs[::subsample, ::subsample]
    z = depth[::subsample, ::subsample]
    img_small = img[::subsample, ::subsample]
    
    x = (xs - cx) * z / fx
    y = (ys - cy) * z / fy
    xyz_cam = np.stack([x, y, z], axis=-1)
    
    # Transform to World
    c2w = frame_data['c2w']
    xyz_world = (c2w[:3, :3] @ xyz_cam.reshape(-1, 3).T).T + c2w[:3, 3]
    colors = img_small.reshape(-1, 3)
    
    # Filter invalid (Close, Far, or Missing)
    mask = (z.reshape(-1) > 0.1) & (z.reshape(-1) < 8.0)
    
    return torch.tensor(xyz_world[mask], dtype=torch.float32, device="cuda"), \
           torch.tensor(colors[mask], dtype=torch.float32, device="cuda")


# --- Online Loop ---
limit_frames = 300 # Process first 300 frames for demo speed
process_data = dataset_data[:limit_frames]

# Metrics
loss_history = []
psnr_history = []

def get_psnr(pred, gt):
    mse = torch.mean((pred - gt) ** 2)
    return -10 * torch.log10(mse)

print(f"Starting Online Reconstruction with DEPTH_MODE='{DEPTH_MODE}'...")
for i, frame in enumerate(process_data):
    
    # 1. Spawn New Gaussians (Keyframing)
    if i % KEYFRAME_EVERY == 0:
        new_means, new_colors = spawn_gaussians_from_frame(frame, mode=DEPTH_MODE, subsample=8)
        model.add_gaussians(new_means, new_colors)
        
        # Re-add parameters to optimizer (since we changed shape)
        optimizer = torch.optim.Adam([
            {'params': [model.means], 'lr': lr_means},
            {'params': [model.colors], 'lr': lr_colors},
        ], lr=0.0)
    
    # 2. Prepare Data for Optimization
    gt_rgb = torch.tensor(np.array(Image.open(frame['rgb_path']).convert("RGB")) / 255.0, dtype=torch.float32, device="cuda")
    gt_c2w = torch.tensor(frame['c2w'], dtype=torch.float32, device="cuda")
    w2c = torch.inverse(gt_c2w)
    
    # 3. Incremental Optimization
    for _ in range(ITERS_PER_FRAME):
        render_rgb = model(w2c, K_torch, H, W)
        
        loss = torch.abs(render_rgb - gt_rgb).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Track Metrics
    with torch.no_grad():
        curr_psnr = get_psnr(render_rgb, gt_rgb).item()
        loss_history.append(loss.item())
        psnr_history.append(curr_psnr)
        avg_loss = np.mean(loss_history[-20:]) # Mean of last 20 frames
        avg_psnr = np.mean(psnr_history[-20:])

    # 4. Live Visualization
    if i % VIZ_EVERY == 0:
        render_np = render_rgb.detach().cpu().numpy()
        gt_np = gt_rgb.detach().cpu().numpy()
        error_np = np.abs(render_np - gt_np).mean(axis=2) # Error Map
        
        # Print stats to verify valid rendering
        print(f"Frame {i}/{len(process_data)} | Loss: {loss.item():.4f} | Avg Loss (20): {avg_loss:.4f} | PSNR: {curr_psnr:.2f} | Avg PSNR (20): {avg_psnr:.2f} | Points: {model.means.shape[0]}")

        clear_output(wait=True)
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow(render_np); ax[0].set_title(f"Render (PSNR: {curr_psnr:.2f})")
        ax[1].imshow(gt_np); ax[1].set_title("Ground Truth")
        im = ax[2].imshow(error_np, cmap='jet', vmin=0, vmax=0.5); ax[2].set_title(f"Error Residual (Avg: {avg_loss:.4f})")
        plt.colorbar(im, ax=ax[2])
        plt.show()

## 6. Save Final Map

In [None]:
from plyfile import PlyData, PlyElement

means = model.means.detach().cpu().numpy()
colors = torch.sigmoid(model.colors).detach().cpu().numpy()

# Create structured array
vertex = np.array([tuple(np.concatenate([means[i], colors[i]*255])) for i in range(len(means))],
                  dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])

el = PlyElement.describe(vertex, 'vertex')
PlyData([el]).write('online_tum_reconstruction.ply')
print("Saved online_tum_reconstruction.ply")

## 7. Interactive Visualization (Plotly)
Visualize the resulting Point Cloud directly in the notebook.

In [None]:
import plotly.graph_objects as go
import numpy as np
from plyfile import PlyData

def visualize_ply(ply_path, subsample=10):
    print("Loading PLY...")
    plydata = PlyData.read(ply_path)
    vertex = plydata['vertex']
    
    # Subsample to avoid crashing browser (e.g., take every 50th point)
    x = vertex['x'][::subsample]
    y = vertex['y'][::subsample]
    z = vertex['z'][::subsample]
    r = vertex['red'][::subsample]
    g = vertex['green'][::subsample]
    b = vertex['blue'][::subsample]
    
    colors = np.stack([r, g, b], axis=-1) / 255.0
    
    # Outlier Removal (Keep 5%-95% range)
    # If outliers exist (e.g., depth=0 or depth=Inf), they distort the plot scale.
    def filter_outliers(arr):
        q5, q95 = np.percentile(arr, 5), np.percentile(arr, 95)
        return (arr >= q5) & (arr <= q95)
        
    mask = filter_outliers(x) & filter_outliers(y) & filter_outliers(z)
    x, y, z = x[mask], y[mask], z[mask]
    colors = colors[mask]
    
    print(f"Visualizing {len(x)} points (Subsample={subsample}, Outliers Removed)...")
    
    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=z, z=y, # Swapped Y/Z for typical rendering orientation
        mode='markers',
        marker=dict(
            size=2,
            color=colors,
            opacity=0.8
        )
    )])
    
    fig.update_layout(
        scene=dict(
            aspectmode='data', # Ensures 1:1:1 scale based on data range
            camera=dict(eye=dict(x=-1.5, y=-1.5, z=0.5))
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )
    fig.show()

visualize_ply('online_tum_reconstruction.ply', subsample=50)