In [1]:
import os
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import imageio
from torch.cuda.amp import autocast, GradScaler
from models.nerf_mlp import NeRFMLP
from utils.encoding import PositionalEncoding
from utils.render import render_rays
import matplotlib.cm as cm

In [2]:
def get_rays(H, W, focal, c2w):
    device = c2w.device
    i, j = torch.meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device), indexing='xy')
    dirs = torch.stack([(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], -1)
    rays_o = c2w[:3, 3].expand(rays_d.shape)
    return rays_o, rays_d

In [3]:
def train(base_dir, n_iters, batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Load data into CPU RAM first ---
    print("Loading data into RAM...")
    with open(os.path.join(base_dir, 'transforms_train.json'), 'r') as f:
        meta = json.load(f)
    
    images_np, poses_np = [], []
    for frame in meta['frames']:
        fname = os.path.join(base_dir, frame['file_path'] + '.png')
        img_rgba = (np.array(Image.open(fname)) / 255.0).astype(np.float32)
        img_rgb = img_rgba[..., :3] * img_rgba[..., 3:] + (1.0 - img_rgba[..., 3:])
        images_np.append(img_rgb)
        poses_np.append(np.array(frame['transform_matrix']))

    images_np = np.stack(images_np)
    poses_np = np.stack(poses_np)
    H, W = images_np[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = 0.5 * W / np.tan(0.5 * camera_angle_x)

    # --- Pre-compute all rays and move the ENTIRE dataset to the GPU ---
    print("Pre-computing all rays and moving dataset to GPU VRAM...")
    all_rays_o = []
    all_rays_d = []
    for pose in tqdm(poses_np, desc="Processing Poses"):
        pose_tensor = torch.tensor(pose, dtype=torch.float32, device=device)
        rays_o, rays_d = get_rays(H, W, focal, pose_tensor)
        all_rays_o.append(rays_o.reshape(-1, 3))
        all_rays_d.append(rays_d.reshape(-1, 3))

    # These now live permanently on the GPU
    all_rays_o = torch.cat(all_rays_o, 0)
    all_rays_d = torch.cat(all_rays_d, 0)
    all_colors = torch.tensor(images_np.reshape(-1, 3), dtype=torch.float32, device=device)
    print("Dataset is now fully in VRAM! 🚀")

    # --- Models, Encoders, Optimizer (all on GPU) ---
    pos_enc = PositionalEncoding(num_freqs=10).to(device)
    dir_enc = PositionalEncoding(num_freqs=4).to(device)
    models = {
        'coarse': NeRFMLP(input_ch=pos_enc.output_dims, input_ch_dir=dir_enc.output_dims).to(device),
        'fine': NeRFMLP(input_ch=pos_enc.output_dims, input_ch_dir=dir_enc.output_dims).to(device)
    }
    optimizer = torch.optim.Adam(list(models['coarse'].parameters()) + list(models['fine'].parameters()), lr=5e-4)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9995)
    scaler = GradScaler()

    # --- Ultra-Fast Training Loop (No DataLoader) ---
    print("\nStarting training...")
    num_total_rays = all_rays_o.shape[0]
    for i in tqdm(range(n_iters), desc="Training Progress"):
        # 1. Generate random indices directly on the GPU
        indices = torch.randint(0, num_total_rays, (batch_size,), device=device)

        # 2. Gather a batch using the indices (very fast GPU-to-GPU copy)
        batch_rays_o = all_rays_o[indices]
        batch_rays_d = all_rays_d[indices]
        batch_colors = all_colors[indices]

        # 3. Perform the training step
        with autocast():
            results = render_rays(
                models, pos_enc, dir_enc, batch_rays_o, batch_rays_d,
                N_samples=64, N_importance=128, near=2.0, far=6.0, 
                white_bkgd=True, device=device
            )
            loss = torch.mean((results['rgb_map'] - batch_colors) ** 2) + torch.mean((results['rgb_map0'] - batch_colors) ** 2)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

    print("Training finished! ✅")
    # ... (return models and other variables for video rendering) ...
    torch.save({
        'coarse_model_state_dict': models['coarse'].state_dict(),
        'fine_model_state_dict': models['fine'].state_dict(),
    }, 'nerf_models.pth')
    print("Saved trained models to nerf_models.pth")

    # --- THE FIX: Add the return statement ---
    return models, H, W, focal, pos_enc, dir_enc, device


In [4]:
def create_camera_path(radius=4.0, n_poses=120):
    """Creates a circular camera orbit."""
    c2w_frames = []
    angles = torch.linspace(0, 2 * np.pi, n_poses + 1)[:-1]
    for theta in angles:
        # Camera position
        pos = torch.tensor([
            radius * np.cos(theta), 
            -0.5,  # Slight downward tilt
            radius * np.sin(theta)
        ])
        
        # Look at origin (where the object is)
        target = torch.tensor([0., 0., 0.])
        forward = target - pos
        forward = forward / torch.linalg.norm(forward)
        
        # Define up direction
        world_up = torch.tensor([0., 1., 0.])
        right = torch.cross(world_up, forward)
        right = right / torch.linalg.norm(right)
        up = torch.cross(forward, right)
        
        # Build camera-to-world matrix
        c2w = torch.eye(4)
        c2w[:3, 0] = right
        c2w[:3, 1] = up
        c2w[:3, 2] = forward
        c2w[:3, 3] = pos
        
        c2w_frames.append(c2w[:3, :4])  # Only need 3x4
    
    return torch.stack(c2w_frames)

def create_nerf_video(models, H, W, focal, pos_enc, dir_enc, device):
    """Renders a video from the trained models."""
    print("\nStarting video rendering...")
    camera_poses = create_camera_path().to(device)
    frames = []

    for model in models.values():
        model.eval()

    for pose in tqdm(camera_poses, desc="Rendering Frames"):
        with torch.no_grad():
            rays_o, rays_d = get_rays(H, W, focal, pose)
            rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

            all_rgb = []
            for i in range(0, rays_o.shape[0], 4096):
                results = render_rays(
                    models, pos_enc, dir_enc, rays_o[i:i+4096], rays_d[i:i+4096],
                    N_samples=64, N_importance=128, near=2.0, far=6.0, 
                    white_bkgd=True, device=device
                )
                all_rgb.append(results['rgb_map'])
            
            full_image = torch.cat(all_rgb, 0).reshape(H, W, 3)
            img_np = (full_image.cpu().numpy() * 255).astype(np.uint8)
            frames.append(img_np)
    
    video_path = 'nerf_lego_video.mp4'
    
    # === FIX: Use explicit FFMPEG writer ===
    try:
        import imageio.v3 as iio
        iio.imwrite(video_path, frames, fps=30, codec='libx264', quality=8)
        print(f"Video saved to {video_path}! 🎉")
    except Exception as e:
        print(f"MP4 save failed: {e}")
        print("Falling back to GIF format...")
        gif_path = 'nerf_lego_video.gif'
        imageio.mimsave(gif_path, frames, fps=30)
        print(f"GIF saved to {gif_path}! 🎉")


In [None]:
def create_combined_video(models, H, W, focal, pos_enc, dir_enc, device):
    """Creates a single video with RGB | Opacity | Depth side-by-side"""
    print("\nRendering combined visualization...")
    camera_poses = create_camera_path().to(device)
    combined_frames = []

    for model in models.values():
        model.eval()

    import matplotlib.cm as cm
    
    for pose in tqdm(camera_poses, desc="Rendering"):
        with torch.no_grad():
            rays_o, rays_d = get_rays(H, W, focal, pose)
            rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

            all_rgb, all_acc, all_depth = [], [], []
            
            for i in range(0, rays_o.shape[0], 4096):
                results = render_rays(
                    models, pos_enc, dir_enc, rays_o[i:i+4096], rays_d[i:i+4096],
                    N_samples=64, N_importance=128, near=2.0, far=6.0, 
                    white_bkgd=True, device=device
                )
                all_rgb.append(results['rgb_map'])
                all_acc.append(results['acc_map'])
                all_depth.append(results['depth_map'])
            
            # Process each view
            rgb = (torch.cat(all_rgb, 0).reshape(H, W, 3).cpu().numpy() * 255).astype(np.uint8)
            
            acc = torch.cat(all_acc, 0).reshape(H, W).cpu().numpy()
            acc_colored = (cm.hot(acc) * 255).astype(np.uint8)[..., :3]
            
            depth = torch.cat(all_depth, 0).reshape(H, W).cpu().numpy()
            depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
            depth_colored = (cm.viridis(depth_norm) * 255).astype(np.uint8)[..., :3]
            
            # Stack horizontally: [RGB | Opacity | Depth]
            combined = np.hstack([rgb, acc_colored, depth_colored])
            combined_frames.append(combined)
    
    # Save
    try:
        import imageio.v3 as iio
        iio.imwrite('nerf_COMBINED.mp4', combined_frames, fps=30, codec='libx264', quality=8)
        print("🎉 Saved nerf_COMBINED.mp4 (RGB | Opacity | Depth)")
    except:
        imageio.mimsave('nerf_COMBINED.gif', combined_frames, fps=30)
        print("🎉 Saved nerf_COMBINED.gif")

In [None]:
# =================================================================================
# === MAIN EXECUTION BLOCK (MODIFIED FOR RENDERING) ===============================
# =================================================================================

def opacity_diagnostic(models, H, W, focal, pos_enc, dir_enc, device):
    """Check if model is producing any solid geometry"""
    print("\n🔬 OPACITY DIAGNOSTIC\n" + "="*50)
    
    with open(os.path.join(LEGO_DATA_DIR, 'transforms_train.json'), 'r') as f:
        meta = json.load(f)
    
    pose = torch.tensor(meta['frames'][0]['transform_matrix'], dtype=torch.float32, device=device)
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
    
    for model in models.values():
        model.eval()
    
    with torch.no_grad():
        # Render full image
        all_rgb = []
        all_acc = []
        
        for i in range(0, rays_o.shape[0], 4096):
            results = render_rays(
                models, pos_enc, dir_enc, rays_o[i:i+4096], rays_d[i:i+4096],
                N_samples=64, N_importance=128, near=2.0, far=6.0, 
                white_bkgd=False,  # No background to see raw output
                device=device
            )
            all_rgb.append(results['rgb_map'])
            all_acc.append(results['acc_map'])
        
        rgb_map = torch.cat(all_rgb, 0).reshape(H, W, 3).cpu().numpy()
        acc_map = torch.cat(all_acc, 0).reshape(H, W).cpu().numpy()
    
    print(f"RGB (no background):")
    print(f"  Min: {rgb_map.min():.4f} | Max: {rgb_map.max():.4f} | Mean: {rgb_map.mean():.4f}")
    print(f"\nOpacity Map (acc_map):")
    print(f"  Min: {acc_map.min():.4f} | Max: {acc_map.max():.4f} | Mean: {acc_map.mean():.4f}")
    
    # Diagnosis
    if acc_map.mean() < 0.3:
        print("\n❌ PROBLEM: Very low opacity (mean < 0.3)")
        print("   → Model outputs weak density everywhere")
        print("   → Need to boost sigma output or retrain longer")
    elif acc_map.mean() > 0.95:
        print("\n⚠️  Very high opacity everywhere (mean > 0.95)")
        print("   → Model might be outputting noise")
    else:
        print("\n✅ Opacity looks reasonable!")
    
    # Visualize
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(rgb_map)
    axes[0].set_title('RGB (No Background)')
    axes[0].axis('off')
    
    axes[1].imshow(acc_map, cmap='hot')
    axes[1].set_title('Opacity Map')
    axes[1].axis('off')
    
    # With white background
    rgb_with_bg = rgb_map + (1 - acc_map[..., None])
    axes[2].imshow(np.clip(rgb_with_bg, 0, 1))
    axes[2].set_title('With White Background')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig('full_diagnostic.png', dpi=150, bbox_inches='tight')
    print("\n💾 Saved full_diagnostic.png")
    print("="*50)

def detailed_diagnostic(models, H, W, focal, pos_enc, dir_enc, device):
    """Deep dive into what the model is producing"""
    print("\n🔬 DETAILED DIAGNOSTIC\n" + "="*50)
    
    with open(os.path.join(LEGO_DATA_DIR, 'transforms_train.json'), 'r') as f:
        meta = json.load(f)
    
    # Use first training pose
    pose = torch.tensor(meta['frames'][0]['transform_matrix'], dtype=torch.float32, device=device)
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
    
    for model in models.values():
        model.eval()
    
    # Sample just center of image
    center_idx = rays_o.shape[0] // 2
    sample_rays_o = rays_o[center_idx:center_idx+4096]
    sample_rays_d = rays_d[center_idx:center_idx+4096]
    
    with torch.no_grad():
        results = render_rays(
            models, pos_enc, dir_enc, sample_rays_o, sample_rays_d,
            N_samples=64, N_importance=128, near=2.0, far=6.0, 
            white_bkgd=True, device=device
        )
    
    rgb = results['rgb_map'].cpu().numpy()
    
    print(f"RGB Stats:")
    print(f"  Min: {rgb.min():.4f} | Max: {rgb.max():.4f} | Mean: {rgb.mean():.4f}")
    
    # NOW THE KEY DIAGNOSTIC - check opacity
    # We need to modify render_rays to return acc_map first...
    # Let's do a manual test
    
    # Test with NO white background
    with torch.no_grad():
        results_no_bg = render_rays(
            models, pos_enc, dir_enc, sample_rays_o, sample_rays_d,
            N_samples=64, N_importance=128, near=2.0, far=6.0, 
            white_bkgd=False,  # ← KEY CHANGE
            device=device
        )
    
    rgb_no_bg = results_no_bg['rgb_map'].cpu().numpy()
    print(f"\nRGB WITHOUT white background:")
    print(f"  Min: {rgb_no_bg.min():.4f} | Max: {rgb_no_bg.max():.4f} | Mean: {rgb_no_bg.mean():.4f}")
    
    # If mean is very low without background, density is the problem
    if rgb_no_bg.mean() < 0.1:
        print("\n⚠️  DIAGNOSIS: Model has VERY LOW DENSITY")
        print("   The model barely learned any solid geometry!")
    
    # Save comparison
    import matplotlib.pyplot as plt
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.imshow(rgb.reshape(64, 64, 3))
    ax1.set_title('With White BG')
    ax2.imshow(rgb_no_bg.reshape(64, 64, 3))
    ax2.set_title('Without BG (Raw)')
    plt.savefig('diagnostic_comparison.png', dpi=150, bbox_inches='tight')
    print("\n💾 Saved diagnostic_comparison.png")
    print("="*50)

# Call this instead of quick_diagnostic

# Run this after loading model


if __name__ == '__main__':
    LEGO_DATA_DIR = './data/nerf_synthetic/lego'
    BATCH_SIZE = 1024
    TRAIN_MODEL = False  # Keep as False since you have trained model

    if not TRAIN_MODEL:
        print("Loading saved model...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        with open(os.path.join(LEGO_DATA_DIR, 'transforms_train.json'), 'r') as f:
            meta = json.load(f)
        H, W = 800, 800
        camera_angle_x = float(meta['camera_angle_x'])
        focal = 0.5 * W / np.tan(0.5 * camera_angle_x)

        pos_enc = PositionalEncoding(num_freqs=10).to(device)
        dir_enc = PositionalEncoding(num_freqs=4).to(device)
        trained_models = {
            'coarse': NeRFMLP(input_ch=pos_enc.output_dims, input_ch_dir=dir_enc.output_dims).to(device),
            'fine': NeRFMLP(input_ch=pos_enc.output_dims, input_ch_dir=dir_enc.output_dims).to(device)
        }

        checkpoint = torch.load('nerf_models.pth', map_location=device)
        trained_models['coarse'].load_state_dict(checkpoint['coarse_model_state_dict'])
        trained_models['fine'].load_state_dict(checkpoint['fine_model_state_dict'])
        print("✅ Models loaded!")

        # Render the combined video (BEST OPTION)
        create_combined_video(trained_models, H, W, focal, pos_enc, dir_enc, device)