In [1]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

#from dataset import get_rays
from rendering import rendering
from model import Voxels, Nerf
from ml_helpers import training

from skimage.measure import marching_cubes
import trimesh
import trimesh.smoothing
import os

In [2]:
print("torch version: ", torch.__version__)

if torch.backends.mps.is_available():
    device = torch.device("mps")
else: 
    device = torch.device("cpu")

print("device: ", device)

torch version:  2.5.1
device:  mps


In [3]:
import torch
import numpy as np
import os
import imageio
import struct

def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    data = fid.read(num_bytes)
    if not data:
        return None
    try:
        return struct.unpack(endian_character + format_char_sequence, data)
    except struct.error:
        return None

def read_cameras_binary(path_to_model_file):
    cameras = {}
    with open(path_to_model_file, "rb") as fid:
        num_cameras = read_next_bytes(fid, 8, "Q")[0]
        for _ in range(num_cameras):
            try:
                camera_properties = read_next_bytes(fid, 24, "iiQQ")
                camera_id = camera_properties[0]
                width = camera_properties[2]
                height = camera_properties[3]
                
                num_params = read_next_bytes(fid, 8, "Q")[0]
                if num_params > 1000:  # Sanity check
                    continue
                    
                params = read_next_bytes(fid, 8*num_params, "d" * num_params)
                if params is not None:
                    cameras[camera_id] = {
                        "width": width,
                        "height": height,
                        "params": params
                    }
            except:
                continue
    return cameras

def read_images_binary(path_to_model_file, chunk_size=100):
    images = {}
    with open(path_to_model_file, "rb") as fid:
        num_reg_images = read_next_bytes(fid, 8, "Q")[0]
        for img_index in range(num_reg_images):
            try:
                binary_image_properties = read_next_bytes(fid, 64, "idddddddi")
                if binary_image_properties is None:
                    continue
                    
                image_id = binary_image_properties[0]
                qvec = np.array(binary_image_properties[1:5], dtype=np.float32)
                tvec = np.array(binary_image_properties[5:8], dtype=np.float32)
                camera_id = binary_image_properties[8]

                # Read image name
                image_name = ""
                while True:
                    current_char = read_next_bytes(fid, 1, "c")
                    if current_char is None or current_char[0] == b"\x00":
                        break
                    image_name += current_char[0].decode("utf-8")
                
                # Skip points2D
                num_points2D = read_next_bytes(fid, 8, "Q")
                if num_points2D is None:
                    continue
                fid.seek(24*num_points2D[0], 1)
                
                images[image_id] = {
                    "qvec": qvec,
                    "tvec": tvec,
                    "camera_id": camera_id,
                    "name": image_name,
                }
                
                if len(images) >= chunk_size:
                    yield images
                    images = {}
                    
            except:
                continue
                
        if images:
            yield images

def qvec2rotmat(qvec):
    return np.array([
        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]], dtype=np.float32)

def process_image_batch(image_batch, cameras, datapath, batch_size=2):
    print(f"Processing {len(image_batch)} images")
    print(f"First image path: {os.path.join(datapath, image_batch[0][1]['name'])}")
    for idx in range(0, len(image_batch), batch_size):
        batch = image_batch[idx:idx + batch_size]
        batch_rays_o = []
        batch_rays_d = []
        batch_pixels = []
        
        for image_id, image_data in batch:
            try:
                camera = cameras[image_data['camera_id']]
                H, W = camera['height'], camera['width']
                f = float(camera['params'][0])
                
                img_path = os.path.join(datapath, image_data['name'])
                img = imageio.imread(img_path).astype(np.float32) / 255.0
                if img.shape[-1] == 4:
                    img = img[..., :3] * img[..., -1:] + (1 - img[..., -1:])
                
                R = qvec2rotmat(image_data['qvec'])
                t = image_data['tvec']
                C = -R.T @ t
                
                u, v = np.meshgrid(np.arange(W), np.arange(H))
                x = (u - W/2).reshape(-1)
                y = -(v - H/2).reshape(-1)
                z = -np.ones_like(x) * f
                
                dirs = np.stack([x, y, z], axis=-1)
                dirs = (R.T @ dirs.T).T
                dirs = dirs / np.linalg.norm(dirs, axis=-1, keepdims=True)
                
                batch_rays_o.append(np.broadcast_to(C, (H*W, 3)))
                batch_rays_d.append(dirs)
                batch_pixels.append(img.reshape(-1, 3))
            except:
                continue
                
        if batch_rays_o:
            yield (np.stack(batch_rays_o),
                  np.stack(batch_rays_d),
                  np.stack(batch_pixels))

def get_rays(datapath, mode='train', split_ratio=0.9):
    sparse_dir = os.path.join(datapath, 'colmap', 'sparse', '0')
    
    # Read cameras data
    cameras = read_cameras_binary(os.path.join(sparse_dir, 'cameras.bin'))
    if not cameras:
        raise ValueError("No cameras found in the binary file")

    # Process images in chunks
    image_chunks = list(read_images_binary(os.path.join(sparse_dir, 'images.bin')))
    if not image_chunks:
        raise ValueError("No images found in the binary file")
        
    # Merge chunks and sort
    all_images = {}
    for chunk in image_chunks:
        all_images.update(chunk)
    image_list = sorted(all_images.items(), key=lambda x: x[1]['name'])
    
    # Split train/test
    split_idx = int(len(image_list) * split_ratio)
    if mode == 'train':
        image_list = image_list[:split_idx]
    else:
        image_list = image_list[split_idx:]
        
    # Process images in batches
    all_rays_o = []
    all_rays_d = []
    all_pixels = []
    
    for rays_o, rays_d, pixels in process_image_batch(image_list, cameras, datapath):
        all_rays_o.append(rays_o)
        all_rays_d.append(rays_d)
        all_pixels.append(pixels)
        
    if not all_rays_o:
        raise ValueError("No valid images processed")
        
    rays_o = np.concatenate(all_rays_o)
    rays_d = np.concatenate(all_rays_d)
    pixels = np.concatenate(all_pixels)
    
    if mode == 'train':
        # Get dimensions from camera
        H = cameras[image_list[0][1]['camera_id']]['height']
        W = cameras[image_list[0][1]['camera_id']]['width']
        rays_o = rays_o.reshape(-1, H, W, 3)
        rays_d = rays_d.reshape(-1, H, W, 3)
        pixels = pixels.reshape(-1, H, W, 3)
    
    return rays_o, rays_d, pixels

# Camera / Dataset

In [4]:
batch_size = 1024
o, d, target_px_values = get_rays('fox', mode='train')

# Main dataloader
dataloader = DataLoader(torch.cat((torch.from_numpy(o).reshape(-1, 3).type(torch.float),
                               torch.from_numpy(d).reshape(-1, 3).type(torch.float),
                               torch.from_numpy(target_px_values).reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)

# Warmup dataloader with center crop
o_warmup = o[:, 100:300, 100:300, :].reshape(-1, 3)
d_warmup = d[:, 100:300, 100:300, :].reshape(-1, 3)
target_warmup = target_px_values[:, 100:300, 100:300, :].reshape(-1, 3)

dataloader_warmup = DataLoader(torch.cat((torch.from_numpy(o_warmup).type(torch.float),
                           torch.from_numpy(d_warmup).type(torch.float),
                           torch.from_numpy(target_warmup).type(torch.float)), dim=1),
                   batch_size=batch_size, shuffle=True)

test_o, test_d, test_target_px_values = get_rays('fox', mode='test')

ValueError: No valid images processed

# Training

In [None]:
device = device

tn = 8.
tf = 12.
nb_epochs = 1 #15 30
lr =  1e-3 # 1e-3 5e-4
gamma = .5 #0.5 0.7 
nb_bins = 100 #100 256

model = Nerf(hidden_dim=256).to(device) #Nerf(hidden_dim=128).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=gamma)



training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, 1, dataloader_warmup, device=device)
plt.plot(training_loss)
plt.show()
training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, nb_epochs, dataloader, device=device)
plt.plot(training_loss)
plt.show()

In [None]:
torch.save(model, 'model_nerf_colmap')

# Mesh extraction

In [None]:
model = torch.load('model_nerf_colmap').to(device)

In [None]:
import torch
import numpy as np
from skimage.measure import marching_cubes
import trimesh
import torch.nn.functional as F

def analyze_density_field(density_volume):
    """Analyze the density field to help choose a good threshold."""
    min_density = float(density_volume.min())
    max_density = float(density_volume.max())
    mean_density = float(density_volume.mean())
    std_density = float(density_volume.std())
    
    print(f"Density field statistics:")
    print(f"Min: {min_density:.6f}")
    print(f"Max: {max_density:.6f}")
    print(f"Mean: {mean_density:.6f}")
    print(f"Std: {std_density:.6f}")
    
    # Suggest threshold as mean + 1 std deviation
    suggested_threshold = mean_density + std_density
    return suggested_threshold

def extract_mesh(nerf_model, resolution=128, threshold=None, bbox_min=[-1.5, -1.5, -1.5], 
                bbox_max=[1.5, 1.5, 1.5], device=torch.device("cpu")):
    """
    Extract a colored mesh from a trained NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        resolution: Grid resolution for marching cubes
        threshold: Density threshold for surface extraction (if None, will be auto-determined)
        bbox_min: Minimum corner of bounding box
        bbox_max: Maximum corner of bounding box
        device: Torch device to use
    
    Returns:
        trimesh.Trimesh: Colored mesh
    """
    print(f"Creating density volume with resolution {resolution}...")
    
    # Create grid of points
    x = torch.linspace(bbox_min[0], bbox_max[0], resolution)
    y = torch.linspace(bbox_min[1], bbox_max[1], resolution)
    z = torch.linspace(bbox_min[2], bbox_max[2], resolution)
    xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
    points = torch.stack([xx, yy, zz], dim=-1).to(device)
    
    # Create density volume
    density_volume = torch.zeros((resolution, resolution, resolution))
    chunk_size = 512 * 512  # Process in chunks to avoid OOM
    
    print("Sampling density field...")
    with torch.no_grad():
        for i in range(0, points.numel() // 3, chunk_size):
            chunk_points = points.reshape(-1, 3)[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            _, chunk_densities = nerf_model(chunk_points, torch.zeros_like(chunk_points))
            density_volume.reshape(-1)[i:i+chunk_size] = chunk_densities.cpu()
    
    # Auto-determine threshold if not provided
    if threshold is None:
        threshold = analyze_density_field(density_volume)
        print(f"Auto-determined threshold: {threshold:.6f}")
    
    print(f"Extracting mesh with threshold {threshold}...")
    
    try:
        # Extract mesh using marching cubes
        vertices, faces, normals, _ = marching_cubes(
            density_volume.numpy(),
            threshold,
            spacing=((bbox_max[0] - bbox_min[0])/resolution,
                    (bbox_max[1] - bbox_min[1])/resolution,
                    (bbox_max[2] - bbox_min[2])/resolution)
        )
    except ValueError as e:
        print("Error during marching cubes:")
        print(e)
        print("\nTry adjusting the threshold based on the density statistics above.")
        raise
    
    print(f"Mesh extracted with {len(vertices)} vertices and {len(faces)} faces")
    
    # Adjust vertices to match bbox
    vertices = vertices + np.array(bbox_min)
    
    # Sample colors at vertex positions
    vertex_colors = torch.zeros((len(vertices), 3))
    vertices_tensor = torch.tensor(vertices, dtype=torch.float32).to(device)
    
    print("Sampling colors...")
    with torch.no_grad():
        for i in range(0, len(vertices), chunk_size):
            chunk_vertices = vertices_tensor[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            chunk_colors, _ = nerf_model(chunk_vertices, torch.zeros_like(chunk_vertices))
            vertex_colors[i:i+chunk_size] = chunk_colors.cpu()
    
    # Create mesh with vertex colors
    mesh = trimesh.Trimesh(
        vertices=vertices,
        faces=faces,
        vertex_colors=(vertex_colors.numpy() * 255).astype(np.uint8),
        vertex_normals=normals
    )
    
    return mesh

def save_colored_mesh(nerf_model, output_path, resolution=256, threshold=None, device=torch.device("cpu")):
    """
    Extract and save a colored mesh from a NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        output_path: Path to save the mesh (should end in .ply or .obj)
        resolution: Resolution for marching cubes
        threshold: Density threshold (if None, will be auto-determined)
        device: Torch device to use
    """
    mesh = extract_mesh(nerf_model, resolution=resolution, threshold=threshold, device=device)
    
    print("Processing mesh...")
    # Optional mesh cleanup
    mesh = mesh.process(validate=True)
    
    print(f"Saving mesh to {output_path}...")
    # Save the mesh
    mesh.export(output_path)
    return mesh

# After loading your model
resolution = 700  # Increase for better quality, decrease if you run into memory issues
output_path = "nerf_mesh.obj"  # Can also use .obj format

# Extract and save the mesh
mesh = save_colored_mesh(model, output_path, resolution=resolution, device=device)

In [None]:
#!pip install Pymcubes
#!pip install trimesh
#!pip install -U scikit-image
#!pip install genesis-world  # Requires Python >=3.9;
#!pip uninstall genesis-world
#!conda install -c anaconda trimesh