In [None]:
import pathlib as pl
import torch
from torch import nn
from torchvision import io
from torchvision import transforms
import json

data_source_name = "NeRF_Data"
dataset_name = "nerf_synthetic"
scene_name = "lego"

root_data_dir = pl.Path('./data/')
data_path = root_data_dir / data_source_name / dataset_name / scene_name

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader


digit_pattern = re.compile(r"\d+")


def load_transforms(tfs_path: pl.Path) -> tuple[float, list[float], list[torch.FloatTensor]]:
    with open(tfs_path, "r") as f:
        transforms = json.load(f)

    cam_angle_x = float(transforms["camera_angle_x"])
    rotations = []
    transform_matrixes = []

    for frame in transforms["frames"]:       
        # Assume ordered
        rotation = float(frame["rotation"])
        transform_matrix = torch.FloatTensor(frame["transform_matrix"])

        rotations.append(rotation)
        transform_matrixes.append(transform_matrix)

    return cam_angle_x, rotations, transform_matrixes
        

def extract_digit_from_path_name(path: pl.Path) -> int:
    match = digit_pattern.search(path.name)

    if not match:
        return None

    return int(match.group(0))


def load_img_paths(imgs_path: pl.Path):
    paths = imgs_path.iterdir() # Ordered lexagraphically
    paths = sorted(paths, key=extract_digit_from_path_name) # Ordered numerically
    return paths


def load_frame(
        imgs_path: pl.Path,
        rotations: list[float],
        tf_matrixes: list[torch.FloatTensor],
        idx: int
    ) -> tuple[torch.Tensor, float, torch.Tensor]:
    img_path = imgs_path[idx]
    rotation = rotations[idx]
    tf_matrix = tf_matrixes[idx]
    img = io.read_image(str(img_path), mode=io.ImageReadMode.RGB_ALPHA)

    return img, rotation, tf_matrix


class FrameDataset(Dataset):
    def __init__(
            self,
            data_path: pl.Path,
            data_mode: str,  # 'train', 'val', 'test'
            ex_idx: int = 5
        ) -> None:
        super().__init__()

        self.imgs_path = data_path / data_mode
        self.tfs_path = data_path / f"transforms_{data_mode}.json"

        self.cam_angle_x, self.rotations, self.transform_matrixes = load_transforms(self.tfs_path)
        self.img_paths = load_img_paths(self.imgs_path)

        self.ex_img, *_ = self[ex_idx]
        
        self.C, self.H, self.W = self.ex_img.shape

        self.focal = 0.5 * self.W / np.tan(0.5 * self.cam_angle_x)


    @property
    def shape(self) -> tuple[float, float]:
        return self.H, self.W

    def __len__(self) -> int:
        return len(self.transform_matrixes)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, float, torch.Tensor]:
        img, rot, c2w = load_frame(self.img_paths, self.rotations, self.transform_matrixes, idx)
        C, H, W = img.shape
        #img = transforms.Resize((H//8, W//8), antialias=True)(img)
        return img, rot, c2w


train_dataset = FrameDataset(data_path, "train")

print(train_dataset.ex_img.shape)
print(train_dataset.ex_img[0].max(), train_dataset.ex_img[3].max())
plt.imshow(train_dataset.ex_img.T.swapaxes(0, 1))

In [None]:
len(train_dataset.img_paths)

In [None]:
def get_rays(
    H: int, W: int, focal: float, c2w: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    # Ported from https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L123

    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing="xy",
    )
    ds = torch.stack(
        [(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -torch.ones_like(i)], dim=-1
    )
    rays_d = ds @ c2w[:3, :3].T
    rays_o = torch.broadcast_to(c2w[:3, -1], rays_d.shape)

    return rays_d, rays_o

In [None]:
# https://github.com/yenchenlin/nerf-pytorch/issues/41
H, W = train_dataset.shape
focal = train_dataset.focal
c2w = train_dataset.transform_matrixes[5]


rays_d, rays_o = get_rays(H, W, focal, c2w)
rays_d.shape, rays_o.shape

In [None]:
#H, W = 8, 8
L1 = 10
L2 = 4
D = 3

batch_size = 1
n_bins = 5
n_rays = H * W
n_samples = n_bins * n_rays

t_near = 0.1
t_far = 5.0

In [None]:
def positional_encoding(p: torch.Tensor, L: int) -> torch.Tensor:
    assert len(p.shape) == 4
    B, NR, NB, D = p.shape
    #p = p.reshape(B, NR*NB, D)

    # Z denotes transformed input p
    # Z_ij becomes 2^i * p_i * p_j for each i in 0..L-1 and each component j in 1..3
    # Thus dimension is <B, D, L>
    z = (2 ** torch.arange(L).repeat(D, 1)) * (torch.pi *  p[..., None])

    # X denotes the encoded value for each transformed input
    x1 = torch.sin(z)
    x2 = torch.cos(z)
    
    # We want ordering sin(x) cos(x) sin(y) cos(y) sin(z) cos(z) repeated for each element in 1..L
    # First we stack encoding into a matrix, then we flatten the matrix to put each row side by side.
    x = torch.stack((x1, x2), dim=5)        # <B, NR, NB, D, L, 2>
    x = x.swapaxes(3, 4)                    # <B, NR, NB, L, D, 2>
    x = x.reshape(B, NR*NB, 2 * D * L)        # Finally, flatten to shape <B, N, 2*D*L>

    return x



tmp_o = torch.randn(batch_size * n_samples * D).reshape(batch_size, n_rays, n_bins, D)
tmp_d = torch.randn(batch_size * n_samples * D).reshape(batch_size, n_rays, n_bins, D)
tex = positional_encoding(tmp_o, L1)
ted = positional_encoding(tmp_o, L2)

tex.shape

In [None]:
class TestNet(nn.Module):
    def __init__(self, L1, L2, n_components, n_hidden):
        super().__init__()
        
        self.d1 = n_components * 2 * L1
        self.d2 = n_components * 2 * L2

        self.lin1 = nn.Linear(self.d1, n_hidden + 1)
        self.lin2 = nn.Linear(n_hidden + self.d2, 3)

    def forward(self, o: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
        # o: <B, NS, D*2*L> 
        # d: <B, NS, D*2*L> 
        assert len(o.shape) == 3
        assert len(d.shape) == 3
        assert o.size(2) == self.d1
        assert d.size(2) == self.d2

        z1 = self.lin1(o)

        x = torch.cat((d, z1[..., 1:]), dim=2)

        sigma = z1[..., 0]  # <B, NS>
        rgb = self.lin2(x)  # <B, NS, 3>

        sigma = nn.functional.relu(sigma)
        rgb = nn.functional.sigmoid(rgb)

        return rgb, sigma

tn = TestNet(L1, L2, 3, 128)
tmp_c, tmp_sigma = tn(tex, ted)
tmp_c.shape, tmp_sigma.shape

In [None]:
def strat_sampling(N: int, t_near: float, t_far: float) -> torch.Tensor:
    samples = (torch.arange(N) + torch.rand(N)) * (t_far - t_near) / N  # <N>
    return samples


def get_t(batch_size, n_rays, n_bins, t_near, t_far) -> torch.Tensor:
    t = strat_sampling(batch_size * n_rays * n_bins, t_near, t_far).reshape(batch_size, n_rays, n_bins)
    dt = torch.diff(t, dim=-1)
    return t, dt

tmp_t, tmp_dt = get_t(batch_size, n_rays, n_bins, t_near, t_far)
tmp_t.shape, tmp_dt.shape

In [None]:
def expected_color(c, sigma, dt):
    # c: <B, N, 3>
    # sigma: <B, NS>
    # delta: <B, NR, NB>
    # B: batch size
    # N: number of samples
    # C: number of components 

    assert len(sigma.shape) == 2
    assert len(dt.shape) == 3
    assert len(c.shape) == 3

    B, NR, NB = dt.shape
    NB = NB + 1
    C = c.size(-1)

    # Unpack from n_samples to n_rays x n_bins
    sigma = sigma.reshape(B, NR, NB)
    c = c.reshape(B, NR, NB, C)
    
    mul = dt * sigma[..., :-1]

    # Compute cumuluative probability, 
    # Since equation (3) sums T_i from i=1 to i-1, we set the first value to (exp 0 = 1) and ignore the last value
    T = torch.exp(-torch.cumsum(mul, dim=-1))
    T = torch.cat((torch.ones(B, NR, 1), T), dim=-1)[..., :-1]

    # Since we do no have a delta for the last value, 
    # we directly set the last value of w to T at i=N,
    # which is the dot product between sigma and delta 
    T_N = torch.einsum("brn,brn->br", dt, sigma[..., :-1])[..., None]
    w = T * (1 - torch.exp(-mul) )
    w = torch.cat((w, T_N), dim=-1)

    c_hat = torch.einsum("brn,brnc->brc", w, c)
    return c_hat

tmp_c_hat = expected_color(tmp_c, tmp_sigma, tmp_dt).shape

In [None]:
class RayDataset(Dataset):
    def __init__(self, frame_dataset: FrameDataset) -> None:
        super().__init__()

        self.frame_dataset = frame_dataset 

    def __len__(self):
        return len(self.frame_dataset)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        C_r, _, c2w = self.frame_dataset[idx]
        C_r = C_r / 255.0

        focal = self.frame_dataset.focal
        H, W = self.frame_dataset.shape
        
        r_o, r_d = get_rays(H, W, focal, c2w)
        return r_o, r_d, C_r
        
        
train_render_dataset = RayDataset(train_dataset)
train_render_loader = DataLoader(train_render_dataset, batch_size=batch_size, shuffle=False)

In [None]:
r_o, r_d, C_r = next(iter(train_render_dataset))
r_o = r_o[None, ...]
r_d = r_d[None, ...]
C_r = C_r[None, ...]

B = r_o.size(0)
t, dt = get_t(B, n_rays, n_bins, t_near, t_far)

r_d = nn.functional.normalize(r_d, dim=-1)

r_o = r_o.reshape(B, -1, 1, 3)
r_d = r_d.reshape(B, -1, 1, 3)
t = t[..., None]

pts = r_o + t * r_d
pts, r_d
print(pts.shape, r_d.shape)

ex = positional_encoding(pts, L1)
ed = positional_encoding(r_d, L2)

print(ex.shape, ed.shape)

In [None]:
import pytorch_lightning as ptl

class LitNerf(ptl.LightningModule):
    def __init__(self, scene_model: nn.Module, learning_rate: float = 3e-4):
        super().__init__()
        self.scene_model = scene_model
        self.criterion = nn.MSELoss()
        self.learning_rate = learning_rate

    def training_step(self, batch, batch_idx):
        r_o, r_d, C_r = batch

        B = r_o.size(0)
        t, dt = get_t(B, n_rays, n_bins, t_near, t_far)

        r_d = nn.functional.normalize(r_d, dim=-1)

        r_o = r_o.reshape(B, -1, 1, 3)
        r_d = r_d.reshape(B, -1, 1, 3)
        r_d = r_d.repeat(1, 1, n_bins, 1)
        t = t[..., None]

        C_r = C_r[:, :3].reshape(B, 3, -1).swapaxes(1, 2)

        x = r_o + t * r_d
        ex = positional_encoding(x, L1)
        ed = positional_encoding(r_d, L2)

        c, sigma = self.scene_model(ex, ed)
        c_hat = expected_color(c, sigma, dt)

        loss = self.criterion(c_hat, C_r)

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


nerf = LitNerf(tn)
for i in range(2):
    %prun nerf.training_step(next(iter(train_render_loader)), 0)


In [None]:
trainer = ptl.Trainer()
trainer.fit(nerf, train_dataloaders=train_render_loader)