In [None]:
import pathlib as pl
import torch
from torch import nn
from torchvision import io
from torchvision import transforms
import json

data_source_name = "NeRF_Data"
dataset_name = "nerf_synthetic"
scene_name = "lego"

root_data_dir = pl.Path('./data/')
data_path = root_data_dir / data_source_name / dataset_name / scene_name

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

from src.data import FrameDataset

digit_pattern = re.compile(r"\d+")

   
train_dataset = FrameDataset(data_path, "train")

print(train_dataset.ex_img.shape)
print(train_dataset.ex_img[0].max(), train_dataset.ex_img[3].max())
plt.imshow(train_dataset.ex_img.T.swapaxes(0, 1))

In [None]:
len(train_dataset.img_paths)

In [None]:
from src.data import get_rays

H, W = train_dataset.shape
focal = train_dataset.focal
c2w = train_dataset.transform_matrixes[5]

rays_d, rays_o = get_rays(H, W, focal, c2w)
rays_d.shape, rays_o.shape

In [None]:
#H, W = 8, 8
L1 = 10
L2 = 4
D = 3

batch_size = 1
n_bins = 5
n_rays = H * W
n_samples = n_bins * n_rays

t_near = 0.1
t_far = 5.0

In [None]:
from src.training import positional_encoding

tmp_o = torch.randn(batch_size * n_samples * D,device=DEVICE).reshape(batch_size, n_rays, n_bins, D)
tmp_d = torch.randn(batch_size * n_samples * D,device=DEVICE).reshape(batch_size, n_rays, n_bins, D)
tex = positional_encoding(tmp_o, L1)
ted = positional_encoding(tmp_o, L2)

tex.shape

In [None]:
from src.model import TestNet

tn = TestNet(L1, L2, 3, 128)
#tmp_c, tmp_sigma = tn(tex, ted)
#tmp_c.shape, tmp_sigma.shape

In [None]:
from src.training import get_t

tmp_t, tmp_dt = get_t(batch_size, n_rays, n_bins, t_near, t_far)
tmp_t.shape, tmp_dt.shape

In [None]:
from src.training import expected_color

#tmp_c_hat = expected_color(tmp_c, tmp_sigma, tmp_dt).shape

In [None]:
from src.data import RayDataset

train_render_dataset = RayDataset(train_dataset)
train_render_loader = DataLoader(train_render_dataset, batch_size=batch_size, shuffle=False,num_workers=2,pin_memory=True)

In [None]:
r_o, r_d, C_r = next(iter(train_render_dataset))
r_o = r_o[None, ...].to(DEVICE)
r_d = r_d[None, ...].to(DEVICE)
C_r = C_r[None, ...].to(DEVICE)

def strat_sampling(N: int, t_near: float, t_far: float) -> torch.Tensor:
    samples = (torch.arange(N,device=DEVICE) + torch.rand(N,device=DEVICE)) * (t_far - t_near) / N  # <N>
    return samples


def get_t(batch_size, n_rays, n_bins, t_near, t_far) -> torch.Tensor:
    t = strat_sampling(batch_size * n_rays * n_bins, t_near, t_far).reshape(
        batch_size, n_rays, n_bins
    )
    dt = torch.diff(t, dim=-1)
    return t, dt

B = r_o.size(0)
t, dt = get_t(B, n_rays, n_bins, t_near, t_far)

r_d = nn.functional.normalize(r_d, dim=-1)

r_o = r_o.reshape(B, -1, 1, 3)
r_d = r_d.reshape(B, -1, 1, 3)
t = t[..., None]


pts = r_o + t * r_d
pts, r_d
print(pts.shape, r_d.shape)

ex = positional_encoding(pts, L1)
ed = positional_encoding(r_d, L2)

print(ex.shape, ed.shape)

In [None]:
from src.training import LitNerf

nerf = LitNerf(tn, n_rays, n_bins, t_near, t_far, L1, L2, learning_rate=3e-4).to(DEVICE)
#nerf.training_step(next(iter(train_render_loader)), 0)

In [None]:
import pytorch_lightning as ptl

trainer = ptl.Trainer(max_epochs=10)
trainer.fit(nerf, train_dataloaders=train_render_loader)