In [None]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataset import get_rays
from rendering import rendering
from model import Voxels, Nerf
from ml_helpers import training

# Hyperparameters / Dataset

In [None]:
SEED = 18
batch_size = 1024   # 4096 used in paper, 1024 used in code
nb_epochs = 10
lr = 1e-3   # 5e-4 used in paper, 1e-3 used in code
#final_lr = 5e-5
gamma = .5
#gamma = (final_lr / lr) ** (1 / (nb_epochs - 1))
nb_bins = 100   # 128 used in paper, 100 used in code

#dataset = 'fox'
dataset = 'helmet/400x400'
datapath = f'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/{dataset}'

torch.manual_seed(SEED)
np.random.seed(SEED)

o, d, target_px_values = get_rays(datapath, mode='train')
dataloader = DataLoader(torch.cat((torch.from_numpy(o).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(d).reshape(-1, 3).type(torch.float),
                                   torch.from_numpy(target_px_values).reshape(-1, 3).type(torch.float)), dim=1),
                                   batch_size=batch_size, shuffle=True)

#dataloader_warmup = DataLoader(torch.cat((torch.from_numpy(o).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
#                               torch.from_numpy(d).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
#                               torch.from_numpy(target_px_values).reshape(90, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float)), dim=1),
#                               batch_size=batch_size, shuffle=True)

dataloader_warmup = DataLoader(torch.cat((torch.from_numpy(o).reshape(936, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(d).reshape(936, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float),
                               torch.from_numpy(target_px_values).reshape(936, 400, 400, 3)[:, 100:300, 100:300, :].reshape(-1, 3).type(torch.float)), dim=1),
                               batch_size=batch_size, shuffle=True)

#test_o, test_d, test_target_px_values = get_rays(datapath, mode='test')

# Testing

In [None]:
def mse2psnr(mse):
    return 20 * np.log10(1 / np.sqrt(mse))


@torch.no_grad()
def test(model, o, d, tn, tf, nb_bins=100, chunk_size=10, H=400, W=400, target=None):
    
    o = o.chunk(chunk_size)
    d = d.chunk(chunk_size)
    
    image = []
    for o_batch, d_batch in zip(o, d):
        img_batch = rendering(model, o_batch, d_batch, tn, tf, nb_bins=nb_bins, device=o_batch.device)
        image.append(img_batch) # N, 3
    image = torch.cat(image)
    image = image.reshape(H, W, 3).cpu().numpy()
    
    if target is not None:
        mse = ((image - target)**2).mean()
        psnr = mse2psnr(mse)
    
    if target is not None: 
        return image, mse, psnr
    else:
        return image

In [None]:
def ShowTrainResults(training_loss, title):
    plt.plot(training_loss)
    plt.ylim(0.001, 1)
    plt.grid(which='major', linestyle='-', linewidth='0.5', color='black')
    plt.grid(which='minor', linestyle=':', linewidth='0.5', color='gray')
    plt.yscale('log')
    plt.xlabel('Batches')
    plt.ylabel('Training Loss')
    plt.title(f'{title} - Min/max: {np.min(training_loss):.4f}/{np.max(training_loss):.2f}')
    plt.show()

In [None]:
def ShowTestImages(model, tn, tf, device, title):
    col, row = 4, 13
    fig, axs = plt.subplots(row, col, figsize=(20, 65))
    for r in range(row):
        for c in range(col):
            img_idx = 72 * r + 18 * c
            img, mse, psnr = test(model, torch.from_numpy(o[img_idx]).to(device).float(), torch.from_numpy(d[img_idx]).to(device).float(),
                    tn, tf, nb_bins=100, chunk_size=10, target=target_px_values[img_idx].reshape(400, 400, 3))
            axs[r, c].imshow(img, cmap='gray')
            axs[r, c].set_title(f'Image {img_idx}, PSNR: {psnr:.1f}')
            axs[r, c].axis('off')  # Hide axis for a cleaner look

    plt.tight_layout()  # Adjust subplots to fit in the figure area
    fig.suptitle(title, fontsize=16, y=1.05)
    plt.show()

# Training

In [None]:
device = 'cuda'

#tn, tf = 8., 12.
tn, tf = 2., 6.

model = Nerf(hidden_dim=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=gamma)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, 1, dataloader_warmup, device=device)
ShowTrainResults(training_loss, "Training Loss - Warmup")
ShowTestImages(model, tn, tf, device, "Test images for wamrup dataloader")

training_loss = training(model, optimizer, scheduler, tn, tf, nb_bins, nb_epochs, dataloader, device=device)
ShowTrainResults(training_loss, "Training Loss for last epoch")
ShowTestImages(model, tn, tf, device, "Test images for last epoch")