<a href="https://colab.research.google.com/github/jspark9703/nerf_study/blob/main/examples/vision/ipynb/nerf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [8]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
tf.random.set_seed(42)
import keras
from keras import layers
import numpy as np

# Initialize global variables (가정)
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 5
NUM_SAMPLES_COARSE = 64 # 예시 값
H = 100 # 예시 값, 실제 데이터 로드 후 설정 필요
W = 100 # 예시 값, 실제 데이터 로드 후 설정 필요
focal = 138.8889 # 예시 값, 실제 데이터 로드 후 설정 필요

# --- 데이터 로드 부분 (사용자 코드에서 가져옴) ---
url = (
    "http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz"
)
data_path = keras.utils.get_file(origin=url)
data = np.load(data_path)
images = data["images"]
(num_images_total, H_data, W_data, _) = images.shape # H, W, focal을 여기서 업데이트
poses = data["poses"]
focal_data = data["focal"]

# 전역 변수 업데이트
H = H_data
W = W_data
focal = focal_data.astype(np.float32) # focal을 float32로 변환

# -------------------------------------------

def encode_position(x, num_encoding_functions):
    """Encodes the position or direction into its corresponding Fourier feature.

    Args:
        x: The input coordinate or direction (3D).
        num_encoding_functions: The L value for positional encoding.

    Returns:
        Fourier features tensors of the position or direction.
    """
    positions = [x]
    for i in range(num_encoding_functions):
        for fn in [tf.sin, tf.cos]:
            positions.append(fn((2.0**i) * x))
    return tf.concat(positions, axis=-1)


def get_rays(height, width, focal_val, pose): # focal 인자 이름 변경 (전역 focal과 구분)
    """Computes origin point and direction vector of rays.

    Args:
        height: Height of the image.
        width: Width of the image.
        focal_val: The focal length between the images and the camera.
        pose: The pose matrix of the camera (camera-to-world).

    Returns:
        Tuple of origin point and direction vector for rays.
    """
    i, j = tf.meshgrid(
        tf.range(width, dtype=tf.float32),
        tf.range(height, dtype=tf.float32),
        indexing="xy",
    )
    transformed_i = (i - width * 0.5) / focal_val
    transformed_j = (j - height * 0.5) / focal_val
    directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
    ray_directions = tf.reduce_sum(directions[..., None, :] * pose[:3, :3], axis=-1)
    ray_origins = tf.broadcast_to(pose[:3, -1], tf.shape(ray_directions))
    return (ray_origins, ray_directions)


def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
    """Generates sample points along rays.

    Args:
        ray_origins: Origin of the rays (H, W, 3) when called from map_fn.
        ray_directions: Direction of the rays (H, W, 3) when called from map_fn.
        near: The near bound for sampling.
        far: The far bound for sampling.
        num_samples: Number of samples per ray.
        rand: If True, add random noise to sampling positions.

    Returns:
        Tuple of:
            rays_flat: Sampled points (H * W * num_samples, 3).
            directions_flat: Ray directions corresponding to each sample point
                             (H * W * num_samples, 3).
            t_vals: The t values for sampling along each ray (H, W, num_samples).
    """
    t_vals_linspace = tf.linspace(near, far, num_samples)
    current_shape_prefix = tf.shape(ray_origins)[:-1] # (H, W)

    if rand:
        noise_shape = tf.concat([current_shape_prefix, [num_samples]], axis=0) # (H, W, num_samples)
        noise = tf.random.uniform(shape=noise_shape) * (far - near) / num_samples
        t_vals = tf.broadcast_to(t_vals_linspace, noise_shape) + noise
    else:
        t_vals_shape = tf.concat([current_shape_prefix, [num_samples]], axis=0) # (H, W, num_samples)
        t_vals = tf.broadcast_to(t_vals_linspace, t_vals_shape)

    rays = ray_origins[..., None, :] + ray_directions[..., None, :] * t_vals[..., None]
    rays_flat = tf.reshape(rays, [-1, 3])

    directions = tf.broadcast_to(ray_directions[..., None, :], tf.shape(rays))
    directions_flat = tf.reshape(directions, [-1, 3])

    return rays_flat, directions_flat, t_vals

# map_fn_coarse는 (pose, image) 쌍을 처리합니다.
def map_fn_coarse(pose, image): # 두 개의 인자로 명시적으로 받음
    # 전역 H, W, focal 사용
    ray_origins, ray_directions = get_rays(height=H, width=W, focal_val=focal, pose=pose)

    rays_flat, dirs_flat, t_vals = render_flat_rays(
        ray_origins=ray_origins,
        ray_directions=ray_directions,
        near=2.0,
        far=6.0,
        num_samples=NUM_SAMPLES_COARSE,
        rand=True,
    )
    # t_vals의 shape은 (H, W, NUM_SAMPLES_COARSE)가 됩니다.
    return (image, (ray_origins, ray_directions, t_vals))

# Create the training split.
split_index = int(num_images_total * 0.8)

# Split the images into training and validation.
train_images_data = images[:split_index]
val_images_data = images[split_index:]

# Split the poses into training and validation.
train_poses_data = poses[:split_index]
val_poses_data = poses[split_index:]

# Make the training pipeline.
train_img_ds = tf.data.Dataset.from_tensor_slices(train_images_data.astype(np.float32)/255.0)
train_pose_ds = tf.data.Dataset.from_tensor_slices(train_poses_data.astype(np.float32))
# train_dataset_raw의 각 요소는 (pose_tensor, image_tensor) 튜플입니다.
train_dataset_raw = tf.data.Dataset.zip((train_pose_ds, train_img_ds))

# Map poses to rays and t_vals for coarse sampling
# map 함수는 train_dataset_raw의 각 요소를 풀어서 map_fn_coarse에 전달합니다.
# 즉, map_fn_coarse(pose_tensor, image_tensor) 형태로 호출됩니다.
train_ds = (
    train_dataset_raw.map(map_fn_coarse, num_parallel_calls=AUTO)
    .shuffle(BATCH_SIZE * 10)
    .batch(BATCH_SIZE, drop_remainder=True, num_parallel_calls=AUTO)
    .prefetch(AUTO)
)


# Make the validation pipeline.
val_img_ds = tf.data.Dataset.from_tensor_slices(val_images_data.astype(np.float32)/255.0)
val_pose_ds = tf.data.Dataset.from_tensor_slices(val_poses_data.astype(np.float32))
val_dataset_raw = tf.data.Dataset.zip((val_pose_ds, val_img_ds))

val_ds = (
    val_dataset_raw.map(map_fn_coarse, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE, drop_remainder=True, num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

# 실행 테스트 (데이터셋에서 첫 번째 배치 가져오기)
# for images_batch, (rays_o_batch, rays_d_batch, t_vals_coarse_batch) in train_ds.take(1):
#     print("Images batch shape:", images_batch.shape)
#     print("Ray origins batch shape:", rays_o_batch.shape)
#     print("Ray directions batch shape:", rays_d_batch.shape)
#     print("Coarse t_vals batch shape:", t_vals_coarse_batch.shape)
#     # t_vals_coarse_batch는 (BATCH_SIZE, H, W, NUM_SAMPLES_COARSE)가 되어야 합니다.
#     break

# print("Dataset pipeline created successfully.")

In [9]:
for images_batch, (ray_origins, ray_directions, t_vals_batch) in train_ds.take(1):
    print("✅ 이미지 배치")
    print("images_batch.shape:", images_batch.shape)

    print("\n✅ 광선 원점 (ray origins)")
    print("ray_origins.shape:", ray_origins.shape)

    print("\n✅ 광선 방향 (ray directions)")
    print("ray_directions.shape:", ray_directions.shape)

    print("\n✅ 각 샘플 포인트의 깊이 값 (t_vals)")
    print("t_vals_batch.shape:", t_vals_batch.shape)


✅ 이미지 배치
images_batch.shape: (5, 100, 100, 3)

✅ 광선 원점 (ray origins)
ray_origins.shape: (5, 100, 100, 3)

✅ 광선 방향 (ray directions)
ray_directions.shape: (5, 100, 100, 3)

✅ 각 샘플 포인트의 깊이 값 (t_vals)
t_vals_batch.shape: (5, 100, 100, 64)


## NeRF model

The model is a multi-layer perceptron (MLP), with ReLU as its non-linearity.

An excerpt from the paper:

*"We encourage the representation to be multiview-consistent by
restricting the network to predict the volume density sigma as a
function of only the location `x`, while allowing the RGB color `c` to be
predicted as a function of both location and viewing direction. To
accomplish this, the MLP first processes the input 3D coordinate `x`
with 8 fully-connected layers (using ReLU activations and 256 channels
per layer), and outputs sigma and a 256-dimensional feature vector.
This feature vector is then concatenated with the camera ray's viewing
direction and passed to one additional fully-connected layer (using a
ReLU activation and 128 channels) that output the view-dependent RGB
color."*

Here we have gone for a minimal implementation and have used 64
Dense units instead of 256 as mentioned in the paper.

In [10]:
import tensorflow as tf
from keras import layers
import keras # Added for keras.Input, keras.Model

# Assume these are defined globally or passed appropriately
# Example values, these should be defined in your main script
POS_ENCODE_DIMS_XYZ = 10
POS_ENCODE_DIMS_DIR = 4

def get_nerf_model(num_layers=8, dense_units=256, skip_layer=4,
                   pos_encode_dims_xyz=POS_ENCODE_DIMS_XYZ, # Uses global as default
                   pos_encode_dims_dir=POS_ENCODE_DIMS_DIR): # Uses global as default
    """Defines the NeRF MLP model.

    Args:
        num_layers: Number of dense layers for position.
        dense_units: Number of units in each dense layer.
        skip_layer: Layer index for skip connection.
        pos_encode_dims_xyz: L for position encoding.
        pos_encode_dims_dir: L for direction encoding.

    Returns:
        A Keras model for NeRF.
    """
    # Input for encoded position (x,y,z)
    # Each of x, y, z is encoded to 1 (raw) + 2*L dimensions.
    # For 3 coordinates (x,y,z), it's 3 * (1 + 2*L_xyz)
    input_xyz_encoded_shape = (1 + 2 * pos_encode_dims_xyz) * 3
    input_dir_encoded_shape = (1 + 2 * pos_encode_dims_dir) * 3

    input_xyz = keras.Input(shape=(input_xyz_encoded_shape,), name="encoded_xyz")
    input_dir = keras.Input(shape=(input_dir_encoded_shape,), name="encoded_dir")

    # Position encoding pathway
    x = input_xyz
    for i in range(num_layers):
        x = layers.Dense(dense_units, activation="relu")(x)
        if i == skip_layer:
            x = layers.concatenate([x, input_xyz]) # Skip connection

    # Sigma (Density) head
    # The density sigma is predicted from the features of the position only.
    sigma = layers.Dense(1, activation=None, name="sigma")(x) # No activation, will be ReLU'd later in volumetric rendering

    # Feature vector for RGB prediction
    # This feature vector is then concatenated with the view direction.
    feature_vector = layers.Dense(dense_units, activation=None)(x) # As per your diagram/note (often ReLU here)

    # Concatenate feature vector with encoded direction
    concat_feature_dir = layers.concatenate([feature_vector, input_dir])
    rgb_intermediate = layers.Dense(dense_units // 2, activation="relu")(concat_feature_dir)
    rgb = layers.Dense(3, activation="sigmoid", name="rgb")(rgb_intermediate) # Sigmoid for [0,1] RGB

    # Final output: concatenate RGB and sigma
    output = layers.concatenate([rgb, sigma], axis=-1) # Shape: (None, 4)

    return keras.Model(inputs=[input_xyz, input_dir], outputs=output, name="NeRF_MLP")


def volumetric_rendering(raw_output, t_vals, batch_size, H_dim, W_dim, num_samples_dim):
    """Performs volumetric rendering on the raw MLP outputs.

    Args:
        raw_output: Raw output from the MLP (batch_size * H * W * num_samples, 4).
                    Contains [rgb_from_model, sigma_from_model].
        t_vals: The t values for sampling along each ray (batch_size, H, W, num_samples).
        batch_size: The batch size. (Can also be inferred from t_vals or raw_output)
        H_dim: Height of the image. (Can also be inferred)
        W_dim: Width of the image. (Can also be inferred)
        num_samples_dim: Number of samples along each ray. (Can also be inferred)

    Returns:
        Tuple of:
            rgb_map: Rendered RGB image (batch_size, H, W, 3).
            depth_map: Rendered depth map (batch_size, H, W).
            weights: Volumetric rendering weights (batch_size, H, W, num_samples).
    """
    # Infer dynamic shapes if possible, or ensure consistent inputs
    actual_batch_size = tf.shape(t_vals)[0]
    actual_h_dim = tf.shape(t_vals)[1]
    actual_w_dim = tf.shape(t_vals)[2]
    actual_num_samples_dim = tf.shape(t_vals)[3]

    # Reshape raw output to (batch_size, H, W, num_samples, 4)
    raw_output_reshaped = tf.reshape(
        raw_output,
        (actual_batch_size, actual_h_dim, actual_w_dim, actual_num_samples_dim, 4)
    )

    # Extract RGB and sigma
    # RGB is already sigmoided by the model in get_nerf_model
    rgb = raw_output_reshaped[..., :3]
    sigma_a = tf.nn.relu(raw_output_reshaped[..., 3]) # Apply ReLU to raw sigma

    # Calculate delta: distance between adjacent samples
    # t_vals shape: (batch_size, H, W, num_samples)
    delta = t_vals[..., 1:] - t_vals[..., :-1] # (batch_size, H, W, num_samples-1)

    # Add a large value for the last interval (delta for the last sample)
    delta_last = tf.broadcast_to([1e10], shape=tf.concat([tf.shape(delta)[:-1], [1]], axis=0) )
    delta = tf.concat([delta, delta_last], axis=-1) # (batch_size, H, W, num_samples)


    # Calculate alpha: opacity for each sample
    alpha = 1.0 - tf.exp(-sigma_a * delta) # (batch_size, H, W, num_samples)

    # Calculate transmittance T_i
    transmittance = tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True)

    # Calculate weights for each sample: w_i = T_i * alpha_i
    weights = alpha * transmittance # (batch_size, H, W, num_samples)

    # Calculate final RGB color for each ray
    rgb_map = tf.reduce_sum(weights[..., None] * rgb, axis=-2) # (batch_size, H, W, 3)

    # Calculate depth map
    depth_map = tf.reduce_sum(weights * t_vals, axis=-1) # (batch_size, H, W)

    return rgb_map, depth_map, weights


def sample_pdf(bins, weights, num_samples_fine, det=False):
    """Hierarchical sampling: sample points from a PDF defined by coarse weights.

    Args:
        bins: Midpoints of the coarse sampling intervals (batch_size, H, W, num_coarse_samples-1).
        weights: Weights from the coarse rendering (batch_size, H, W, num_coarse_samples_for_pdf).
                 Should correspond to the intervals in 'bins', so typically weights[..., 1:-1].
        num_samples_fine: Number of fine samples per ray.
        det: If True, use deterministic sampling (linspace). Otherwise, random.

    Returns:
        Fine samples t_vals (batch_size, H, W, num_samples_fine).
    """
    weights = weights + 1e-5
    pdf = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)
    cdf = tf.cumsum(pdf, axis=-1)
    cdf = tf.concat([tf.zeros_like(cdf[..., :1]), cdf], axis=-1)

    batch_dims_shape = tf.shape(cdf)[:-1]
    if det:
        u_shape = tf.concat([batch_dims_shape, [num_samples_fine]], axis=0)
        u = tf.linspace(0.0, 1.0, num_samples_fine)
        u = tf.broadcast_to(u, u_shape)
    else:
        u_shape = tf.concat([batch_dims_shape, [num_samples_fine]], axis=0)
        u = tf.random.uniform(shape=u_shape, dtype=pdf.dtype) # Ensure dtype matches

    # tf.searchsorted expects cdf and u to be sorted. u is sorted if det=True.
    # If u is random, it's not necessarily sorted but searchsorted still gives meaningful indices for inv transform.
    inds = tf.searchsorted(cdf, u, side='right')

    below = tf.maximum(0, inds - 1)
    above = tf.minimum(tf.shape(cdf)[-1] - 1, inds) # cdf has an extra 0 at the beginning
    inds_g = tf.stack([below, above], axis=-1)

    # batch_dims for gather should be rank(input_tensor) - 1
    # For cdf (B,H,W,N_cdf), batch_dims_rank_cdf = 3
    # For bins (B,H,W,N_bins), batch_dims_rank_bins = 3
    rank_cdf = tf.rank(cdf)
    cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=rank_cdf -1)

    rank_bins = tf.rank(bins)
    bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=rank_bins -1)


    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = tf.where(denom < 1e-5, tf.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def render_nerf_hierarchical(
    coarse_model, fine_model, ray_origins, ray_directions, t_vals_coarse,
    num_samples_fine, rand_fine_sampling=True):
    """
    Renders rays using a hierarchical sampling approach (coarse and fine models).
    Args:
        coarse_model: The coarse NeRF Keras model.
        fine_model: The fine NeRF Keras model.
        ray_origins: Ray origins (batch_size, H, W, 3).
        ray_directions: Ray directions (batch_size, H, W, 3).
        t_vals_coarse: Coarse sampling t_vals (batch_size, H, W, num_samples_coarse).
        num_samples_fine: Number of samples for the fine model.
        rand_fine_sampling: Whether to use random sampling for fine stage during training.
    Returns:
        A dictionary containing renderings.
    """
    batch_size = tf.shape(ray_origins)[0]
    h_dim = tf.shape(ray_origins)[1]
    w_dim = tf.shape(ray_origins)[2]
    num_s_coarse = tf.shape(t_vals_coarse)[-1]

    # 1. COARSE MODEL RENDERING
    pts_coarse = ray_origins[..., None, :] + ray_directions[..., None, :] * t_vals_coarse[..., None]
    dirs_coarse = tf.broadcast_to(ray_directions[..., None, :], tf.shape(pts_coarse))

    pts_coarse_flat = tf.reshape(pts_coarse, [-1, 3])
    dirs_coarse_flat = tf.reshape(dirs_coarse, [-1, 3])

    encoded_pts_coarse = encode_position(pts_coarse_flat, POS_ENCODE_DIMS_XYZ)
    encoded_dirs_coarse = encode_position(dirs_coarse_flat, POS_ENCODE_DIMS_DIR)

    raw_output_coarse = coarse_model([encoded_pts_coarse, encoded_dirs_coarse], training=True) # Pass training flag

    rgb_map_coarse, depth_map_coarse, weights_coarse = volumetric_rendering(
        raw_output_coarse, t_vals_coarse, batch_size, h_dim, w_dim, num_s_coarse
    )

    # 2. HIERARCHICAL SAMPLING FOR FINE MODEL
    mid_t_vals_coarse = 0.5 * (t_vals_coarse[..., :-1] + t_vals_coarse[..., 1:])

    # Pass weights corresponding to the midpoints/intervals.
    # weights_coarse has shape (B, H, W, num_s_coarse).
    # mid_t_vals_coarse has shape (B, H, W, num_s_coarse - 1).
    # The weights for PDF should be for these intervals.
    # Original NeRF uses weights_coarse[..., 1:-1] if bins are z_vals[...,1:-1].
    # Here, bins are mid_t_vals_coarse (Nc-1 intervals).
    # weights_coarse for these Nc-1 intervals are often taken as weights_coarse[..., :-1] or a similar slice.
    # The sample_pdf function expects weights to match the number of bins.
    # If bins = mid_t_vals_coarse (Nc-1), then weights for pdf should be (Nc-1).
    # A common choice is weights_coarse[..., 1:-1] (Nc-2 items) if bins are also adjusted,
    # or use weights_coarse[..., :-1] if it makes sense for (Nc-1) bins.
    # Given the original implementation of sample_pdf, it often uses weights[..., 1:-1] for Nc samples,
    # implying Nc-2 bins/PDF entries. Let's align with the provided sample_pdf which uses weights as is,
    # but it's typically a slice. If sample_pdf internally handles it, or if weights_coarse[..., 1:-1] is intended
    # for the input `weights` argument to `sample_pdf`.
    # The `sample_pdf` you provided previously uses `weights` directly (e.g. `weights_coarse[..., 1:-1]`)
    # which would be (B, H, W, N_coarse - 2) for the PDF.
    # And `bins` would be `mid_t_vals_coarse` (B, H, W, N_coarse - 1). There is a mismatch.
    # Let's assume the `weights` passed to `sample_pdf` should match the number of `bins`.
    # The most robust way from original NeRF is to use z_vals (t_vals) for bins in sample_pdf and the weights for those.
    # For simplicity with your `sample_pdf` structure which expects `bins` as midpoints:
    # We need weights for these midpoints. The `weights_coarse` are for `t_vals_coarse`.
    # Let's use `weights_coarse[..., :-1]` as an approximation for the weights of the intervals defined by `mid_t_vals_coarse`.
    # Or, more commonly, the `weights_coarse` excluding the very first and last are used for the PDF over the *intervals*.
    # So, `weights_coarse[..., 1:-1]` is standard for PDF sampling. `bins` should then align.
    # The provided `sample_pdf` uses `weights` as input, then computes PDF.
    # If `bins` is `mid_t_vals_coarse` (Nc-1 elements), and we want PDF over these intervals.
    # `weights_coarse` has Nc elements.
    # A common approach: `pdf_weights = weights_coarse[..., 1:-1]`.
    # `pdf_bins = 0.5 * (t_vals_coarse[..., :-2] + t_vals_coarse[..., 1:-1])` - this would also be Nc-2.
    # Let's stick to the provided sample_pdf and assume `weights_coarse[..., 1:-1]` is the intended input for `weights` argument in `sample_pdf`.

    t_vals_fine_sampled = sample_pdf(
        bins=mid_t_vals_coarse, # (B, H, W, Nc-1)
        weights=weights_coarse[..., 1:-1], # (B, H, W, Nc-2), ensure sample_pdf handles this or adjust bins
        num_samples_fine=num_samples_fine,
        det=not rand_fine_sampling
    )

    t_vals_all = tf.sort(tf.concat([t_vals_coarse, t_vals_fine_sampled], axis=-1), axis=-1)
    num_s_all = tf.shape(t_vals_all)[-1] # num_s_coarse + num_samples_fine (potentially, if sample_pdf handles shapes)


    # 3. FINE MODEL RENDERING
    pts_fine = ray_origins[..., None, :] + ray_directions[..., None, :] * t_vals_all[..., None]
    dirs_fine = tf.broadcast_to(ray_directions[..., None, :], tf.shape(pts_fine))

    pts_fine_flat = tf.reshape(pts_fine, [-1, 3])
    dirs_fine_flat = tf.reshape(dirs_fine, [-1, 3])

    encoded_pts_fine = encode_position(pts_fine_flat, POS_ENCODE_DIMS_XYZ)
    encoded_dirs_fine = encode_position(dirs_fine_flat, POS_ENCODE_DIMS_DIR)

    raw_output_fine = fine_model([encoded_pts_fine, encoded_dirs_fine], training=True) # Pass training flag

    rgb_map_fine, depth_map_fine, _ = volumetric_rendering(
        raw_output_fine, t_vals_all, batch_size, h_dim, w_dim, num_s_all
    )

    return {
        "rgb_coarse": rgb_map_coarse,
        "depth_coarse": depth_map_coarse,
        "rgb_fine": rgb_map_fine,
        "depth_fine": depth_map_fine,
        "weights_coarse": weights_coarse, # Return for debugging or other uses
    }

## Training

The training step is implemented as part of a custom `keras.Model` subclass
so that we can make use of the `model.fit` functionality.

In [None]:
import tensorflow as tf
import keras
from keras import layers
import numpy as np
import os # For path creation
import glob # For create_gif
import imageio.v2 as imageio # For create_gif
from tqdm import tqdm # For create_gif
import matplotlib.pyplot as plt # For TrainMonitor

# Assume get_nerf_model, render_nerf_hierarchical, POS_ENCODE_DIMS_XYZ, etc.
# are defined above this block as discussed.
# For example:
POS_ENCODE_DIMS_XYZ = 10
POS_ENCODE_DIMS_DIR = 4
NUM_SAMPLES_FINE = 128
EPOCHS = 20 # Or your desired value
BATCH_SIZE = 5 # Or your desired value


class NeRFModel(keras.Model):
    def __init__(self, coarse_mlp, fine_mlp, num_samples_fine):
        super().__init__()
        self.coarse_mlp = coarse_mlp
        self.fine_mlp = fine_mlp
        self.num_samples_fine = num_samples_fine
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.psnr_tracker = keras.metrics.Mean(name="psnr") # PSNR on fine model output

    # Modified compile to accept a single optimizer
    def compile(self, optimizer, loss_fn):
        super().compile()
        self.optimizer = optimizer # Assign the single optimizer
        self.loss_fn = loss_fn

    @property
    def metrics(self):
        return [self.loss_tracker, self.psnr_tracker]

    def train_step(self, inputs):
        # `inputs` from tf.data pipeline: (images_batch, (rays_o_batch, rays_d_batch, t_vals_coarse_batch))
        target_images, (rays_o, rays_d, t_vals_coarse) = inputs

        with tf.GradientTape() as tape:
            # Render with hierarchical sampling
            # Pass training=True to render_nerf_hierarchical
            renderings = render_nerf_hierarchical(
                self.coarse_mlp, self.fine_mlp, rays_o, rays_d, t_vals_coarse,
                self.num_samples_fine,
                rand_fine_sampling=True, # Random fine sampling for training
                training=True # Explicitly pass training status
            )
            rgb_coarse = renderings["rgb_coarse"]
            rgb_fine = renderings["rgb_fine"]

            # Compute losses
            loss_coarse = self.loss_fn(target_images, rgb_coarse)
            loss_fine = self.loss_fn(target_images, rgb_fine)
            total_loss = loss_coarse + loss_fine # Sum of coarse and fine losses

        # Compute gradients and apply to both models' variables
        trainable_vars = self.coarse_mlp.trainable_variables + self.fine_mlp.trainable_variables
        grads = tape.gradient(total_loss, trainable_vars)
        self.optimizer.apply_gradients(zip(grads, trainable_vars)) # Use the single optimizer

        # Update metrics
        self.loss_tracker.update_state(total_loss)
        psnr = tf.image.psnr(target_images, rgb_fine, max_val=1.0) # PSNR based on fine model
        self.psnr_tracker.update_state(psnr)

        return {"loss": self.loss_tracker.result(), "psnr": self.psnr_tracker.result()}

    def test_step(self, inputs):
        target_images, (rays_o, rays_d, t_vals_coarse) = inputs

        # Pass training=False to render_nerf_hierarchical
        renderings = render_nerf_hierarchical(
            self.coarse_mlp, self.fine_mlp, rays_o, rays_d, t_vals_coarse,
            self.num_samples_fine,
            rand_fine_sampling=False, # Deterministic fine sampling for testing
            training=False # Explicitly pass training status
        )
        rgb_coarse = renderings["rgb_coarse"]
        rgb_fine = renderings["rgb_fine"]

        loss_coarse = self.loss_fn(target_images, rgb_coarse)
        loss_fine = self.loss_fn(target_images, rgb_fine)
        total_loss = loss_coarse + loss_fine

        self.loss_tracker.update_state(total_loss)
        psnr = tf.image.psnr(target_images, rgb_fine, max_val=1.0)
        self.psnr_tracker.update_state(psnr)

        return {"loss": self.loss_tracker.result(), "psnr": self.psnr_tracker.result()}

    def call(self, inputs, training=False): # training default is False for inference
        # `inputs`: (rays_o, rays_d, t_vals_coarse)
        rays_o, rays_d, t_vals_coarse = inputs
        # Pass the `training` flag to render_nerf_hierarchical
        renderings = render_nerf_hierarchical(
            self.coarse_mlp, self.fine_mlp, rays_o, rays_d, t_vals_coarse,
            self.num_samples_fine,
            rand_fine_sampling=training, # rand_fine_sampling depends on training status
            training=training # Explicitly pass training status to sub-models
        )
        # Return fine model output for inference by default
        return renderings["rgb_fine"], renderings["depth_fine"]


# Instantiate models
# Ensure POS_ENCODE_DIMS_XYZ and POS_ENCODE_DIMS_DIR are defined
coarse_nerf_mlp = get_nerf_model(
    num_layers=8, dense_units=256, skip_layer=4,
    pos_encode_dims_xyz=POS_ENCODE_DIMS_XYZ,
    pos_encode_dims_dir=POS_ENCODE_DIMS_DIR
)
fine_nerf_mlp = get_nerf_model(
    num_layers=8, dense_units=256, skip_layer=4,
    pos_encode_dims_xyz=POS_ENCODE_DIMS_XYZ,
    pos_encode_dims_dir=POS_ENCODE_DIMS_DIR
)

# Combined NeRF model
# Ensure NUM_SAMPLES_FINE is defined
nerf_system = NeRFModel(coarse_nerf_mlp, fine_nerf_mlp, NUM_SAMPLES_FINE)

# Compile the model
learning_rate = 5e-4
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_function = keras.losses.MeanSquaredError()

# Compile with the single optimizer
nerf_system.compile(optimizer=optimizer, loss_fn=loss_function)


loss_list_train = [] # For plotting training loss

# Get a single batch from validation set for visualization during training
# Ensure val_ds is defined and is a tf.data.Dataset
val_iter = iter(val_ds)
val_batch_for_viz = next(val_iter)

class TrainMonitor(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs is not None: # logs might be None if training is interrupted early
             loss_list_train.append(logs.get("loss")) # Use .get for safety

        target_images_viz, (rays_o_viz, rays_d_viz, t_vals_coarse_viz) = val_batch_for_viz

        # Call the model directly (invokes `call` method)
        # Pass training=False for inference behavior in the callback
        predicted_rgb_fine, predicted_depth_fine = self.model(
            (rays_o_viz, rays_d_viz, t_vals_coarse_viz), training=False
        )

        # Ensure BATCH_SIZE and EPOCHS are defined for plotting
        num_images_to_show = min(BATCH_SIZE, 3)
        fig, axes = plt.subplots(nrows=num_images_to_show, ncols=4, figsize=(20, 5 * num_images_to_show))
        if num_images_to_show == 1:
            axes = np.array([axes]) # Ensure axes is always 2D for consistent indexing

        for i in range(num_images_to_show):
            ax_row = axes[i] if num_images_to_show > 1 else axes # Handle single row case

            # Ensure target_images_viz[i] is valid for array_to_img
            if hasattr(target_images_viz[i], 'numpy'):
                target_img_display = target_images_viz[i].numpy()
            else:
                target_img_display = target_images_viz[i]
            ax_row[0].imshow(keras.utils.array_to_img(target_img_display))
            ax_row[0].set_title(f"Target Image {i}")

            if hasattr(predicted_rgb_fine[i], 'numpy'):
                rgb_fine_display = predicted_rgb_fine[i].numpy()
            else:
                rgb_fine_display = predicted_rgb_fine[i]
            ax_row[1].imshow(keras.utils.array_to_img(rgb_fine_display))
            ax_row[1].set_title(f"Predicted Fine RGB {i} (Epoch: {epoch:03d})")

            if hasattr(predicted_depth_fine[i], 'numpy'):
                depth_fine_display = predicted_depth_fine[i].numpy()
            else:
                depth_fine_display = predicted_depth_fine[i]
            ax_row[2].imshow(keras.utils.array_to_img(depth_fine_display[..., None]), cmap='inferno')
            ax_row[2].set_title(f"Predicted Fine Depth {i} (Epoch: {epoch:03d})")

        # Plot loss
        loss_ax = axes[0,3] if num_images_to_show > 1 else axes[3]
        loss_ax.plot(loss_list_train)
        loss_ax.set_xticks(np.arange(0, EPOCHS + 1, 5.0 if EPOCHS >=5 else 1.0))
        loss_ax.set_title(f"Training Loss (Epoch: {epoch:03d})")

        if num_images_to_show > 1:
            for j in range(1, num_images_to_show):
                axes[j,3].axis('off') # Hide extra loss plot axes

        save_dir = "images_original_nerf"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fig.savefig(os.path.join(save_dir, f"{epoch:03d}.png"))
        # plt.show() # Uncomment for interactive display if not in a headless environment
        plt.close(fig)


# Create a directory to save the images during training.
if not os.path.exists("images_original_nerf"):
    os.makedirs("images_original_nerf")

print("Starting training...")
# Ensure train_ds, val_ds, EPOCHS are defined.
nerf_system.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[TrainMonitor()]
)
print("Training finished.")


def create_gif(path_to_images_pattern, name_gif):
    filenames = glob.glob(path_to_images_pattern)
    filenames = sorted(filenames)
    generated_images = []
    if not filenames:
        print(f"No images found for pattern {path_to_images_pattern}")
        return

    for filename in tqdm(filenames):
        try:
            generated_images.append(imageio.imread(filename))
        except Exception as e:
            print(f"Error reading {filename}: {e}")
            continue

    if not generated_images:
        print(f"No valid images were read to create GIF.")
        return

    kargs = {"duration": 0.25, "loop": 0}
    try:
        imageio.mimsave(name_gif, generated_images, "GIF", **kargs)
        print(f"GIF saved as {name_gif}")
    except Exception as e:
        print(f"Error creating GIF: {e}")


create_gif("images_original_nerf/*.png", "training_original_nerf.gif")

ValueError: Cannot convert '99' to a shape.

In [None]:
# --- Inference and Video Generation (Similar to provided code, adapted for the new model) ---

# Get a batch from the test set for inference visualization
test_imgs_viz, (test_rays_o_viz, test_rays_d_viz, test_t_vals_coarse_viz) = next(iter(val_ds)) # Using val_ds as test here

# Infer with the trained model
recons_images_fine, depth_maps_fine = nerf_system(
    (test_rays_o_viz, test_rays_d_viz, test_t_vals_coarse_viz), training=False
)

# Create subplots for visualization
num_viz_samples = min(BATCH_SIZE, 5)
fig, axes = plt.subplots(nrows=num_viz_samples, ncols=3, figsize=(10, 4 * num_viz_samples))
if num_viz_samples == 1:
    axes = np.array([axes])


for i in range(num_viz_samples):
    ax_row = axes[i]
    ax_row[0].imshow(keras.utils.array_to_img(test_imgs_viz[i]))
    ax_row[0].set_title("Original")

    ax_row[1].imshow(keras.utils.array_to_img(recons_images_fine[i]))
    ax_row[1].set_title("Reconstructed (Fine)")

    ax_row[2].imshow(keras.utils.array_to_img(depth_maps_fine[i, ..., None]), cmap="inferno")
    ax_row[2].set_title("Depth Map (Fine)")
    ax_row[2].axes.get_xaxis().set_visible(False)
    ax_row[2].axes.get_yaxis().set_visible(False)


plt.tight_layout()
# plt.show() # 로컬에서 실행 시 주석 해제
plt.savefig("inference_visualization_original_nerf.png")
print("Inference visualization saved as inference_visualization_original_nerf.png")
plt.close(fig)

# Functions for novel view synthesis video (can be adapted from your original code)
def get_translation_t(t):
    matrix = [
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, t],
        [0, 0, 0, 1],
    ]
    return tf.convert_to_tensor(matrix, dtype=tf.float32)

def get_rotation_phi(phi): # around X-axis
    matrix = [
        [1, 0, 0, 0],
        [0, tf.cos(phi), -tf.sin(phi), 0],
        [0, tf.sin(phi), tf.cos(phi), 0],
        [0, 0, 0, 1],
    ]
    return tf.convert_to_tensor(matrix, dtype=tf.float32)

def get_rotation_theta(theta): # around Y-axis
    matrix = [
        [tf.cos(theta), 0, -tf.sin(theta), 0],
        [0, 1, 0, 0],
        [tf.sin(theta), 0, tf.cos(theta), 0],
        [0, 0, 0, 1],
    ]
    return tf.convert_to_tensor(matrix, dtype=tf.float32)

def pose_spherical(theta, phi, radius):
    c2w = get_translation_t(radius)
    c2w = get_rotation_phi(phi / 180.0 * np.pi) @ c2w
    c2w = get_rotation_theta(theta / 180.0 * np.pi) @ c2w
    # This is a common transformation to align camera axes with standard conventions
    c2w = tf.convert_to_tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]), dtype=tf.float32) @ c2w
    return c2w


print("Generating novel view synthesis video...")
rgb_frames = []
num_video_frames = 60 # Fewer frames for quicker generation for testing
video_batch_size = BATCH_SIZE # Process in batches if memory is an issue

for i in tqdm(range(0, num_video_frames, video_batch_size)):
    current_batch_size = min(video_batch_size, num_video_frames - i)
    batch_rays_o = []
    batch_rays_d = []
    batch_t_vals_coarse = []

    for j in range(current_batch_size):
        theta = (i + j) * (360.0 / num_video_frames)
        c2w = pose_spherical(theta, -30.0, 4.0) # Example view parameters

        ray_origins_single, ray_directions_single = get_rays(H, W, focal, c2w)
        # Expand dims to simulate batch for render_flat_rays if it expects it,
        # or adjust render_flat_rays. Here, get_rays already returns unbatched HxWx3.
        # We will batch them *after* this loop.

        # Coarse t_vals for this single pose (needs to be shaped for the model later)
        # Here, we ensure t_vals are generated for a single "image" view.
        # The render_flat_rays will produce t_vals of shape (H, W, N_samples)
        # We need to add a batch dim for the model.
        _, _, t_vals_coarse_single = render_flat_rays(
            ray_origins_single[None,...], ray_directions_single[None,...], # Add batch dim for this call
            near=2.0, far=6.0,
            num_samples=NUM_SAMPLES_COARSE,
            rand=False # No random sampling for video generation
        )
        # t_vals_coarse_single is (1, H, W, N_samples), remove batch for list append
        batch_rays_o.append(ray_origins_single)
        batch_rays_d.append(ray_directions_single)
        batch_t_vals_coarse.append(tf.squeeze(t_vals_coarse_single, axis=0))


    # Stack to create a batch
    rays_o_batch_video = tf.stack(batch_rays_o, axis=0) # (current_batch_size, H, W, 3)
    rays_d_batch_video = tf.stack(batch_rays_d, axis=0) # (current_batch_size, H, W, 3)
    t_vals_coarse_batch_video = tf.stack(batch_t_vals_coarse, axis=0) # (current_batch_size, H, W, N_coarse)

    # Infer with the model
    rgb_fine_video_batch, _ = nerf_system(
        (rays_o_batch_video, rays_d_batch_video, t_vals_coarse_batch_video), training=False
    )

    # Process and store frames
    for k in range(current_batch_size):
        img_np = np.clip(255 * rgb_fine_video_batch[k].numpy(), 0.0, 255.0).astype(np.uint8)
        rgb_frames.append(img_np)


if rgb_frames:
    rgb_video_path = "rgb_video_original_nerf.mp4"
    imageio.mimwrite(rgb_video_path, rgb_frames, fps=30, quality=8, macro_block_size=1) # quality 1-10
    print(f"Novel view video saved as {rgb_video_path}")
else:
    print("No frames generated for the video.")