In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import matplotlib

import nerf_utils
import Datasets

# Neural Radiance Fields (NeRF)
Contributors: Onur Bagoren, Hendrik Dreger, Wuao Liu

## Load the Datasets
### Load the chair dataset

In [None]:
!mkdir datasets/chair
!mkdir datasets/lego

## Chair Dataset
This dataset is one of the datasets tested by the original NeRF paper.
In the next cells, we will load and visualize some of the images from the dataset.

In [None]:
# Enter the root directory of the dataset
root_dir = f'{sys.path[0]}/datasets/chair'
chairs = Datasets.ChairData(root_dir, mode='train')
chairs_loader = torch.utils.data.DataLoader(dataset=chairs, batch_size=1, shuffle=True)

In [None]:
# Display the first 25 images on a 5,5 grid
fig, axs = plt.subplots(5, 5, figsize=(15, 15))
for ii, (image, _, _) in enumerate(chairs_loader):
    if ii == 25:
        break
    row = ii // 5
    col = ii % 5
    axs[row, col].imshow(image[0])
    axs[row, col].axis('off')

## Implementing NeRF
The NeRF algorithm is a method that aims to minimize the rendering error between images taken of the same object, from multiple viewpoints.

In order to do this, the NeRf algorithm represents the state of the object as a 5-dimensional state $\mathbf{X}$, where $(x, y, z)$ represents the camera position in the world frame, and $(\theta, \phi)$ represent the orientation of the camera.

\begin{align}
\mathbf{X} = 
\begin{bmatrix}
x \\
y \\
z \\
\theta \\
\phi
\end{bmatrix}
\end{align}

In particular, the input can be though of consisting input position $\mathbf{x} = \left(x, y, z\right)$ and direction $\mathbf{d} = \left(\theta, \phi\right)$ of the camera.

The ouput of the model is the volume function $\sigma_i(t)$ and color $c_i(t)$ of points along the ray that projects from the camera to the object. The volume function represents the probability of the ray terminating at the point $t$.

### The model architecture
The architecture of the model is a multi-layer perceptron (MLP) with a ReLU activation functions. At certain layers, an encoded version of the input position $\mathbf{x}$ or input direction $\mathbf{d}$ is concatenated to a layer. A representation of the architecture is shown below.

![Nerf-arch](images/NERF_arch.png)

### Positional Encoding
The positional encoding is a method to represent the input as a high-frequency function. 
A high frequency function is necessary, as an image is a high-frequency function, with frequent changes of the color gradients across the pixels. If the input to the network is not a high-frequency function, the network will struggle to learn how to represent the rendered image as a high-frequency function as well, struggling to capture fine textures and minute patterns.

The paper uses $\sin(x)$ and $\cos(x)$ for the basis functions for the positional encoding. The number of functions used for positional encoding dictates the frequency of the encoded input, such that the frequency is directly proportional to $N^2$, where $N$ is the number of functions.

In [None]:
# A demo of the positional encoding function used to represent an image

## Ray-casting and sampling of query points

In order to render the scene, it is necessary to first represent how the camera projects the scene onto the image plane. In conventinal graphics methods, this is done by projecting the scene onto the image plane using a perspective projection. However, in the NeRF algorithm, the camera is not perspective, and the projection is done by projecting the scene onto the image plane using a ray-casting method.

The code below is a visualization of the function that does this ray cast and query point sampling. After running the cell, two figures will appear:
1. The ray-cast in the camera frame
2. The The ray cast in the worlds frame

In both figures, the image plane is shown as a blue plane, the rays that go from the image plane to the camera pinhole (also referred to as the ray origin) are shown as red lines. The casted ray is then shown as a composition of yellow and blue lines. The blue line indicates the part of the ray where the query points are sampled from, called the query vector. The blue line begins at the "near threshold" and ends at the "far threshold". The yellow line represents the part of the ray that is cast from the ray origin to the beginning of the query vector.

In [None]:
from visualize import visualize_ray_casting
from scipy.spatial.transform import Rotation as R
# Enter values to change the visualization
roll = 0
pitch = -torch.pi/16
yaw = -torch.pi/16
rot = R.from_euler('xyz', [roll, pitch, yaw], degrees=False)
rot_mat = torch.tensor(rot.as_matrix())
trans = torch.tensor([1, 0, 0])
c2w = torch.eye(4)
c2w[:3,:3] = rot_mat
c2w[:3,3] = trans

near_threshold = 1
far_threshold = 20
num_samples = 10
H, W = (10, 10)
fig = plt.figure(figsize=(20,20))
visualize_ray_casting(H, W, near_threshold, far_threshold, num_samples, c2w, fig)
# plt.show()

### Generating query points / representing the ray as a function for the model
The ray-casting method is used to generate query points for the model. The query points are generated by sampling the ray-casting method at a certain number of points.

## Volume Rendering
As an output from our model, we are receiving a color and a volume density at each depth point along the ray.
We want to use these values and compute the color for each pixel in the image.
We are iterating through each depth value along the image and compute the accumulated transmittance (T) and opacity (alpha). We then multiply them with the output color for the current depth position and sum over all  values. This gives us the predicted color for each pixel. The computation for the volume rendering is given below:
\begin{align}

\hat{C}(r) &= \sum_{i=1}^{N} T_i\left(1 - \exp\left(-\sigma_i \delta_i\right)\right)\mathbf{c}_i \\
T_i &= \exp\left(-\sum_{j=1}^{{i-1}}\sigma_j\delta_j\right)

\end{align}

### Metrics for the model
#### PSNR
PSNR is the metric that evaluates the quality of the output. It is defined as the ratio of the maximum possible value of the output to the mean squared error between the output and the ground truth.

#### Structural Similarity Index (SSIM)
SSIM is a metric that can be used to measure how similar two images are to eachother. We use the SSIM metric that the skimage package provides.

# Changes compared to the original model
In order to reduce the training time for our model, we applied several changes:
- Lower resolution images: Our images are transformed to 100x100 resolution
- Fewer samples per ray: We are using fewer samples per ray during training (32)
- No hierarchical sampling: We are just sampling along each point on the ray (32 points)
- Smaller MLP: We reduced the size of our model. Instead of an 8-layer MLP with 256 hidden units per layer, we have 4 linear layers each with 256 hidden units per layer and do not concatenate the inputs from the positional encoder

# Training on the Chair Dataset

## Loading the dataset

In [None]:
root_dir_chair = f'{sys.path[0]}/datasets/chair'
chairs_train = Datasets.ChairData(root_dir, mode='train')
train_loader_chairs = torch.utils.data.DataLoader(dataset=chairs_train, batch_size=1, shuffle=True)

chairs_test = Datasets.ChairData(root_dir, mode='test')
test_loader_chairs = torch.utils.data.DataLoader(dataset=chairs_test, batch_size=1, shuffle=True)

In [None]:
from nerf_utils import Nerf
num_encoding_functions = 6
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
input_dim = 2 * 3 * num_encoding_functions + 3

model = Nerf(input_dim)
model = model.to(device)

In [None]:
# Set up he optimizer, scheduler and other hyper parameters
lr = 5e-4
num_iterations = int(1e3)
log_interval = 100

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

test_data = Datasets.ChairData(root_dir, mode='test')
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=1, shuffle=True)
fixed_test = True

In [None]:
from nerf_utils import train_synthetic

# Fix the seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)

trained_model_chair, train_losses, test_losses, ssims, psnrs = train_synthetic(
    train_dataloader=train_loader_chairs,
    test_dataloader=test_loader_chairs,
    model=model,
    optimizer=optimizer,
    num_iterations=num_iterations,
    log_interval_iterations=log_interval,
    plot_intervals=num_iterations // 10,
    device=device,
    fixed_test=fixed_test,    
)

# Display the training plots
fig, axs = plt.subplots(1, 3, figsize=(23, 5))
N = len(train_losses)
x = np.linspace(0, num_iterations, N)
axs[0].plot(x, train_losses, label = 'train')
axs[0].set_title('Loss Curves', fontsize=14)
axs[0].plot(x, test_losses, label = 'test')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

axs[1].plot(x, psnrs)
axs[1].set_title('PSNR Curves', fontsize=14)
axs[1].set_xlabel('Iteration')
axs[1].set_ylabel('PSNR')

axs[2].plot(x, ssims)
axs[2].set_title('SSIM Curves', fontsize=14)
axs[2].set_xlabel('Iteration')
axs[2].set_ylabel('SSIM')

# Training on the Lego Dataset
## Load the dataset

In [None]:
from nerf_utils import Nerf
num_encoding_functions = 6
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
input_dim = 2 * 3 * num_encoding_functions + 3

model = Nerf(input_dim)
model = model.to(device)

In [None]:
root_dir_lego = f'{sys.path[0]}/datasets/lego'
train_lego = Datasets.ChairData(root_dir_lego, mode='train')
train_loader_lego = torch.utils.data.DataLoader(dataset=train_lego, batch_size=1, shuffle=True)

test_lego = Datasets.ChairData(root_dir_lego, mode='test')
test_loader_lego = torch.utils.data.DataLoader(dataset=test_lego, batch_size=1, shuffle=True)

In [None]:
# Set up he optimizer, scheduler and other hyper parameters
lr = 5e-4
num_iterations = int(1e3)
log_interval = 100

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

fixed_test = True

In [None]:
from nerf_utils import train_synthetic

trained_model_lego, train_losses, test_losses, ssims, psnrs = train_synthetic(
    train_dataloader=train_loader_lego,
    test_dataloader=test_loader_lego,
    model=model,
    optimizer=optimizer,
    num_iterations=num_iterations,
    log_interval_iterations=log_interval,
    plot_intervals=num_iterations // 10,
    device=device,
    fixed_test=fixed_test,    
)

# Display the training plots
fig, axs = plt.subplots(1, 3, figsize=(23, 5))
N = len(train_losses)
x = np.linspace(0, num_iterations, N)
axs[0].plot(x, train_losses, label='train')
axs[0].set_title('Loss Curves', fontsize=14)
axs[0].plot(x, test_losses, label='test')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

axs[1].plot(x, psnrs)
axs[1].set_title('PSNR Curves', fontsize=14)
axs[1].set_xlabel('Iteration')
axs[1].set_ylabel('PSNR')

axs[2].plot(x, ssims)
axs[2].set_title('SSIM Curves', fontsize=14)
axs[2].set_xlabel('Iteration')
axs[2].set_ylabel('SSIM')

# Different Perspectives from the Training Dataset

#### Chair Dataset

In [None]:
from nerf_utils import generate_network_input, render_volume_density

chair_val = Datasets.ChairData(root_dir_chair, mode='train')
val_loader_chair = torch.utils.data.DataLoader(dataset=chair_val, batch_size=1, shuffle=True)

for ii in range(5):
    val_img, val_pose, val_focal = next(iter(val_loader_chair))
    val_img = val_img[0,...].to(device)
    val_pose = val_pose[0,...].to(device)
    val_focal = val_focal[0].to(device)

    network_input = generate_network_input(val_img, val_focal, val_pose)
    encoded_points, query_points, depth_values = network_input

    model_output = trained_model_chair(encoded_points)

    predicted_rgb = render_volume_density(model_output, query_points, depth_values)

    plotting_rgb = predicted_rgb.cpu().detach().numpy()
    plotting_gt = val_img.cpu().detach().numpy()

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(plotting_gt)
    axs[0].set_title('Original Image')
    axs[1].imshow(plotting_rgb)
    axs[1].set_title('Predicted Image')
    fig.suptitle(f'Train Image {ii+1}')

#### Lego Dataset

In [None]:
from nerf_utils import generate_network_input, render_volume_density

lego_val = Datasets.ChairData(root_dir_lego, mode='train')
val_loader_lego = torch.utils.data.DataLoader(dataset=lego_val, batch_size=1, shuffle=True)

for ii in range(5):
    val_img, val_pose, val_focal = next(iter(val_loader_lego))
    val_img = val_img[0,...].to(device)
    val_pose = val_pose[0,...].to(device)
    val_focal = val_focal[0].to(device)

    network_input = generate_network_input(val_img, val_focal, val_pose)
    encoded_points, query_points, depth_values = network_input

    model_output = trained_model_lego(encoded_points)

    predicted_rgb = render_volume_density(model_output, query_points, depth_values)

    plotting_rgb = predicted_rgb.cpu().detach().numpy()
    plotting_gt = val_img.cpu().detach().numpy()
    
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(plotting_gt)
    axs[0].set_title('Original Image')
    axs[1].imshow(plotting_rgb)
    axs[1].set_title('Predicted Image')
    fig.suptitle(f'Train Image {ii+1}')

# Different Perspectives Validation Dataset

### Chair Dataset

In [None]:
from nerf_utils import generate_network_input, render_volume_density

chair_val = Datasets.ChairData(root_dir_chair, mode='val')
val_loader_chair = torch.utils.data.DataLoader(dataset=chair_val, batch_size=1, shuffle=True)

for ii in range(5):
    val_img, val_pose, val_focal = next(iter(val_loader_chair))
    val_img = val_img[0,...].to(device)
    val_pose = val_pose[0,...].to(device)
    val_focal = val_focal[0].to(device)

    network_input = generate_network_input(val_img, val_focal, val_pose)
    encoded_points, query_points, depth_values = network_input

    model_output = trained_model_chair(encoded_points)

    predicted_rgb = render_volume_density(model_output, query_points, depth_values)

    plotting_rgb = predicted_rgb.cpu().detach().numpy()
    plotting_gt = val_img.cpu().detach().numpy()

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(plotting_gt)
    axs[0].set_title('Original Image')
    axs[1].imshow(plotting_rgb)
    axs[1].set_title('Predicted Image')
    fig.suptitle(f'Validation Image {ii+1}')

## Lego Dataset

In [None]:
from nerf_utils import generate_network_input, render_volume_density

lego_val = Datasets.ChairData(root_dir_lego, mode='val')
val_loader_lego = torch.utils.data.DataLoader(dataset=lego_val, batch_size=1, shuffle=True)
for ii in range(5):
    val_img, val_pose, val_focal = next(iter(val_loader_lego))
    val_img = val_img[0,...].to(device)
    val_pose = val_pose[0,...].to(device)
    val_focal = val_focal[0].to(device)

    network_input = generate_network_input(val_img, val_focal, val_pose)
    encoded_points, query_points, depth_values = network_input

    model_output = trained_model_lego(encoded_points)

    predicted_rgb = render_volume_density(model_output, query_points, depth_values)

    plotting_rgb = predicted_rgb.cpu().detach().numpy()
    plotting_gt = val_img.cpu().detach().numpy()
    
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(plotting_gt)
    axs[0].set_title('Original Image')
    axs[1].imshow(plotting_rgb)
    axs[1].set_title('Predicted Image')
    fig.suptitle(f'Validation Image {ii+1}')

# Results

For both the chair and the lego dataset, we are able to render images from novel perspectives. The loss curve is decreasing, the PSNR increasing and the SSIM is converging towards 1.
Due to time constraints we were unable to run the model in the notebook for a long time, but we were able to run on the chair dataset for 5000 iterations. The output can be seen in the video.

## Video

In [None]:
from IPython.display import Video

Video('images/rendering_predicted_result.mp4', width=800)

### Example Outputs from models we trained for longer

#### 5000 Iterations on chair dataset, with metrics

![](images/4800.png)
![](images/Figure_1.png)

#### 15000 Iterations on lego dataset

![](images/14250.png)
