In [1]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataset import get_rays
from rendering import rendering
from model import Voxels, Nerf
from ml_helpers import training

from skimage.measure import marching_cubes
import trimesh
import trimesh.smoothing
import os

In [2]:
print("torch version: ", torch.__version__)

if torch.backends.mps.is_available():
    device = torch.device("mps")
else: 
    device = torch.device("cpu")

print("device: ", device)

torch version:  2.5.1
device:  mps


# Camera / Dataset

In [3]:
batch_size = 1024

o, d, target_px_values = get_rays('fox', mode='train')
dataloader = DataLoader(torch.cat((torch.from_numpy(o).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(d).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(target_px_values).reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)


dataloader_warmup = DataLoader(torch.cat((torch.from_numpy(o).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(d).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(target_px_values).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)


test_o, test_d, test_target_px_values = get_rays('fox', mode='test')

AssertionError: 

# Training

In [None]:
device = device

tn = 8.
tf = 12.
nb_epochs = 30 #15
lr = 5e-4 #1e-3
gamma = 0.7 #.5
nb_bins = 256 #100

model = Nerf(hidden_dim=256).to(device) #Nerf(hidden_dim=128).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=gamma)



training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, 1, dataloader_warmup, device=device)
plt.plot(training_loss)
plt.show()
training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, nb_epochs, dataloader, device=device)
plt.plot(training_loss)
plt.show()

In [None]:
torch.save(model, 'model_nerf_mps')

# Mesh extraction

In [None]:
model = torch.load('model_nerf_mps').to(device)

In [None]:
#!pip install Pymcubes
#!pip install trimesh
#!pip install -U scikit-image
#!pip install genesis-world  # Requires Python >=3.9;
#!pip uninstall genesis-world
#!conda install -c anaconda trimesh

In [None]:
def generate_colored_mesh(model, device, N=200, scale=1.5, batch_size=50000, density_threshold=None):
    """Generate a detailed mesh from a NeRF model with vertex colors."""
    # Generate grid points
    x = torch.linspace(-scale, scale, N)
    y = torch.linspace(-scale, scale, N)
    z = torch.linspace(-scale, scale, N)
    x, y, z = torch.meshgrid((x, y, z), indexing='ij')
    xyz = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
    
    # Process in batches
    densities = []
    colors = []
    
    for i in range(0, xyz.shape[0], batch_size):
        batch_xyz = xyz[i:i+batch_size].to(device)
        # Sample multiple directions for better color estimates
        theta = torch.linspace(0, np.pi, 4)
        phi = torch.linspace(0, 2*np.pi, 4)
        theta, phi = torch.meshgrid((theta, phi), indexing='ij')
        dirs = torch.stack([
            torch.sin(theta) * torch.cos(phi),
            torch.sin(theta) * torch.sin(phi),
            torch.cos(theta)
        ], dim=-1).view(-1, 3)
        
        batch_dirs = dirs.to(device)
        
        with torch.no_grad():
            batch_colors = []
            for d in batch_dirs:
                batch_d = d.expand(batch_xyz.shape[0], -1)
                rgb, density = model.forward(batch_xyz, batch_d)
                batch_colors.append(rgb)
                
                if len(batch_colors) == 1:  # Only need density once
                    densities.append(density.cpu())
            
            # Average colors across viewing directions
            avg_color = torch.stack(batch_colors).mean(0)
            colors.append(avg_color.cpu())
    
    # Combine results
    density = torch.cat(densities, dim=0).numpy().reshape(N, N, N)
    colors = torch.cat(colors, dim=0)
    
    # Set threshold
    if density_threshold is None:
        density_threshold = 30 * np.mean(density)
    
    # Generate mesh
    vertices, faces, normals, values = marching_cubes(
        density,
        level=density_threshold,
        spacing=(scale*2/N, scale*2/N, scale*2/N)
    )
    
    # Create mesh
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, normals=normals)
    
    # Sample colors for vertices
    vertex_colors = []
    for vertex in vertices:
        v_normalized = torch.tensor(vertex / (scale*2) * N).long()
        v_normalized = torch.clamp(v_normalized, 0, N-1)
        color = colors[v_normalized[0] * N*N + v_normalized[1] * N + v_normalized[2]]
        vertex_colors.append(color.numpy())
    
    # Convert and assign colors
    vertex_colors = np.array(vertex_colors)
    vertex_colors = (vertex_colors * 255).astype(np.uint8)
    mesh.visual.vertex_colors = vertex_colors
    
    # Clean up mesh
    components = mesh.split(only_watertight=False)
    if len(components) > 1:
        areas = np.array([c.area for c in components])
        mesh = components[np.argmax(areas)]
    
    mesh = trimesh.smoothing.filter_laplacian(mesh)
    
    return mesh

In [None]:
# Set parameters
N = 250  # Resolution
scale = 1.5
batch_size = 50000
density_threshold = None  # Will use mean-based threshold

In [None]:
# Generate mesh with colors
print("Generating colored mesh...")
mesh = generate_colored_mesh(
    model=model,
    device=device,
    N=N,
    scale=scale,
    batch_size=batch_size,
    density_threshold=density_threshold
)

# Save mesh in different formats
output_dir = "nerf_output"
os.makedirs(output_dir, exist_ok=True)

# Save as PLY (best for vertex colors)
ply_path = os.path.join(output_dir, "nerf_mesh_colored.ply")
mesh.invert()
mesh.export(ply_path)
print(f"Saved colored mesh as PLY: {ply_path}")

# Save as OBJ with MTL
obj_path = os.path.join(output_dir, "nerf_mesh_colored.obj")
mesh.export(obj_path, include_texture=True)
print(f"Saved colored mesh as OBJ: {obj_path}")

# Display statistics
print("\nMesh Statistics:")
print(f"Number of vertices: {len(mesh.vertices)}")
print(f"Number of faces: {len(mesh.faces)}")
print(f"Mesh volume: {mesh.volume:.2f}")

# Optional: Display the mesh
mesh.show()