In [1]:
#!pip install Pymcubes
#!pip install trimesh


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import mcubes
import trimesh

In [3]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataset import get_rays
from rendering import rendering
from model import Voxels, Nerf
from ml_helpers import training

In [4]:
device = 'cuda'
tn = 8.
tf = 12.
model = torch.load('model_nerf').to(device)

In [5]:
N = 100
scale = 1.5

x = torch.linspace(-scale, scale, N)
y = torch.linspace(-scale, scale, N)
z = torch.linspace(-scale, scale, N)

x, y, z = torch.meshgrid((x, y, z))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
xyz = torch.cat((x.reshape(-1, 1),
                 y.reshape(-1, 1),
                 z.reshape(-1, 1)), dim=1)

In [7]:
with torch.no_grad():
    _, density = model.forward(xyz.to(device), torch.zeros_like(xyz).to(device))
    
density = density.cpu().numpy().reshape(N, N, N)

In [8]:
vertices, triangles = mcubes.marching_cubes(density, 30 * np.mean(density))

In [9]:
mesh = trimesh.Trimesh(vertices / N, triangles)

In [10]:
mesh.show()

# Other method

In [None]:
import torch
import numpy as np
from skimage.measure import marching_cubes
import trimesh
import torch.nn.functional as F

def extract_mesh(nerf_model, resolution=256, threshold=50.0, bbox_min=[-1.5, -1.5, -1.5], 
                bbox_max=[1.5, 1.5, 1.5], device=torch.device("cpu")):
    """
    Extract a colored mesh from a trained NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        resolution: Grid resolution for marching cubes
        threshold: Density threshold for surface extraction
        bbox_min: Minimum corner of bounding box
        bbox_max: Maximum corner of bounding box
        device: Torch device to use
    
    Returns:
        trimesh.Trimesh: Colored mesh
    """
    # Create grid of points
    x = torch.linspace(bbox_min[0], bbox_max[0], resolution)
    y = torch.linspace(bbox_min[1], bbox_max[1], resolution)
    z = torch.linspace(bbox_min[2], bbox_max[2], resolution)
    xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
    points = torch.stack([xx, yy, zz], dim=-1).to(device)
    
    # Create density volume
    density_volume = torch.zeros((resolution, resolution, resolution))
    chunk_size = 512 * 512  # Process in chunks to avoid OOM
    
    with torch.no_grad():
        for i in range(0, points.numel() // 3, chunk_size):
            chunk_points = points.reshape(-1, 3)[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            _, chunk_densities = nerf_model(chunk_points, torch.zeros_like(chunk_points))
            density_volume.reshape(-1)[i:i+chunk_size] = chunk_densities.cpu()
    
    # Extract mesh using marching cubes
    vertices, faces, normals, _ = marching_cubes(
        density_volume.numpy(),
        threshold,
        spacing=((bbox_max[0] - bbox_min[0])/resolution,
                (bbox_max[1] - bbox_min[1])/resolution,
                (bbox_max[2] - bbox_min[2])/resolution)
    )
    
    # Adjust vertices to match bbox
    vertices = vertices + np.array(bbox_min)
    
    # Sample colors at vertex positions
    vertex_colors = torch.zeros((len(vertices), 3))
    vertices_tensor = torch.tensor(vertices, dtype=torch.float32).to(device)
    
    with torch.no_grad():
        for i in range(0, len(vertices), chunk_size):
            chunk_vertices = vertices_tensor[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            chunk_colors, _ = nerf_model(chunk_vertices, torch.zeros_like(chunk_vertices))
            vertex_colors[i:i+chunk_size] = chunk_colors.cpu()
    
    # Create mesh with vertex colors
    mesh = trimesh.Trimesh(
        vertices=vertices,
        faces=faces,
        vertex_colors=(vertex_colors.numpy() * 255).astype(np.uint8),
        vertex_normals=normals
    )
    
    return mesh

def save_colored_mesh(nerf_model, output_path, resolution=128, device=torch.device("cpu")):
    """
    Extract and save a colored mesh from a NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        output_path: Path to save the mesh (should end in .ply or .obj)
        resolution: Resolution for marching cubes
        device: Torch device to use
    """
    mesh = extract_mesh(nerf_model, resolution=resolution, device=device)
    
    # Optional mesh cleanup
    mesh = mesh.process(validate=True)
    
    # Save the mesh
    mesh.export(output_path)
    return mesh

# After loading your model
resolution = 500  # Increase for better quality, decrease if you run into memory issues
output_path = "nerf_mesh.ply"  # Can also use .obj format

# Extract and save the mesh
mesh = save_colored_mesh(model, output_path, resolution=resolution, device=device)

In [None]:
# Approach 2
def extract_mesh(model, device, bounds=(-1.2, 1.2), resolution=128, sigma_threshold=0.5):
    """
    Extract mesh from a trained NeRF model using marching cubes.
    
    Args:
        model: Trained NeRF model
        device: torch device (cpu or mps)
        bounds: tuple of (min, max) for all axes (assuming cubic volume)
        resolution: Grid resolution for marching cubes
        sigma_threshold: Density threshold for surface extraction (default 0.5 for normalized density)
    
    Returns:
        vertices: numpy array of vertex positions
        faces: numpy array of triangle indices
        mesh: trimesh object
    """
    model.eval()
    
    # Create grid
    x = np.linspace(bounds[0], bounds[1], resolution)
    y = np.linspace(bounds[0], bounds[1], resolution)
    z = np.linspace(bounds[0], bounds[1], resolution)
    
    xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
    points = np.stack([xx, yy, zz], axis=-1)
    
    # Process in chunks to avoid memory issues
    chunk_size = 32768  # Adjust based on your GPU memory
    sigma = np.zeros(resolution * resolution * resolution)
    
    with torch.no_grad():
        for i in range(0, points.reshape(-1, 3).shape[0], chunk_size):
            # Get chunk of points
            points_chunk = torch.FloatTensor(
                points.reshape(-1, 3)[i:i+chunk_size]
            ).to(device)
            
            # Your model might need different input formatting
            directions = torch.zeros_like(points_chunk).to(device)  # Dummy directions
            
            # Get model output - we know it's a tuple now
            chunk_output = model(points_chunk, directions)
            
            # Extract sigma from the second element of the tuple
            sigma_chunk = chunk_output[1].cpu().numpy()
            
            # Handle potential size mismatch at the last chunk
            chunk_end = min(i + chunk_size, len(sigma))
            sigma[i:chunk_end] = sigma_chunk[:chunk_end-i]
    
    # Reshape sigma to grid
    sigma = sigma.reshape(resolution, resolution, resolution)
    
    # Print sigma statistics for debugging
    print("Sigma statistics:")
    print(f"Min: {sigma.min():.6f}")
    print(f"Max: {sigma.max():.6f}")
    print(f"Mean: {sigma.mean():.6f}")
    print(f"Number of values above threshold: {np.sum(sigma > sigma_threshold)}")
    
    # Extract mesh using marching cubes
    try:
        vertices, faces, _, _ = marching_cubes(sigma, sigma_threshold)
        
        # Scale vertices back to original coordinate system
        scale = (bounds[1] - bounds[0]) / resolution
        vertices = vertices * scale + bounds[0]
        
        # Create trimesh object
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
        
        # Optional: Smooth the mesh
        mesh = trimesh.smoothing.filter_laplacian(mesh, iterations=1)
        
        return vertices, faces, mesh
        
    except ValueError as e:
        print("\nError during marching cubes:", e)
        print("Try adjusting the sigma_threshold to a value between the min and max sigma values shown above.")
        raise

# Example usage:

# Extract mesh - start with a lower threshold since your model outputs values between 0 and 1
vertices, faces, mesh = extract_mesh(
    model=model,
    device=device,
    bounds=(-1.2, 1.2),  # Adjust based on your scene bounds
    resolution=128,      # Increase for higher quality
    sigma_threshold=50  # Try a much lower threshold first
)

# Save mesh
mesh.export('nerf_mesh.obj')

# Visualize (if needed)
mesh.invert()
mesh.show()

In [None]:
# Approach 1
def generate_colored_mesh(model, device, N=200, scale=1.5, batch_size=50000, density_threshold=None):
    """Generate a detailed mesh from a NeRF model with vertex colors."""
    # Generate grid points
    x = torch.linspace(-scale, scale, N)
    y = torch.linspace(-scale, scale, N)
    z = torch.linspace(-scale, scale, N)
    x, y, z = torch.meshgrid((x, y, z), indexing='ij')
    xyz = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
    
    # Process in batches
    densities = []
    colors = []
    
    for i in range(0, xyz.shape[0], batch_size):
        batch_xyz = xyz[i:i+batch_size].to(device)
        # Sample multiple directions for better color estimates
        theta = torch.linspace(0, np.pi, 4)
        phi = torch.linspace(0, 2*np.pi, 4)
        theta, phi = torch.meshgrid((theta, phi), indexing='ij')
        dirs = torch.stack([
            torch.sin(theta) * torch.cos(phi),
            torch.sin(theta) * torch.sin(phi),
            torch.cos(theta)
        ], dim=-1).view(-1, 3)
        
        batch_dirs = dirs.to(device)
        
        with torch.no_grad():
            batch_colors = []
            for d in batch_dirs:
                batch_d = d.expand(batch_xyz.shape[0], -1)
                rgb, density = model.forward(batch_xyz, batch_d)
                batch_colors.append(rgb)
                
                if len(batch_colors) == 1:  # Only need density once
                    densities.append(density.cpu())
            
            # Average colors across viewing directions
            avg_color = torch.stack(batch_colors).mean(0)
            colors.append(avg_color.cpu())
    
    # Combine results
    density = torch.cat(densities, dim=0).numpy().reshape(N, N, N)
    colors = torch.cat(colors, dim=0)
    
    # Set threshold
    if density_threshold is None:
        density_threshold = 30 * np.mean(density)
    
    # Generate mesh
    vertices, faces, normals, values = marching_cubes(
        density,
        level=density_threshold,
        spacing=(scale*2/N, scale*2/N, scale*2/N)
    )
    
    # Create mesh
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, normals=normals)
    
    # Sample colors for vertices
    vertex_colors = []
    for vertex in vertices:
        v_normalized = torch.tensor(vertex / (scale*2) * N).long()
        v_normalized = torch.clamp(v_normalized, 0, N-1)
        color = colors[v_normalized[0] * N*N + v_normalized[1] * N + v_normalized[2]]
        vertex_colors.append(color.numpy())
    
    # Convert and assign colors
    vertex_colors = np.array(vertex_colors)
    vertex_colors = (vertex_colors * 255).astype(np.uint8)
    mesh.visual.vertex_colors = vertex_colors
    
    # Clean up mesh
    components = mesh.split(only_watertight=False)
    if len(components) > 1:
        areas = np.array([c.area for c in components])
        mesh = components[np.argmax(areas)]
    
    mesh = trimesh.smoothing.filter_laplacian(mesh)
    
    return mesh

In [None]:
# Set parameters
N = 250  # Resolution
scale = 1.5
batch_size = 50000
density_threshold = None  # Will use mean-based threshold

In [None]:
# Generate mesh with colors
print("Generating colored mesh...")
mesh = generate_colored_mesh(
    model=model,
    device=device,
    N=N,
    scale=scale,
    batch_size=batch_size,
    density_threshold=density_threshold
)

# Save mesh in different formats
output_dir = "nerf_output"
os.makedirs(output_dir, exist_ok=True)

# Save as PLY (best for vertex colors)
ply_path = os.path.join(output_dir, "nerf_mesh_colored.ply")
mesh.invert()
mesh.export(ply_path)
print(f"Saved colored mesh as PLY: {ply_path}")

# Save as OBJ with MTL
obj_path = os.path.join(output_dir, "nerf_mesh_colored.obj")
mesh.export(obj_path, include_texture=True)
print(f"Saved colored mesh as OBJ: {obj_path}")

# Display statistics
print("\nMesh Statistics:")
print(f"Number of vertices: {len(mesh.vertices)}")
print(f"Number of faces: {len(mesh.faces)}")
print(f"Mesh volume: {mesh.volume:.2f}")

# Optional: Display the mesh
mesh.show()