# FreeNeRF: Improving Few-shot Neural Rendering with Free Frequency Regularization

- Ref: 
    - https://arxiv.org/abs/2303.07418
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/FreeNeRF_Improving_Few_shot_Neural_Rendering_with_Free_Frequency_Regularization

- Conda env : [gsplat](../gsplat/README.md#setup-a-conda-environment)

In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torchshow as ts

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/09_FreeNeRF_models/'
g_video_path = './temp/09_FreeNeRF_novel_views/'
fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))


In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
else:
    g_device = torch.device("cpu")
print(g_device)

Current memory allocated on MPS: 0 bytes
Driver memory allocated on MPS: 393216 bytes
mps


In [5]:
# Model
class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=16, embedding_dim_direction=4, hidden_dim=128, T=40_000):
        super(NerfModel, self).__init__()

        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2),
                                    nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()
        self.T = T

    def positional_encoding(self, x, L, step, is_pos=False):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        out = torch.cat(out, dim=1)

        Lmax = 2 * 3 * L + 3
        if is_pos:
            out[:, int(step / self.T * Lmax) + 3:] = 0.
        return out

    def forward(self, o, d, step):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos, step, is_pos=True)
        emb_d = self.positional_encoding(d, self.embedding_dim_direction, step, is_pos=False)
        h = self.block1(emb_x)
        tmp = self.block2(torch.cat((h, emb_x), dim=1))
        h, sigma = tmp[:, :-1], torch.nn.functional.softplus(tmp[:, -1])
        h = self.block3(torch.cat((h, emb_d), dim=1))
        c = self.block4(h)
        return c, sigma

In [6]:
# Rendering

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, step, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3), step)
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    weight_sum = weights.sum(-1).sum(-1)  # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)

In [7]:
# @torch.no_grad()
# def test(nerf_model, hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
#     ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
#     ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

#     data = []
#     for i in range(int(np.ceil(H / chunk_size))):
#         ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
#         ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
#         regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
#         data.append(regenerated_px_values.cpu())
#     img_t = torch.cat(data).reshape(H, W, 3)
#     return img_t

@torch.no_grad()
def test(nerf_model, hn, hf, dataset, step, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, step, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values)
    img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)

    plt.figure()
    plt.imshow(img)
    plt.savefig(f'{g_video_path}/img_{img_index}_v1.png', bbox_inches='tight')
    plt.close()

def sample_batch(data, batch_size, device):
    idx = torch.randperm(data.shape[0])[:batch_size]
    return torch.from_numpy(data[idx]).to(device)

# def train(nerf_model, optimizer, training_data, nb_epochs, batch_size, device='cpu', hn=0, hf=1, nb_bins=192):
#     training_loss = []
#     for step in tqdm(range(nb_epochs)):
#         batch = sample_batch(training_data, batch_size, device)
#         rays_o = batch[:, :3].to(device)
#         rays_d = batch[:, 3:6].to(device)
#         ground_truth_px_values = batch[:, 6:].to(device)

#         regenerated_px_values = render_rays(nerf_model, rays_o, rays_d, step, hn=hn, hf=hf, nb_bins=nb_bins)
#         loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
#         training_loss.append(loss.item())
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if step % 100 or step == nb_epochs - 1:
#             _model_path = os.path.join(g_model_path, f'nerf_model_{step:03d}.pth')
#             torch.save(nerf_model.state_dict(), _model_path)
#     return training_loss

def train(nerf_model, optimizer, training_data, nb_epochs, batch_size, device='cpu', hn=0, hf=1, nb_bins=192):
    training_loss = []
    for step in tqdm(range(nb_epochs)):
        batch = sample_batch(training_data, batch_size, device)
        rays_o = batch[:, :3].to(device)
        rays_d = batch[:, 3:6].to(device)
        ground_truth_px_values = batch[:, 6:].to(device)

        regenerated_px_values = render_rays(nerf_model, rays_o, rays_d, step, hn=hn, hf=hf, nb_bins=nb_bins)
        loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
        training_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if step % 10000 or step == nb_epochs - 1:
        #     _model_path = os.path.join(g_model_path, f'nerf_model_{step:03d}.pth')
        #     torch.save(nerf_model.state_dict(), _model_path)
    return training_loss

In [8]:
# model = NerfModel(hidden_dim=256).to(g_device)
# model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

# data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
# train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6, nb_bins=192, H=400,W=400)


# # for img_index in range(200):
# #     test(model, 2, 6, testing_dataset, img_index=img_index, nb_bins=192, H=800, W=800)
# imgT_lst = []
# for idx in tqdm(range(200), desc="Validation"):
#         imgT_lst.append(test(model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192, H = 400, W = 400 ).unsqueeze(0))
# img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
# _video_path = os.path.join(g_video_path,  f'nerf_model_test.mp4')
# torchvision.io.write_video(_video_path, img_t, fps)


In [9]:
nb_epochs = 80_000

img0 = training_dataset[26 * 400 * 400:(26 + 1) * 400 * 400]
img2 = training_dataset[86 * 400 * 400:(86 + 1) * 400 * 400]
img3 = training_dataset[2 * 400 * 400:(2 + 1) * 400 * 400]
img4 = training_dataset[55 * 400 * 400:(55 + 1) * 400 * 400]
img5 = training_dataset[75 * 400 * 400:(75 + 1) * 400 * 400]
img6 = training_dataset[93 * 400 * 400:(93 + 1) * 400 * 400]
img7 = training_dataset[16 * 400 * 400:(16 + 1) * 400 * 400]
img8 = training_dataset[73 * 400 * 400:(73 + 1) * 400 * 400]

training_data = np.concatenate((img0, img2, img3, img4, img5, img6, img7, img8))
model = NerfModel(hidden_dim=256, T=nb_epochs // 2).to(g_device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
train(model, model_optimizer, training_data, nb_epochs, 1024, device=g_device, hn=2, hf=6, nb_bins=192)
_model_path = os.path.join(g_model_path, f'nerf_model_final.pth')
torch.save(model.state_dict(), _model_path)

for img_index in range(200):
        test(model, 2, 6, testing_dataset, nb_epochs, img_index=img_index, nb_bins=192, H=400, W=400)

100%|██████████| 80000/80000 [3:49:18<00:00,  5.81it/s]  


In [12]:
# @torch.no_grad()
# def test(nerf_model, hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
#     ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
#     ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

#     data = []
#     for i in range(int(np.ceil(H / chunk_size))):
#         ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
#         ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
#         regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
#         data.append(regenerated_px_values.cpu())
#     img_t = torch.cat(data).reshape(H, W, 3)
#     return img_t

@torch.no_grad()
def test(nerf_model, hn, hf, dataset, step, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, step, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.data.cpu())
    img_t = torch.cat(data).reshape(H, W, 3)
    return img_t

model.eval()
imgT_lst = []
for idx in tqdm(range(200), desc="Validation"):
    imgT_lst.append(test(model, 2, 6, testing_dataset, nb_epochs, chunk_size = 10, img_index=idx, nb_bins=192, H = 400, W = 400 ).unsqueeze(0))
img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
_video_path = os.path.join(g_video_path,  f'nerf_model_test.mp4')
torchvision.io.write_video(_video_path, img_t, fps)


Validation: 100%|██████████| 200/200 [29:39<00:00,  8.90s/it]
