In [None]:
# Import all the good stuff
from typing import Optional

import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
def meshgrid_xy(tensor1: torch.Tensor, tensor2: torch.Tensor) -> (torch.Tensor, torch.Tensor):
  
    # TESTED
    ii, jj = torch.meshgrid(tensor1, tensor2)
    return ii.transpose(-1, -2), jj.transpose(-1, -2)


def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
  # Only works for the last dimension (dim=-1)
  dim = -1
  # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
  cumprod = torch.cumprod(tensor, dim)
  # "Roll" the elements along dimension 'dim' by 1 element.
  cumprod = torch.roll(cumprod, 1, dim)
  # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
  cumprod[..., 0] = 1.

  return cumprod

In [None]:
def get_ray_bundle(height: int, width: int, focal_length: float, tform_cam2world: torch.Tensor):
  ii, jj = meshgrid_xy(
      torch.arange(width).to(tform_cam2world),
      torch.arange(height).to(tform_cam2world)
  )
  directions = torch.stack([(ii - width * .5) / focal_length,
                            -(jj - height * .5) / focal_length,
                            -torch.ones_like(ii)
                           ], dim=-1)

  ray_directions = torch.sum(directions[..., None, :] * tform_cam2world[:3, :3], dim=-1)
  ray_origins = tform_cam2world[:3, -1].expand(ray_directions.shape)
  return ray_origins, ray_directions

In [None]:
def compute_query_points_from_rays(
    ray_origins: torch.Tensor,
    ray_directions: torch.Tensor,
    near_thresh: float,
    far_thresh: float,
    num_samples: int,
    randomize: Optional[bool] = True
) -> (torch.Tensor, torch.Tensor):

  # shape: (num_samples)
  depth_values = torch.linspace(near_thresh, far_thresh, num_samples).to(ray_origins)
  if randomize is True:
    # ray_origins: (width, height, 3)
    # noise_shape = (width, height, num_samples)
    noise_shape = list(ray_origins.shape[:-1]) + [num_samples]
    # depth_values: (num_samples)
    depth_values = depth_values \
        + torch.rand(noise_shape).to(ray_origins) * (far_thresh
            - near_thresh) / num_samples
  # (width, height, num_samples, 3) = (width, height, 1, 3) + (width, height, 1, 3) * (num_samples, 1)
  # query_points:  (width, height, num_samples, 3)
  query_points = ray_origins[..., None, :] + ray_directions[..., None, :] * depth_values[..., :, None]
  # TODO: Double-check that `depth_values` returned is of shape `(num_samples)`.
  return query_points, depth_values

In [None]:
def render_volume_density(
    radiance_field: torch.Tensor,
    ray_origins: torch.Tensor,
    depth_values: torch.Tensor
) -> (torch.Tensor, torch.Tensor, torch.Tensor):

  sigma_a = torch.nn.functional.relu(radiance_field[..., 3])
  rgb = torch.sigmoid(radiance_field[..., :3])
  one_e_10 = torch.tensor([1e10], dtype=ray_origins.dtype, device=ray_origins.device)
  dists = torch.cat((depth_values[..., 1:] - depth_values[..., :-1],
                  one_e_10.expand(depth_values[..., :1].shape)), dim=-1)
  alpha = 1. - torch.exp(-sigma_a * dists)
  weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)

  rgb_map = (weights[..., None] * rgb).sum(dim=-2)
  depth_map = (weights * depth_values).sum(dim=-1)
  acc_map = weights.sum(-1)

  return rgb_map, depth_map, acc_map

In [None]:
def positional_encoding(
    tensor, num_encoding_functions=6, include_input=True, log_sampling=True
) -> torch.Tensor:

  # Trivially, the input tensor is added to the positional encoding.
  encoding = [tensor] if include_input else []
  # Now, encode the input using a set of high-frequency functions and append the
  # resulting values to the encoding.
  frequency_bands = None
  if log_sampling:
      frequency_bands = 2.0 ** torch.linspace(
            0.0,
            num_encoding_functions - 1,
            num_encoding_functions,
            dtype=tensor.dtype,
            device=tensor.device,
        )
  else:
      frequency_bands = torch.linspace(
          2.0 ** 0.0,
          2.0 ** (num_encoding_functions - 1),
          num_encoding_functions,
          dtype=tensor.dtype,
          device=tensor.device,
      )

  for freq in frequency_bands:
      for func in [torch.sin, torch.cos]:
          encoding.append(func(tensor * freq))

  # Special case, for no positional encoding
  if len(encoding) == 1:
      return encoding[0]
  else:
      return torch.cat(encoding, dim=-1)

In [None]:
class VeryTinyNerfModel(torch.nn.Module):
  r"""Define a "very tiny" NeRF model comprising three fully connected layers.
  """
  def __init__(self, filter_size=128, num_encoding_functions=6):
    super(VeryTinyNerfModel, self).__init__()
    # Input layer (default: 39 -> 128)
    self.layer1 = torch.nn.Linear(3 + 3 * 2 * num_encoding_functions, filter_size)
    # Layer 2 (default: 128 -> 128)
    self.layer2 = torch.nn.Linear(filter_size, filter_size)
    # Layer 3 (default: 128 -> 4)
    self.layer3 = torch.nn.Linear(filter_size, 4)
    # Short hand for torch.nn.functional.relu
    self.relu = torch.nn.functional.relu

  def forward(self, x):
    x = self.relu(self.layer1(x))
    x = self.relu(self.layer2(x))
    x = self.layer3(x)
    return x

class EnhancedNerfModel(torch.nn.Module):
    """Define an enhanced NeRF model with additional layers and increased filter size for better learning capacity."""
    def __init__(self, filter_size=256, num_encoding_functions=6):
        super(EnhancedNerfModel, self).__init__()
        # Input layer: Encoded positions are expanded with positional encoding
        self.input_layer = torch.nn.Linear(3 + 3 * 2 * num_encoding_functions, filter_size)
        
        # Adding more layers to increase model complexity
        self.layer1 = torch.nn.Linear(filter_size, filter_size)
        self.layer2 = torch.nn.Linear(filter_size, filter_size)
        self.layer3 = torch.nn.Linear(filter_size, filter_size)  # Additional layer
        self.layer4 = torch.nn.Linear(filter_size, filter_size)  # Additional layer
        
        # Output layer: Produces RGB and density (sigma)
        self.output_layer = torch.nn.Linear(filter_size, 4)
        
        # Activation function
        self.relu = torch.nn.functional.relu

    def forward(self, x):
        x = self.relu(self.input_layer(x))
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.relu(self.layer3(x))  # Additional ReLU activation for new layers
        x = self.relu(self.layer4(x))  # Additional ReLU activation for new layers
        x = self.output_layer(x)
        return x

In [None]:
def get_minibatches(inputs: torch.Tensor, chunksize: Optional[int] = 1024 * 8):

  return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]

In [None]:
# Download sample data used in the official tiny_nerf example
import os
if not os.path.exists('tiny_nerf_data.npz'):
    !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load input images, poses, and intrinsics
data = np.load("tiny_nerf_data.npz")

# Images
images = data["images"]
# Camera extrinsics (poses)
tform_cam2world = data["poses"]
tform_cam2world = torch.from_numpy(tform_cam2world).to(device)
# Focal length (intrinsics)
focal_length = data["focal"]
focal_length = torch.from_numpy(focal_length).to(device)

# Height and width of each image
height, width = images.shape[1:3]

# Near and far clipping thresholds for depth values.
near_thresh = 2.
far_thresh = 6.

test_image_idx = 99 # test image for evaluation

# Hold one image out (for test).
testimg, testpose = images[test_image_idx], tform_cam2world[test_image_idx]
testimg = torch.from_numpy(testimg).to(device)

# Map images to device
images = torch.from_numpy(images[:100, ..., :3]).to(device)

In [None]:
plt.imshow(testimg.detach().cpu().numpy())
plt.show()

In [None]:
# One iteration of TinyNeRF (forward pass).
def run_one_iter_of_tinynerf(height, width, focal_length, tform_cam2world,
                             near_thresh, far_thresh, depth_samples_per_ray,
                             encoding_function, get_minibatches_function):

  # Get the "bundle" of rays through all image pixels.
  ray_origins, ray_directions = get_ray_bundle(height, width, focal_length,
                                               tform_cam2world)

  # Sample query points along each ray
  query_points, depth_values = compute_query_points_from_rays(
      ray_origins, ray_directions, near_thresh, far_thresh, depth_samples_per_ray
  )

  # "Flatten" the query points.
  flattened_query_points = query_points.reshape((-1, 3))

  # Encode the query points (default: positional encoding).
  encoded_points = encoding_function(flattened_query_points)

  # Split the encoded points into "chunks", run the model on all chunks, and
  # concatenate the results (to avoid out-of-memory issues).
  batches = get_minibatches_function(encoded_points, chunksize=chunksize)
  predictions = []
  for batch in batches:
    predictions.append(model(batch))
  radiance_field_flattened = torch.cat(predictions, dim=0)

  # "Unflatten" to obtain the radiance field.
  unflattened_shape = list(query_points.shape[:-1]) + [4]
  radiance_field = torch.reshape(radiance_field_flattened, unflattened_shape)

  # Perform differentiable volume rendering to re-synthesize the RGB image.
  rgb_predicted, _, _ = render_volume_density(radiance_field, ray_origins, depth_values)

  return rgb_predicted

In [None]:
"""
Parameters for TinyNeRF training
"""

# Number of functions used in the positional encoding 
num_encoding_functions = 6
# Specify encoding function.
encode = lambda x: positional_encoding(x, num_encoding_functions=num_encoding_functions)
# Number of depth samples along each ray.
depth_samples_per_ray = 32

# Chunksize
chunksize = 16384 
test_image_idx = 97 # test image for evaluation

# Optimizer parameters
lr = 1e-4
num_iters = 10000


# Misc parameters
display_every = 200  # Number of iters after which stats are displayed

"""
Model
"""
#model = VeryTinyNerfModel(num_encoding_functions=num_encoding_functions)
#model.to(device)

model = EnhancedNerfModel(num_encoding_functions=num_encoding_functions)
model.to(device)

"""
Optimizer
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

"""
Train-Eval-Repeat!
"""

# Seed RNG, for repeatability
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)

# Lists to log metrics etc.
psnrs = []
iternums = []

for i in range(num_iters):

  # Randomly pick an image as the target.
  target_img_idx = np.random.randint(images.shape[0])
  target_img = images[target_img_idx].to(device)
  target_tform_cam2world = tform_cam2world[target_img_idx].to(device)

  # Run one iteration of TinyNeRF and get the rendered RGB image.
  rgb_predicted = run_one_iter_of_tinynerf(height, width, focal_length,
                                           target_tform_cam2world, near_thresh,
                                           far_thresh, depth_samples_per_ray,
                                           encode, get_minibatches)

  # Compute mean-squared error between the predicted and target images. Backprop!
  loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

  # Display images/plots/stats
  if i % display_every == 0:
    # Render the held-out view
    rgb_predicted = run_one_iter_of_tinynerf(height, width, focal_length,
                                             testpose, near_thresh,
                                             far_thresh, depth_samples_per_ray,
                                             encode, get_minibatches)
    loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
    print("Loss:", loss.item())
    psnr = -10. * torch.log10(loss)

    psnrs.append(psnr.item())
    iternums.append(i)

    plt.figure(figsize=(10, 4))
    plt.subplot(121)
    plt.imshow(rgb_predicted.detach().cpu().numpy())
    plt.title(f"Iteration {i}")
    plt.subplot(122)
    plt.plot(iternums, psnrs)
    plt.title("PSNR")
    plt.show()

print('Done!')

In [None]:
import matplotlib.pyplot as plt

def visualize_comparison(original_img, rendered_img, title1='Original Image', title2='Rendered Image from Estimated Pose'):
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    plt.imshow(original_img.detach().cpu().numpy())
    plt.title(title1)
    plt.axis('off')

    plt.subplot(122)
    plt.imshow(rendered_img.detach().cpu().numpy())
    plt.title(title2)
    plt.axis('off')
    plt.show()

def plot_loss_curve(losses, best_loss_idx):
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Loss')
    plt.axvline(x=best_loss_idx, color='r', linestyle='--', label='Best Loss')
    plt.title('Loss Curve')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

def visualize_pose_estimation(target_img, initial_pose, best_pose, model, encode, get_minibatches, height, width, focal_length, near_thresh, far_thresh, depth_samples_per_ray):
    # Render images using initial and best poses
    rgb_initial = run_one_iter_of_tinynerf(height, width, focal_length, initial_pose,
                                           near_thresh, far_thresh, depth_samples_per_ray,
                                           encode, get_minibatches)
    
    rgb_best = run_one_iter_of_tinynerf(height, width, focal_length, best_pose,
                                        near_thresh, far_thresh, depth_samples_per_ray,
                                        encode, get_minibatches)
    
    # Plot the target image, initial rendered image, and best rendered image
    plt.figure(figsize=(18, 6))
    plt.subplot(131)
    plt.imshow(target_img.detach().cpu().numpy())
    plt.title('Target Image')
    plt.axis('off')
    
    plt.subplot(132)
    plt.imshow(rgb_initial.detach().cpu().numpy())
    plt.title('Rendered Image from Initial Pose')
    plt.axis('off')
    
    plt.subplot(133)
    plt.imshow(rgb_best.detach().cpu().numpy())
    plt.title('Rendered Image from Best Pose')
    plt.axis('off')
    
    plt.show()

# pose estimation and visualization
initial_img_idx = 5  # Index for the initial guess image
target_img_idx = 10   # Index for the target image

initial_img = images[initial_img_idx].to(device)
initial_pose = tform_cam2world[initial_img_idx].clone()

target_img = images[target_img_idx].to(device)
target_pose = tform_cam2world[target_img_idx].clone()  # Not used directly in estimation, just for reference

# Estimate the pose by comparing the NeRF rendering from `initial_pose` to `target_img`
losses = []
best_loss = float('inf')
best_pose = None
best_loss_idx = -1
iters = 1000
def estimate_pose_with_early_stopping(target_img, initial_pose, model, encode, get_minibatches, height, width, focal_length, near_thresh, far_thresh, depth_samples_per_ray, num_iters=200, lr=0.01, patience=50):
    global best_loss, best_pose, best_loss_idx
    pose = initial_pose.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([pose], lr=lr)
    early_stop_counter = 0
    
    for i in range(iters):
        optimizer.zero_grad()
        
        # Render the image using the current pose
        rgb_predicted = run_one_iter_of_tinynerf(height, width, focal_length, pose,
                                                 near_thresh, far_thresh, depth_samples_per_ray,
                                                 encode, get_minibatches)
        
        # Compute the loss
        loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
        losses.append(loss.item())
        
        # Check for improvement
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_pose = pose.clone().detach()
            best_loss_idx = i
            early_stop_counter = 0
        else:
            early_stop_counter += 1
        
        # Early stopping
        if early_stop_counter >= patience:
            print(f"Early stopping at iteration {i} with best loss {best_loss}")
            break
        
        # Backpropagate the loss and update the pose
        loss.backward()
        optimizer.step()
        
        # Print and visualize the intermediate results every 50 iterations
        if i % 50 == 0:
            print(f"Iteration {i}, Loss: {loss.item()}")
            visualize_comparison(target_img, rgb_predicted, title1='Target Image', title2=f'Rendered Image at Iteration {i}')
    
    return best_pose

estimated_pose = estimate_pose_with_early_stopping(target_img, initial_pose, model, encode, get_minibatches, height, width, focal_length, near_thresh, far_thresh, depth_samples_per_ray, num_iters=200)

# Visualize the final comparison between the target image and the rendered image from the estimated pose
visualize_comparison(target_img, run_one_iter_of_tinynerf(height, width, focal_length, estimated_pose, near_thresh, far_thresh, depth_samples_per_ray, encode, get_minibatches))

# Visualize the entire process
visualize_pose_estimation(target_img, initial_pose, estimated_pose, model, encode, get_minibatches, height, width, focal_length, near_thresh, far_thresh, depth_samples_per_ray)

# Plot the loss curve with the best loss indicated
plot_loss_curve(losses, best_loss_idx)
