In [2]:
import torch.nn as nn
import torch.utils.data
from tqdm import tqdm
import glob
from matplotlib.image import imread
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import torch

In [3]:
print("torch version: ", torch.__version__)

if torch.backends.mps.is_available():
    device = torch.device("mps")
else: 
    device = torch.device("cpu")

print("device: ", device)

torch version:  2.5.1
device:  mps


In [4]:
@torch.no_grad()
def test(model, camera_intrinsics, camera_extrinsics, hn, hf, images, chunk_size=10, img_index=0, nb_bins=192, H=400,
         W=400):
    ray_origins, ray_directions, _ = sample_batch(camera_extrinsics, camera_intrinsics, images, None, H, W,
                                                  img_index=img_index, sample_all=True)
    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(camera_intrinsics.device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(camera_intrinsics.device)
        regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values)
    img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)
    plt.imshow(img)
    plt.savefig(f'Imgs/novel_view.png', bbox_inches='tight')
    plt.close()


class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
        super(NerfModel, self).__init__()

        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2),
                                    nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, o, d):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos)
        emb_d = self.positional_encoding(d, self.embedding_dim_direction)
        tmp = self.block2(torch.cat((self.block1(emb_x), emb_x), dim=1))
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1])
        c = self.block4(self.block3(torch.cat((h, emb_d), dim=1)))
        return c, sigma


def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)  # Perturb sampling along each ray.
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    alpha = 1 - torch.exp(-sigma.reshape(x.shape[:-1]) * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    return (weights * colors.reshape(x.shape)).sum(dim=1)  # Pixel values


def train(nerf_model, optimizers, schedulers, training_images, camera_extrinsics, camera_intrinsics, batch_size,
          nb_epochs, hn=0., hf=1., nb_bins=192):
    H, W = training_images.shape[1:3]

    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        ids = np.arange(training_images.shape[0])
        np.random.shuffle(ids)
        for img_index in ids:
            rays_o, rays_d, samples_idx = sample_batch(camera_extrinsics, camera_intrinsics, training_images,
                                                       batch_size, H, W, img_index=img_index)
            gt_px_values = torch.from_numpy(training_images[samples_idx]).to(camera_intrinsics.device)
            regenerated_px_values = render_rays(nerf_model, rays_o, rays_d, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((gt_px_values - regenerated_px_values) ** 2).sum()

            for optimizer in optimizers:
                optimizer.zero_grad()
            loss.backward()
            for optimizer in optimizers:
                optimizer.step()
            training_loss.append(loss.item())
        for scheduler in schedulers:
            scheduler.step()
    return training_loss


def initialize_camera_parameters(images, device=device):
    camera_intrinsics = torch.ones(1, device=device, requires_grad=True)
    camera_extrinsics = torch.zeros((images.shape[0], 6), device=device, dtype=torch.float32, requires_grad=True)
    return camera_intrinsics, camera_extrinsics


def load_images(data_path):
    image_paths = glob.glob(data_path)
    images = None
    for i, image_path in enumerate(image_paths):
        img = imread(image_path)
        # Ensure we only take RGB channels if image is RGBA
        if img.shape[-1] == 4:
            img = img[..., :3]
        # Add batch dimension
        img = np.expand_dims(img, 0)
        images = np.concatenate((images, img)) if images is not None else img
    return images


def get_ndc_rays(H, W, focal, rays_o, rays_d, near=1.):
    # We shift o to the ray’s intersection with the near plane at z = −n (before the NDC conversion)
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    rays_o = torch.stack([- focal / W / 2. * rays_o[..., 0] / rays_o[..., 2],
                          - focal / H / 2. * rays_o[..., 1] / rays_o[..., 2],
                          1. + 2. * near / rays_o[..., 2]], -1)  # Eq 25 https://arxiv.org/pdf/2003.08934.pdf
    rays_d = torch.stack([- focal / W / 2. * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]),
                          - focal / H / 2. * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]),
                          - 2. * near / rays_o[..., 2]], -1)  # Eq 26 https://arxiv.org/pdf/2003.08934.pdf
    return rays_o, rays_d


def sample_batch(camera_extrinsics, camera_intrinsics, images, batch_size, H, W, img_index=0, sample_all=False):
    if sample_all:
        image_indices = (torch.zeros(W * H) + img_index).type(torch.long)
        u, v = np.meshgrid(np.linspace(0, W - 1, W, dtype=int), np.linspace(0, H - 1, H, dtype=int))
        u = torch.from_numpy(u.reshape(-1)).to(camera_intrinsics.device)
        v = torch.from_numpy(v.reshape(-1)).to(camera_intrinsics.device)
    else:
        image_indices = (torch.zeros(batch_size) + img_index).type(torch.long)  # Sample random images
        u = torch.randint(W, (batch_size,), device=camera_intrinsics.device)  # Sample random pixels
        v = torch.randint(H, (batch_size,), device=camera_intrinsics.device)

    focal = camera_intrinsics[0] ** 2 * W
    t = camera_extrinsics[img_index, :3]
    r = camera_extrinsics[img_index, -3:]

    # Creating the c2w matrix, Section 4.1 from the paper
    phi_skew = torch.stack([torch.cat([torch.zeros(1, device=r.device), -r[2:3], r[1:2]]),
                            torch.cat([r[2:3], torch.zeros(1, device=r.device), -r[0:1]]),
                            torch.cat([-r[1:2], r[0:1], torch.zeros(1, device=r.device)])], dim=0)
    alpha = r.norm() + 1e-15
    R = torch.eye(3, device=r.device) + (torch.sin(alpha) / alpha) * phi_skew + (
            (1 - torch.cos(alpha)) / alpha ** 2) * (phi_skew @ phi_skew)
    c2w = torch.cat([R, t.unsqueeze(1)], dim=1)
    c2w = torch.cat([c2w, torch.tensor([[0., 0., 0., 1.]], device=c2w.device)], dim=0)

    rays_d_cam = torch.cat([((u.to(camera_intrinsics.device) - .5 * W) / focal).unsqueeze(-1),
                            (-(v.to(camera_intrinsics.device) - .5 * H) / focal).unsqueeze(-1),
                            - torch.ones_like(u).unsqueeze(-1)], dim=-1)
    rays_d_world = torch.matmul(c2w[:3, :3].view(1, 3, 3), rays_d_cam.unsqueeze(2)).squeeze(2)
    rays_o_world = c2w[:3, 3].view(1, 3).expand_as(rays_d_world)
    rays_o_world, rays_d_world = get_ndc_rays(H, W, focal, rays_o=rays_o_world, rays_d=rays_d_world)
    return rays_o_world, F.normalize(rays_d_world, p=2, dim=1), (image_indices, v.cpu(), u.cpu())


if __name__ == "__main__":
    device = device
    nb_epochs = 60 #int(1e4)

    training_images = load_images("fox/_fox/imgs/*.png")
    print("Training images shape:", training_images.shape)
    camera_intrinsics, camera_extrinsics = initialize_camera_parameters(training_images, device=device)
    batch_size = 1024
    import os

    # Create Imgs directory if it doesn't exist
    os.makedirs('Imgs', exist_ok=True)

    # Part 1
    model = NerfModel(hidden_dim=256).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    optimizer_camera_parameters = torch.optim.Adam({camera_extrinsics}, lr=0.0009)
    optimizer_focal = torch.optim.Adam({camera_intrinsics}, lr=0.001)
    scheduler_model = torch.optim.lr_scheduler.MultiStepLR(
        model_optimizer, [10 * (i + 1) for i in range(nb_epochs // 10)], gamma=0.9954)
    scheduler_camera_parameters = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_camera_parameters, [100 * (i + 1) for i in range(nb_epochs // 100)], gamma=0.81)
    scheduler_focal = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_focal, [100 * (i + 1) for i in range(nb_epochs // 100)], gamma=0.9)
    train(model, [model_optimizer, optimizer_camera_parameters, optimizer_focal],
          [scheduler_model, scheduler_camera_parameters, scheduler_focal], training_images, camera_extrinsics,
          camera_intrinsics, batch_size, nb_epochs, hn=0., hf=1., nb_bins=192)

    # Part 2
    model = NerfModel(hidden_dim=256).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler_model = torch.optim.lr_scheduler.MultiStepLR(
        model_optimizer, [10 * (i + 1) for i in range(nb_epochs // 10)], gamma=0.9954)
    train(model, [model_optimizer], [scheduler_model], training_images, camera_extrinsics, camera_intrinsics,
          batch_size, nb_epochs, hn=0., hf=1., nb_bins=192)

    # Test: interpolation between two images
    test(model, camera_intrinsics, (.5 * camera_extrinsics[0] + .5 * camera_extrinsics[1]).unsqueeze(0), 0., 1.,
         training_images, img_index=0, nb_bins=192, H=training_images.shape[1], W=training_images.shape[2])

Training images shape: (100, 400, 400, 3)


100%|███████████████████████████████████████████| 60/60 [31:10<00:00, 31.17s/it]
100%|███████████████████████████████████████████| 60/60 [31:10<00:00, 31.18s/it]


# Camera / Dataset

In [None]:
batch_size = 1024

o, d, target_px_values = get_rays('fox', mode='train')
dataloader = DataLoader(torch.cat((torch.from_numpy(o).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(d).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(target_px_values).reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)


dataloader_warmup = DataLoader(torch.cat((torch.from_numpy(o).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(d).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(target_px_values).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float)), dim=1),
                       batch_size=batch_size, shuffle=True)


test_o, test_d, test_target_px_values = get_rays('fox', mode='test')

# Training

In [None]:
device = device

tn = 8.
tf = 12.
nb_epochs = 1 #15 30
lr =  1e-3 # 1e-3 5e-4
gamma = .5 #0.5 0.7 
nb_bins = 100 #100 256

model = Nerf(hidden_dim=256).to(device) #Nerf(hidden_dim=128).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=gamma)



training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, 1, dataloader_warmup, device=device)
plt.plot(training_loss)
plt.show()
training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, nb_epochs, dataloader, device=device)
plt.plot(training_loss)
plt.show()

In [5]:
torch.save(model, 'model_nerf_colmap')

# Mesh extraction

In [6]:
model = torch.load('model_nerf_colmap').to(device)

  model = torch.load('model_nerf_colmap').to(device)


In [7]:
import torch
import numpy as np
from skimage.measure import marching_cubes
import trimesh
import torch.nn.functional as F

def analyze_density_field(density_volume):
    """Analyze the density field to help choose a good threshold."""
    min_density = float(density_volume.min())
    max_density = float(density_volume.max())
    mean_density = float(density_volume.mean())
    std_density = float(density_volume.std())
    
    print(f"Density field statistics:")
    print(f"Min: {min_density:.6f}")
    print(f"Max: {max_density:.6f}")
    print(f"Mean: {mean_density:.6f}")
    print(f"Std: {std_density:.6f}")
    
    # Suggest threshold as mean + 1 std deviation
    suggested_threshold = mean_density + std_density
    return suggested_threshold

def extract_mesh(nerf_model, resolution=128, threshold=None, bbox_min=[-1.5, -1.5, -1.5], 
                bbox_max=[1.5, 1.5, 1.5], device=torch.device("cpu")):
    """
    Extract a colored mesh from a trained NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        resolution: Grid resolution for marching cubes
        threshold: Density threshold for surface extraction (if None, will be auto-determined)
        bbox_min: Minimum corner of bounding box
        bbox_max: Maximum corner of bounding box
        device: Torch device to use
    
    Returns:
        trimesh.Trimesh: Colored mesh
    """
    print(f"Creating density volume with resolution {resolution}...")
    
    # Create grid of points
    x = torch.linspace(bbox_min[0], bbox_max[0], resolution)
    y = torch.linspace(bbox_min[1], bbox_max[1], resolution)
    z = torch.linspace(bbox_min[2], bbox_max[2], resolution)
    xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
    points = torch.stack([xx, yy, zz], dim=-1).to(device)
    
    # Create density volume
    density_volume = torch.zeros((resolution, resolution, resolution))
    chunk_size = 512 * 512  # Process in chunks to avoid OOM
    
    print("Sampling density field...")
    with torch.no_grad():
        for i in range(0, points.numel() // 3, chunk_size):
            chunk_points = points.reshape(-1, 3)[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            _, chunk_densities = nerf_model(chunk_points, torch.zeros_like(chunk_points))
            density_volume.reshape(-1)[i:i+chunk_size] = chunk_densities.cpu()
    
    # Auto-determine threshold if not provided
    if threshold is None:
        threshold = analyze_density_field(density_volume)
        print(f"Auto-determined threshold: {threshold:.6f}")
    
    print(f"Extracting mesh with threshold {threshold}...")
    
    try:
        # Extract mesh using marching cubes
        vertices, faces, normals, _ = marching_cubes(
            density_volume.numpy(),
            threshold,
            spacing=((bbox_max[0] - bbox_min[0])/resolution,
                    (bbox_max[1] - bbox_min[1])/resolution,
                    (bbox_max[2] - bbox_min[2])/resolution)
        )
    except ValueError as e:
        print("Error during marching cubes:")
        print(e)
        print("\nTry adjusting the threshold based on the density statistics above.")
        raise
    
    print(f"Mesh extracted with {len(vertices)} vertices and {len(faces)} faces")
    
    # Adjust vertices to match bbox
    vertices = vertices + np.array(bbox_min)
    
    # Sample colors at vertex positions
    vertex_colors = torch.zeros((len(vertices), 3))
    vertices_tensor = torch.tensor(vertices, dtype=torch.float32).to(device)
    
    print("Sampling colors...")
    with torch.no_grad():
        for i in range(0, len(vertices), chunk_size):
            chunk_vertices = vertices_tensor[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            chunk_colors, _ = nerf_model(chunk_vertices, torch.zeros_like(chunk_vertices))
            vertex_colors[i:i+chunk_size] = chunk_colors.cpu()
    
    # Create mesh with vertex colors
    mesh = trimesh.Trimesh(
        vertices=vertices,
        faces=faces,
        vertex_colors=(vertex_colors.numpy() * 255).astype(np.uint8),
        vertex_normals=normals
    )
    
    return mesh

def save_colored_mesh(nerf_model, output_path, resolution=256, threshold=None, device=torch.device("cpu")):
    """
    Extract and save a colored mesh from a NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        output_path: Path to save the mesh (should end in .ply or .obj)
        resolution: Resolution for marching cubes
        threshold: Density threshold (if None, will be auto-determined)
        device: Torch device to use
    """
    mesh = extract_mesh(nerf_model, resolution=resolution, threshold=threshold, device=device)
    
    print("Processing mesh...")
    # Optional mesh cleanup
    mesh = mesh.process(validate=True)
    
    print(f"Saving mesh to {output_path}...")
    # Save the mesh
    mesh.export(output_path)
    return mesh

# After loading your model
resolution = 700  # Increase for better quality, decrease if you run into memory issues
output_path = "nerf_mesh.obj"  # Can also use .obj format

# Extract and save the mesh
mesh = save_colored_mesh(model, output_path, resolution=resolution, device=device)

Creating density volume with resolution 700...
Sampling density field...
Density field statistics:
Min: 0.000000
Max: 0.000000
Mean: 0.000000
Std: 0.000000
Auto-determined threshold: 0.000000
Extracting mesh with threshold 0.0...


RuntimeError: No surface found at the given iso value.

In [None]:
#!pip install Pymcubes
#!pip install trimesh
#!pip install -U scikit-image
#!pip install genesis-world  # Requires Python >=3.9;
#!pip uninstall genesis-world
#!conda install -c anaconda trimesh