In [10]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataset import get_rays
from rendering import rendering
from model import Voxels, Nerf
from ml_helpers import training

from skimage.measure import marching_cubes
import trimesh
import trimesh.smoothing
import os

In [2]:
print("torch version: ", torch.__version__)

if torch.backends.mps.is_available():
    device = torch.device("mps")
else: 
    device = torch.device("cpu")

print("device: ", device)

torch version:  2.5.1
device:  mps


In [4]:
import numpy as np
from pathlib import Path
import quaternion  # You'll need to install numpy-quaternion package

def read_cameras(path):
    """Read COLMAP cameras.txt and convert to intrinsic matrices"""
    cameras = {}
    with open(path, 'r') as f:
        lines = f.readlines()
        
    for line in lines:
        if line[0] == '#' or len(line.strip()) == 0:
            continue
        
        data = line.strip().split()
        camera_id = int(data[0])
        # SIMPLE_RADIAL format: f, cx, cy, k
        if data[1] == 'SIMPLE_RADIAL':
            f = float(data[4])
            cx = float(data[5])
            cy = float(data[6])
            
            # Create 4x4 intrinsic matrix
            K = np.array([
                [f, 0, cx, 0],
                [0, f, cy, 0],
                [0, 0, 1, 0],
                [0, 0, 0, 1]
            ], dtype=np.float32)
            
            cameras[camera_id] = K
            
    return cameras

def read_images(path):
    """Read COLMAP images.txt and convert quaternions to transformation matrices"""
    images = {}
    with open(path, 'r') as f:
        lines = f.readlines()
        
    i = 0
    while i < len(lines):
        if lines[i][0] == '#' or len(lines[i].strip()) == 0:
            i += 1
            continue
            
        # First line contains pose info
        data = lines[i].strip().split()
        image_id = int(data[0])
        qw, qx, qy, qz = map(float, data[1:5])
        tx, ty, tz = map(float, data[5:8])
        camera_id = int(data[8])
        
        # Convert quaternion to rotation matrix
        q = np.quaternion(qw, qx, qy, qz)
        R = quaternion.as_rotation_matrix(q)
        
        # Create 4x4 transformation matrix
        T = np.eye(4, dtype=np.float32)
        T[:3, :3] = R
        T[:3, 3] = [tx, ty, tz]
        
        images[image_id] = {'transform': T, 'camera_id': camera_id}
        
        # Skip second line (2D points)
        i += 2
        
    return images

def save_matrix(matrix, output_path):
    """Save 4x4 matrix to txt file in row-major order"""
    np.savetxt(output_path, matrix.reshape(-1), fmt='%.16f')

def convert_colmap_to_matrices(colmap_dir, output_dir):
    """Convert COLMAP output to individual matrix files"""
    colmap_dir = Path(colmap_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    # Read camera and image data
    cameras = read_cameras(colmap_dir / 'cameras.txt')
    images = read_images(colmap_dir / 'images.txt')
    
    # Save matrices
    for image_id, image_data in images.items():
        # Save camera intrinsics
        camera_id = image_data['camera_id']
        if camera_id in cameras:
            intrinsics_path = output_dir / f'intrinsics_{image_id:03d}.txt'
            save_matrix(cameras[camera_id], intrinsics_path)
        
        # Save camera pose
        pose_path = output_dir / f'pose_{image_id:03d}.txt'
        save_matrix(image_data['transform'], pose_path)

if __name__ == '__main__':
    # Example usage
    colmap_dir = 'fox/colmap'  # Directory containing cameras.txt and images.txt
    output_dir = 'fox/colmap/matrices'  # Output directory for matrix files
    convert_colmap_to_matrices(colmap_dir, output_dir)

In [5]:
import numpy as np
import os
import imageio.v2 as imageio
from typing import Dict, Tuple
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def read_cameras(path: str) -> Dict:
    cameras = {}
    with open(path, "r") as f:
        for line in f:
            if line[0] == "#": continue
            data = line.strip().split()
            camera_id = int(data[0])
            cameras[camera_id] = {
                "model": data[1],
                "width": int(data[2]),
                "height": int(data[3]),
                "params": np.array(data[4:], dtype=np.float64)
            }
    return cameras

def read_images(path: str) -> Dict:
    images = {}
    with open(path, "r") as f:
        lines = f.readlines()

    for i in range(0, len(lines), 2):
        if lines[i][0] == "#": continue
        data = lines[i].split()
        points_data = lines[i + 1].split()
        
        image_id = int(data[0])
        images[image_id] = {
            "qvec": np.array(data[1:5], dtype=float),
            "tvec": np.array(data[5:8], dtype=float),
            "camera_id": int(data[8]),
            "name": data[9],
            "xys": np.array([(float(points_data[j]), float(points_data[j+1])) 
                           for j in range(0, len(points_data), 3)]),
            "point3D_ids": np.array([int(points_data[j+2]) 
                                   for j in range(0, len(points_data), 3)])
        }
    return images

def qvec2rotmat(qvec: np.ndarray) -> np.ndarray:
    return np.array([
        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])

def generate_rays(camera: Dict, image_data: Dict, img: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    # Keep original dimensions
    H, W = img.shape[:2]
    f = camera['params'][0]
        
    if img.shape[2] == 4:
        img = img[..., :3] * img[..., -1:] + (1 - img[..., -1:])
    
    R = qvec2rotmat(image_data['qvec'])
    t = image_data['tvec']
    C = -R.T @ t
    
    u, v = np.meshgrid(np.arange(W), np.arange(H))
    dirs = np.stack([(u - W/2), -(v - H/2), -np.ones_like(u) * f], axis=-1)
    dirs = (R.T @ dirs.reshape(-1, 3).T).T
    dirs = dirs / np.linalg.norm(dirs, axis=-1, keepdims=True)
    
    rays_o = np.broadcast_to(C, dirs.shape)
    rays_d = dirs
    target_px_values = img.reshape(-1, 3)
    
    return rays_o, rays_d, target_px_values

def get_rays(datapath: str, mode: str = 'train') -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    colmap_dir = os.path.join(datapath, 'colmap')
    imgs_dir = os.path.join(datapath, 'imgs')
    
    cameras = read_cameras(os.path.join(colmap_dir, 'cameras.txt'))
    images = read_images(os.path.join(colmap_dir, 'images.txt'))
    
    # Get a sample image to determine dimensions
    sample_img = imageio.imread(os.path.join(imgs_dir, list(images.values())[0]['name']))
    H, W = sample_img.shape[:2]
    
    image_list = sorted(images.items(), key=lambda x: x[1]['name'])
    split_idx = int(len(image_list) * 0.9)
    selected_images = image_list[:split_idx] if mode == 'train' else image_list[split_idx:]
    
    rays_o = []
    rays_d = []
    target_px_values = []
    
    for image_id, image_data in selected_images:
        img_path = os.path.join(imgs_dir, image_data['name'])
        img = imageio.imread(img_path).astype(np.float32) / 255.0
        
        # Ensure consistent dimensions
        if img.shape[:2] != (H, W):
            from skimage.transform import resize
            img = resize(img, (H, W, 3), anti_aliasing=True)
        
        camera = cameras[image_data['camera_id']]
        ro, rd, rgb = generate_rays(camera, image_data, img)
        
        rays_o.append(ro)
        rays_d.append(rd)
        target_px_values.append(rgb)
    
    # Stack and reshape
    rays_o = np.stack(rays_o)
    rays_d = np.stack(rays_d)
    target_px_values = np.stack(target_px_values)
    
    n_images = len(selected_images)
    logger.info(f"Number of images: {n_images}, Image dimensions: {H}x{W}")
    
    return (rays_o.reshape(n_images, H, W, 3),
            rays_d.reshape(n_images, H, W, 3),
            target_px_values.reshape(n_images, H, W, 3))

# move matrices and rename

In [7]:
import os
import shutil
from pathlib import Path

def organize_files(source_dir, test_dir, train_dir, test_count=10):
    """
    Organize matrix files into test and train directories with proper subdirectories.
    
    Args:
        source_dir: Path to source directory containing matrix files
        test_dir: Path to test directory
        train_dir: Path to train directory
        test_count: Number of files to move to test directory
    """
    # Create required directories
    for base_dir in [test_dir, train_dir]:
        for subdir in ['intrinsics', 'pose']:
            os.makedirs(os.path.join(base_dir, subdir), exist_ok=True)
    
    # Get list of all files
    source_path = Path(source_dir)
    intrinsics_files = sorted([f for f in source_path.glob('intrinsics_*.txt')])
    pose_files = sorted([f for f in source_path.glob('pose_*.txt')])
    
    # Move files to test directory
    for i in range(min(test_count, len(intrinsics_files))):
        # Move intrinsics files
        shutil.copy2(
            intrinsics_files[i],
            os.path.join(test_dir, 'intrinsics', intrinsics_files[i].name)
        )
        # Move pose files
        shutil.copy2(
            pose_files[i],
            os.path.join(test_dir, 'pose', pose_files[i].name)
        )
    
    # Move remaining files to train directory
    for i in range(test_count, len(intrinsics_files)):
        # Move intrinsics files
        shutil.copy2(
            intrinsics_files[i],
            os.path.join(train_dir, 'intrinsics', intrinsics_files[i].name)
        )
        # Move pose files
        shutil.copy2(
            pose_files[i],
            os.path.join(train_dir, 'pose', pose_files[i].name)
        )

if __name__ == '__main__':
    # Define paths
    source_dir = 'fox/colmap/matrices'
    test_dir = 'fox/test'
    train_dir = 'fox/train'
    
    # Organize files
    organize_files(source_dir, test_dir, train_dir, test_count=10)
    print("Files have been organized successfully!")

Files have been organized successfully!


# rename

In [8]:
import os
from pathlib import Path

def rename_files(base_dir):
    """
    Rename files in test and train directories to follow the pattern:
    test_X.txt or train_X.txt where X is a sequential number
    
    Args:
        base_dir: Base directory containing test and train folders
    """
    # Process both test and train directories
    for dir_type in ['test', 'train']:
        dir_path = Path(base_dir) / dir_type
        
        # Process both intrinsics and pose subdirectories
        for subdir in ['intrinsics', 'pose']:
            subdir_path = dir_path / subdir
            
            # Get all txt files in the directory
            files = sorted(list(subdir_path.glob('*.txt')))
            
            # Rename files
            for i, file_path in enumerate(files, 1):
                new_name = f"{dir_type}_{i}.txt"
                new_path = subdir_path / new_name
                
                # Rename the file
                os.rename(file_path, new_path)
                print(f"Renamed {file_path.name} to {new_name} in {subdir}")

if __name__ == '__main__':
    # Define base directory
    base_dir = 'fox'
    
    # Rename files
    rename_files(base_dir)
    print("Files have been renamed successfully!")

Renamed intrinsics_001.txt to test_1.txt in intrinsics
Renamed intrinsics_002.txt to test_2.txt in intrinsics
Renamed intrinsics_003.txt to test_3.txt in intrinsics
Renamed intrinsics_004.txt to test_4.txt in intrinsics
Renamed intrinsics_005.txt to test_5.txt in intrinsics
Renamed intrinsics_006.txt to test_6.txt in intrinsics
Renamed intrinsics_007.txt to test_7.txt in intrinsics
Renamed intrinsics_009.txt to test_8.txt in intrinsics
Renamed intrinsics_010.txt to test_9.txt in intrinsics
Renamed intrinsics_011.txt to test_10.txt in intrinsics
Renamed pose_001.txt to test_1.txt in pose
Renamed pose_002.txt to test_2.txt in pose
Renamed pose_003.txt to test_3.txt in pose
Renamed pose_004.txt to test_4.txt in pose
Renamed pose_005.txt to test_5.txt in pose
Renamed pose_006.txt to test_6.txt in pose
Renamed pose_007.txt to test_7.txt in pose
Renamed pose_009.txt to test_8.txt in pose
Renamed pose_010.txt to test_9.txt in pose
Renamed pose_011.txt to test_10.txt in pose
Renamed intrinsics

# Camera / Dataset

In [None]:
batch_size = 1024
o, d, target_px_values = get_rays('fox', mode='train')

# Regular dataloader
dataloader = DataLoader(torch.cat((
    torch.from_numpy(o).reshape(-1, 3).type(torch.float),
    torch.from_numpy(d).reshape(-1, 3).type(torch.float),
    torch.from_numpy(target_px_values).reshape(-1, 3).type(torch.float)), dim=1),
    batch_size=batch_size, shuffle=True)

# Warmup dataloader with cropped center region
dataloader_warmup = DataLoader(torch.cat((
    torch.from_numpy(o).reshape(89, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
    torch.from_numpy(d).reshape(89, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
    torch.from_numpy(target_px_values).reshape(89, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float)), dim=1),
    batch_size=batch_size, shuffle=True)

# Test data
test_o, test_d, test_target_px_values = get_rays('fox', mode='test')

# Training

In [9]:
device = device

tn = 8.
tf = 12.
nb_epochs = 1 #15 30
lr =  1e-3 # 1e-3 5e-4
gamma = .5 #0.5 0.7 
nb_bins = 100 #100 256

model = Nerf(hidden_dim=256).to(device) #Nerf(hidden_dim=128).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=gamma)



training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, 1, dataloader_warmup, device=device)
plt.plot(training_loss)
plt.show()
training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, nb_epochs, dataloader, device=device)
plt.plot(training_loss)
plt.show()

NameError: name 'dataloader_warmup' is not defined

In [None]:
torch.save(model, 'model_nerf_colmap')

# Mesh extraction

In [None]:
model = torch.load('model_nerf_colmap').to(device)

In [None]:
import torch
import numpy as np
from skimage.measure import marching_cubes
import trimesh
import torch.nn.functional as F

def analyze_density_field(density_volume):
    """Analyze the density field to help choose a good threshold."""
    min_density = float(density_volume.min())
    max_density = float(density_volume.max())
    mean_density = float(density_volume.mean())
    std_density = float(density_volume.std())
    
    print(f"Density field statistics:")
    print(f"Min: {min_density:.6f}")
    print(f"Max: {max_density:.6f}")
    print(f"Mean: {mean_density:.6f}")
    print(f"Std: {std_density:.6f}")
    
    # Suggest threshold as mean + 1 std deviation
    suggested_threshold = mean_density + std_density
    return suggested_threshold

def extract_mesh(nerf_model, resolution=128, threshold=None, bbox_min=[-1.5, -1.5, -1.5], 
                bbox_max=[1.5, 1.5, 1.5], device=torch.device("cpu")):
    """
    Extract a colored mesh from a trained NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        resolution: Grid resolution for marching cubes
        threshold: Density threshold for surface extraction (if None, will be auto-determined)
        bbox_min: Minimum corner of bounding box
        bbox_max: Maximum corner of bounding box
        device: Torch device to use
    
    Returns:
        trimesh.Trimesh: Colored mesh
    """
    print(f"Creating density volume with resolution {resolution}...")
    
    # Create grid of points
    x = torch.linspace(bbox_min[0], bbox_max[0], resolution)
    y = torch.linspace(bbox_min[1], bbox_max[1], resolution)
    z = torch.linspace(bbox_min[2], bbox_max[2], resolution)
    xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
    points = torch.stack([xx, yy, zz], dim=-1).to(device)
    
    # Create density volume
    density_volume = torch.zeros((resolution, resolution, resolution))
    chunk_size = 512 * 512  # Process in chunks to avoid OOM
    
    print("Sampling density field...")
    with torch.no_grad():
        for i in range(0, points.numel() // 3, chunk_size):
            chunk_points = points.reshape(-1, 3)[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            _, chunk_densities = nerf_model(chunk_points, torch.zeros_like(chunk_points))
            density_volume.reshape(-1)[i:i+chunk_size] = chunk_densities.cpu()
    
    # Auto-determine threshold if not provided
    if threshold is None:
        threshold = analyze_density_field(density_volume)
        print(f"Auto-determined threshold: {threshold:.6f}")
    
    print(f"Extracting mesh with threshold {threshold}...")
    
    try:
        # Extract mesh using marching cubes
        vertices, faces, normals, _ = marching_cubes(
            density_volume.numpy(),
            threshold,
            spacing=((bbox_max[0] - bbox_min[0])/resolution,
                    (bbox_max[1] - bbox_min[1])/resolution,
                    (bbox_max[2] - bbox_min[2])/resolution)
        )
    except ValueError as e:
        print("Error during marching cubes:")
        print(e)
        print("\nTry adjusting the threshold based on the density statistics above.")
        raise
    
    print(f"Mesh extracted with {len(vertices)} vertices and {len(faces)} faces")
    
    # Adjust vertices to match bbox
    vertices = vertices + np.array(bbox_min)
    
    # Sample colors at vertex positions
    vertex_colors = torch.zeros((len(vertices), 3))
    vertices_tensor = torch.tensor(vertices, dtype=torch.float32).to(device)
    
    print("Sampling colors...")
    with torch.no_grad():
        for i in range(0, len(vertices), chunk_size):
            chunk_vertices = vertices_tensor[i:i+chunk_size]
            # Assume model returns (rgb, sigma) tuple
            chunk_colors, _ = nerf_model(chunk_vertices, torch.zeros_like(chunk_vertices))
            vertex_colors[i:i+chunk_size] = chunk_colors.cpu()
    
    # Create mesh with vertex colors
    mesh = trimesh.Trimesh(
        vertices=vertices,
        faces=faces,
        vertex_colors=(vertex_colors.numpy() * 255).astype(np.uint8),
        vertex_normals=normals
    )
    
    return mesh

def save_colored_mesh(nerf_model, output_path, resolution=256, threshold=None, device=torch.device("cpu")):
    """
    Extract and save a colored mesh from a NeRF model.
    
    Args:
        nerf_model: Trained NeRF model
        output_path: Path to save the mesh (should end in .ply or .obj)
        resolution: Resolution for marching cubes
        threshold: Density threshold (if None, will be auto-determined)
        device: Torch device to use
    """
    mesh = extract_mesh(nerf_model, resolution=resolution, threshold=threshold, device=device)
    
    print("Processing mesh...")
    # Optional mesh cleanup
    mesh = mesh.process(validate=True)
    
    print(f"Saving mesh to {output_path}...")
    # Save the mesh
    mesh.export(output_path)
    return mesh

# After loading your model
resolution = 700  # Increase for better quality, decrease if you run into memory issues
output_path = "nerf_mesh.obj"  # Can also use .obj format

# Extract and save the mesh
mesh = save_colored_mesh(model, output_path, resolution=resolution, device=device)

In [None]:
#!pip install Pymcubes
#!pip install trimesh
#!pip install -U scikit-image
#!pip install genesis-world  # Requires Python >=3.9;
#!pip uninstall genesis-world
#!conda install -c anaconda trimesh