In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

from src.data import FrameDataset
from src.data import RayDataset

import pathlib as pl
import torch
from torch import nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
import tqdm


def get_path_number(path: pl.Path):
    return int(path.stem.split("-")[-1])


@torch.jit.script
def strat_sampling(
    N: int, t_near: float, t_far: float, device: torch.device
) -> torch.Tensor:
    samples = (
        (torch.arange(N, device=device) + torch.rand(N, device=device))
        * (t_far - t_near)
        / N
    )  # <N>
    return samples


@torch.jit.script
def positional_encoding(p: torch.Tensor, L: int, device: torch.device) -> torch.Tensor:
    assert len(p.shape) == 3
    B, NB, D = p.shape

    # Z denotes transformed input p
    # Z_ij becomes 2^i * p_i * p_j for each i in 0..L-1 and each component j in 1..3
    # Thus dimension is <B, NB, D, L>
    z = (2 ** torch.arange(L, device=device).repeat(D, 1)) * (torch.pi * p[..., None])

    # X denotes the encoded value for each transformed input
    x1 = torch.sin(z)
    x2 = torch.cos(z)

    # We want ordering sin(x) cos(x) sin(y) cos(y) sin(z) cos(z) repeated for each element in 1..L
    # First we stack encoding into a matrix, then we flatten the matrix to put each row side by side.
    x = torch.stack((x1, x2), dim=4)  # <B, NB, D, L, 2>
    x = x.swapaxes(2, 3)  # <B, NB, L, D, 2>
    x = x.reshape(B, NB, 2 * D * L)  # Finally, flatten to shape <B, NB, 2*D*L>

    return x


@torch.jit.script
def get_t(
    n_rays: int, n_bins: int, t_near: float, t_far: float, device: torch.device
) -> torch.Tensor:
    t = strat_sampling(n_rays * n_bins, t_near, t_far, device)
    t = t.reshape(n_rays, n_bins)  # B, NB
    return t


@torch.jit.script
def batchify_rays(batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
    # r_o: <S, S, D>
    # r_d: <S, S, D>
    # C_r: <4, S, S>
    # img_size: S
    # n_rays: N = S*S
    # N = S * S
    # chunk_size: C = (S/NC)^2

    r_o, r_d, C_r = batch
    S, _, D = r_o.shape
    r_o = r_o.reshape(S * S, D)
    r_d = r_d.reshape(S * S, D)

    # We only care about RGB, which is the first three dimensions
    # We also reshape it to <N, 3> to make it an RGB value for each ray
    C_r = C_r[:3].reshape(3, -1).T  # <N, D>

    # r_o, r_d: <N, D>
    # C_r: <N, 3>
    return r_o, r_d, C_r


@torch.jit.script
def preprocess_rays(
    r_o: torch.Tensor,
    r_d: torch.Tensor,
    n_bins: int,
    t_near: float,
    t_far: float,
    L1: int,
    L2: int,
    device: torch.device,
):
    # r_o: <N, D>
    # r_o: <N, D>
    # C_r: <N, 3>

    N, D = r_o.shape

    t = get_t(N, n_bins, t_near, t_far, device=device)  # <C, NB>
    r_d = nn.functional.normalize(r_d, dim=-1)  # <N, D>

    # Reshape the dimensions for broadcasting during x = r_o + t * r_d
    r_o = r_o.reshape(N, 1, D)

    # Repeat this for broadcasting when multiplying
    r_d = r_d.reshape(N, 1, D).repeat(1, n_bins, 1)

    # We will do elementwise multiplication with each dimension, so we add a dimension at the end for broadcasting
    tmul = t[..., None]
    x = r_o + tmul * r_d  # <N, NB, D>

    # <N, NB, 2*D*L>
    ex = positional_encoding(x, L1, device=device)
    ed = positional_encoding(r_d, L2, device=device)

    ex = ex.reshape(N, n_bins, 6 * L1)  # <C, NB, 2*D*L>
    ed = ed.reshape(N, n_bins, 6 * L2)  # <C, NB, 2*D*L>

    return ex, ed, t


@torch.jit.script
def expected_color(
    c: torch.Tensor, sigma: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
    # N: number of rays in batch
    # NB: number of bins
    # c: <N, NB, 3>
    # sigma: <N, NB>
    # t: <N, NB>

    assert len(c.shape) == 3
    assert len(sigma.shape) == 2
    assert len(t.shape) == 2


    # Multiply up to T-1. For effeciency purposes, we keep the last dimension, but it gets overwritten later.
    dt = torch.roll(t, -1, dims=-1) - t
    mul = dt * sigma  

    # Compute cumuluative probability,
    # Since equation (3) sums T_i from i=1 to i-1, we set the first value to (exp 0 = 1) and ignore the last value.
    # We don't remove it from the tensor yet, but it will be overwritten later.
    T = torch.exp(-torch.cumsum(mul, dim=-1))
    T = T.roll(1, dims=-1)
    T[..., 0] = 1

    # Since we do no have a delta for the last value,
    # we directly set the last value of w to T at i=N,
    # which is the dot product between sigma and delta
    w = T * (1 - torch.exp(-mul))
    w[..., -1] = torch.einsum("nb,nb->n", dt[..., :-1], sigma[..., :-1])

    c_hat = torch.einsum("nb,nbc->nc", w, c)
    return c_hat

def train_nerf_batch(
    batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    model: nn.Module,
    criterion: nn.Module,
    optim: torch.optim.Optimizer,
    chunk_size: int,
    n_bins: int,
    t_near: float,
    t_far: float,
    L1: int,
    L2: int,
    device: torch.device,
) -> dict[str, any]:
    img_size = batch[0].size(0)
    n_rays = img_size * img_size
    assert n_rays % chunk_size == 0

    r_o_full, r_d_full, C_r_full = batchify_rays(batch)

    model.train()
    running_loss = 0

    for i in range(0, n_rays, chunk_size):
        optim.zero_grad()

        r_o = r_o_full[i : i + chunk_size].to(device)
        r_d = r_d_full[i : i + chunk_size].to(device)
        C_r = C_r_full[i : i + chunk_size].to(device)

        ex, ed, t = preprocess_rays(
            r_o, r_d, n_bins, t_near, t_far, L1, L2, device=device
        )

        c, sigma = model(ex, ed)
        c_hat = expected_color(c, sigma, t)

        batch_loss = criterion(c_hat, C_r)
        batch_loss.backward()

        optim.step()

        running_loss += batch_loss.item()

    train_loss = running_loss / (n_rays / chunk_size)
    return train_loss


def eval_nerf_batch(
    batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    model: nn.Module,
    criterion: nn.Module,
    chunk_size: int,
    n_bins: int,
    t_near: float,
    t_far: float,
    L1: int,
    L2: int,
    device: torch.device,
):
    img_size = batch[0].size(0)
    n_rays = img_size * img_size
    assert n_rays % chunk_size == 0

    model.eval()
    with torch.no_grad():
        img_size = batch[0].size(0)
        n_rays = img_size * img_size
        assert n_rays % chunk_size == 0

        r_o_full, r_d_full, C_r_full = batchify_rays(batch)
        running_loss = 0

        for i in range(0, n_rays, chunk_size):
            r_o = r_o_full[i : i + chunk_size].to(device)
            r_d = r_d_full[i : i + chunk_size].to(device)
            C_r = C_r_full[i : i + chunk_size].to(device)

            ex, ed, t = preprocess_rays(
                r_o, r_d, n_bins, t_near, t_far, L1, L2, device=device
            )

            c, sigma = model(ex, ed)
            c_hat = expected_color(c, sigma, t)

            batch_loss = criterion(c_hat, C_r)
            running_loss += batch_loss.item()

    val_loss = running_loss / (n_rays / chunk_size)
    return val_loss


def train_nerf(
    model: nn.Module,
    optim: torch.optim.Optimizer,
    criterion: nn.Module,
    train_dataset,
    val_dataset,
    n_epochs: int,
    chunk_size: int,
    n_bins: int,
    t_near: float,
    t_far: float,
    L1: int,
    L2: int,
    base_save_path: pl.Path,
    device: torch.device,
    load_checkpoint_path: pl.Path = None,
    shuffle_train: bool = True,
    shuffle_val: bool = True,
    limit_train_size: int = None,
    limit_val_size: int = None,
):
    if not base_save_path.exists():
        raise ValueError

    if load_checkpoint_path:
        if not load_checkpoint_path.exists():
            raise ValueError

        checkpoint = torch.load(str(load_checkpoint_path))
        model.load_state_dict(checkpoint["model_state_dict"])
        optim.load_state_dict(checkpoint["optim_state_dict"])
        epoch = checkpoint["epoch"] + 1

    else:
        checkpoint = {
            "epoch": None,
            "train_loss": [],
            "val_loss": [],
            "model_state_dict": None,
            "optim_state_dict": None,
        }
        epoch = 0

    model_save_path = base_save_path / model.name
    model_save_path.mkdir(exist_ok=True)

    n_saved_versions = len(list(model_save_path.glob("version-*")))
    if n_saved_versions > 0:
        n_saved_versions = (
            get_path_number(
                sorted(
                    model_save_path.glob("version-*"),
                    key=get_path_number,
                )[-1]
            )
            + 1
        )

    version_path = model_save_path / f"version-{n_saved_versions}"
    version_path.mkdir(parents=True, exist_ok=False)

    while epoch < n_epochs:
        print(f"epoch: {epoch}/{n_epochs}")
        train_loss_epoch = 0
        n_batches_train = (
            len(train_dataset) if (limit_train_size is None) else limit_train_size
        )
        batch_idxs_train = (
            torch.randperm(n_batches_train)
            if shuffle_train
            else torch.arange(n_batches_train)
        )
        for i_batch_train in tqdm.tqdm(batch_idxs_train):
            batch = train_dataset[i_batch_train]
            train_loss_batch = train_nerf_batch(
                batch,
                model,
                criterion,
                optim,
                chunk_size,
                n_bins,
                t_near,
                t_far,
                L1,
                L2,
                device=device,
            )
            train_loss_epoch += train_loss_batch

        val_loss_epoch = 0
        n_batches_val = len(val_dataset) if (limit_val_size is None) else limit_val_size
        batch_idxs_val = (
            torch.randperm(n_batches_val)
            if shuffle_val
            else torch.arange(n_batches_val)
        )
        for i_batch_val in tqdm.tqdm(batch_idxs_val):
            batch = val_dataset[i_batch_val]
            val_loss_batch = eval_nerf_batch(
                batch,
                model,
                criterion,
                chunk_size,
                n_bins,
                t_near,
                t_far,
                L1,
                L2,
                device=device,
            )
            val_loss_epoch += val_loss_batch

        train_loss = train_loss_epoch / n_batches_train
        val_loss = val_loss_epoch / n_batches_val

        checkpoint_path = version_path / f"checkpoint-{epoch}"
        print(f"\ttrain loss: {train_loss:.5e}\n\tval loss: {val_loss:.5e}\n")

        checkpoint["epoch"] = epoch 
        checkpoint["train_loss"].append(train_loss)
        checkpoint["val_loss"].append(val_loss)
        checkpoint["model_state_dict"] = model.state_dict()
        checkpoint["optim_state_dict"] = optim.state_dict()
        torch.save(checkpoint, checkpoint_path)

        epoch += 1

    return checkpoint

In [None]:
data_source_name = "NeRF_Data"
dataset_name = "nerf_synthetic"
scene_name = "lego"

root_data_dir = pl.Path('./data/')
data_path = root_data_dir / data_source_name / dataset_name / scene_name

train_frame_dataset = FrameDataset(data_path, "train", downsample_factor=4)
val_frame_dataset = FrameDataset(data_path, "val", downsample_factor=4)
train_ray_dataset = RayDataset(train_frame_dataset)
val_ray_dataset = RayDataset(val_frame_dataset)

print(train_frame_dataset.ex_img.shape)
print(train_frame_dataset.ex_img[0].max(), train_frame_dataset.ex_img[3].max())
plt.imshow(train_frame_dataset.ex_img.T.swapaxes(0, 1))

In [None]:
from src.model import MediumNet

chunk_size = 8000
n_epochs = 100
n_bins = 100
t_near = 0.1
t_far = 5.0
n_hidden = 256
n_components = 3
learning_rate = 2e-4
L1 = 10
L2 = 4

base_save_path = pl.Path('./models/')
load_checkpoint_path = None

# Set these when you want to test. 
# If None, the full dataset will be used on each epoch.
limit_train_size = None
limit_val_size = 30


#model = TestNet("testnet", L1, L2, n_components, n_hidden).to(DEVICE)
model = MediumNet("mediumnet", L1, L2, n_components, n_hidden).to(DEVICE)
model = torch.jit.script(model)
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()


train_nerf(
    model=model,
    optim=optim,
    criterion=criterion,
    train_dataset=train_ray_dataset,
    val_dataset=val_ray_dataset,
    n_epochs=n_epochs,
    chunk_size=chunk_size,
    n_bins=n_bins,
    t_near=t_near,
    t_far=t_far,
    L1=L1,
    L2=L2,
    base_save_path=base_save_path,
    device=DEVICE,
    load_checkpoint_path=load_checkpoint_path,
    limit_train_size=limit_train_size,
    limit_val_size=limit_val_size,
)

In [None]:
cp = torch.load(pl.Path('./models/mediumnet-v1/version-0/checkpoint-99'))
train_loss = cp["train_loss"]
val_loss = cp["val_loss"]
n_epochs = cp["epoch"] + 1

plt.plot(np.arange(n_epochs), train_loss, label="train loss")
plt.plot(np.arange(n_epochs), val_loss, label="val loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()