In [None]:
import torch
import torch.nn as nn
import numpy as np

import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataset import get_rays
from rendering import rendering
from model import Voxels, Nerf
from ml_helpers import training
from tqdm import tqdm

In [None]:
#tn, tf = 8., 12.
#tn, tf = 2., 6.
tn, tf = 2., 6.
device = 'cuda'

#datapath = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/fox'
#test_o, test_d, test_target_px_values = get_rays(datapath, mode='test')
datapath = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/400x400'
test_o, test_d, test_target_px_values = get_rays(datapath, mode='train')

In [None]:
def mse2psnr(mse):
    return 20 * np.log10(1 / np.sqrt(mse))


@torch.no_grad()
def test(model, o, d, tn, tf, nb_bins=100, chunk_size=10, H=400, W=400, target=None):
    
    o = o.chunk(chunk_size)
    d = d.chunk(chunk_size)
    
    image = []
    for o_batch, d_batch in zip(o, d):
        img_batch = rendering(model, o_batch, d_batch, tn, tf, nb_bins=nb_bins, device=o_batch.device)
        image.append(img_batch) # N, 3
    image = torch.cat(image)
    image = image.reshape(H, W, 3).cpu().numpy()
    
    if target is not None:
        mse = ((image - target)**2).mean()
        psnr = mse2psnr(mse)
    
    if target is not None: 
        return image, mse, psnr
    else:
        return image

In [None]:
def ShowTestReults(model, tn, tf, device, title):
    col, row = 4, 13
    fig, axs = plt.subplots(row, col, figsize=(20, 65))
    for r in tqdm(range(row)):
        for c in range(col):
            img_idx = 72 * r + 18 * c
            img, mse, psnr = test(model, torch.from_numpy(test_o[img_idx]).to(device).float(), torch.from_numpy(test_d[img_idx]).to(device).float(),
                    tn, tf, nb_bins=100, chunk_size=10, target=test_target_px_values[img_idx].reshape(400, 400, 3))
            axs[r, c].imshow(img, cmap='gray')
            axs[r, c].set_title(f'Image {img_idx}, PSNR: {psnr:.1f}')
            axs[r, c].axis('off')  # Hide axis for a cleaner look

    plt.tight_layout()  # Adjust subplots to fit in the figure area
    fig.suptitle(title, fontsize=16, y=1.05)
    plt.show()

In [None]:
#nn_model_path = 'model_nerf'
#nn_model_path = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/nn_models/helmet/400x400/model_nerf0'
nn_model_path = 'C:/_sw/eb_python/deep_learning/nerf/udemy_class/_test/model_nerf-epoch_6'
model = torch.load(nn_model_path).to(device)

ShowTestReults(model, tn, tf, device, "Training loss - Epoch #7")

In [None]:
img_idx = 1
img, mse, psnr = test(model, torch.from_numpy(test_o[img_idx]).to(device).float(), torch.from_numpy(test_d[img_idx]).to(device).float(),
                tn, tf, nb_bins=100, chunk_size=10, target=test_target_px_values[img_idx].reshape(400, 400, 3))

print(psnr)
plt.imshow(img)