# K-Planes: Explicit Radiance Fields in Space, Time, and Appearance

---

- Conda env : [3dcv_playgrounds](../README.md#setup-a-conda-environment)

----

- Ref: 
    - https://arxiv.org/abs/2301.10241
    - https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/tree/main/KPlanes_Explicit_Radiance_Fields_in_Space_Time_and_Appearance


In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
import torchshow as ts

import os
from pathlib import Path
import gdown

In [2]:
# Download Data
Path('./temp/data').mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1hH7NhaXxIthO9-FeT16fvpf_MVIhf41J'
train_data_path = './temp/data/training_data.pkl'
gdown.download(url, train_data_path, quiet=False, resume=True)


url = 'https://drive.google.com/uc?id=16M64h0KKgFKhM8hJDpqd15YWYhafUs2Q'
test_data_path = './temp/data/testing_data.pkl'
gdown.download(url, test_data_path, quiet=False, resume=True)

Skipping already downloaded file ./temp/data/training_data.pkl
Skipping already downloaded file ./temp/data/testing_data.pkl


'./temp/data/testing_data.pkl'

In [3]:
g_model_path = './temp/07_KplaneNeRF_models/'
g_video_path = './temp/07_KplaneNeRF_novel_views/'
fps = 10
Path(g_model_path).mkdir(exist_ok=True, parents=True)
Path(g_video_path).mkdir(exist_ok=True, parents=True)
training_dataset = torch.from_numpy(np.load('./temp/data/training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('./temp/data/testing_data.pkl', allow_pickle=True))


In [4]:
if torch.backends.mps.is_available():
    g_device = torch.device("mps")
    print(f"Current memory allocated on MPS: {torch.mps.current_allocated_memory()} bytes")
    print(f"Driver memory allocated on MPS: {torch.mps.driver_allocated_memory()} bytes")
elif torch.cuda.is_available():
    g_device = torch.device("cuda")
else:
    g_device = torch.device("cpu")
print(g_device)

cuda


In [5]:
# Model
class NerfModel(nn.Module):
    def __init__(self, embedding_dim_direction=4, hidden_dim=64, N=512, F=96, scale=1.5):
        """
        The parameter scale represents the maximum absolute value among all coordinates and is used for scaling the data
        """
        super(NerfModel, self).__init__()

        self.xy_plane = nn.Parameter(torch.rand((N, N, F)))
        self.yz_plane = nn.Parameter(torch.rand((N, N, F)))
        self.xz_plane = nn.Parameter(torch.rand((N, N, F)))

        self.block1 = nn.Sequential(nn.Linear(F, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 16), nn.ReLU(), )
        self.block2 = nn.Sequential(nn.Linear(15 + 3 * 4 * 2 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, 3), nn.Sigmoid())

        self.embedding_dim_direction = embedding_dim_direction
        self.scale = scale
        self.N = N

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, x, d):
        sigma = torch.zeros_like(x[:, 0])
        c = torch.zeros_like(x)

        mask = (x[:, 0].abs() < self.scale) & (x[:, 1].abs() < self.scale) & (x[:, 2].abs() < self.scale)
        xy_idx = ((x[:, [0, 1]] / (2 * self.scale) + .5) * self.N).long().clip(0, self.N - 1)  # [batch_size, 2]
        yz_idx = ((x[:, [1, 2]] / (2 * self.scale) + .5) * self.N).long().clip(0, self.N - 1)  # [batch_size, 2]
        xz_idx = ((x[:, [0, 2]] / (2 * self.scale) + .5) * self.N).long().clip(0, self.N - 1)  # [batch_size, 2]
        F_xy = self.xy_plane[xy_idx[mask, 0], xy_idx[mask, 1]]  # [batch_size, F]
        F_yz = self.yz_plane[yz_idx[mask, 0], yz_idx[mask, 1]]  # [batch_size, F]
        F_xz = self.xz_plane[xz_idx[mask, 0], xz_idx[mask, 1]]  # [batch_size, F]
        F = F_xy * F_yz * F_xz  # [batch_size, F]

        h = self.block1(F)
        h, sigma[mask] = h[:, :-1], h[:, -1]
        c[mask] = self.block2(torch.cat([self.positional_encoding(d[mask], self.embedding_dim_direction), h], dim=1))
        return c, sigma

In [6]:
# Rendering

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    weight_sum = weights.sum(-1).sum(-1)  # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)

In [7]:
@torch.no_grad()
def test(nerf_model, hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(g_device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())
    img_t = torch.cat(data).reshape(H, W, 3)
    return img_t

def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, start_epoch = 0, nb_epochs=int(1e5), nb_bins=192, H=400, W=400, eval_steps = 5):
    training_loss = []
    for e in (range(start_epoch, nb_epochs)):
        print(f"{e:03d}-epoch")
        for ep, batch in enumerate(tqdm(data_loader, total = len(data_loader), desc="Training")):
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()
        _model_path = os.path.join(g_model_path, f'nerf_model_{e:03d}.pth')
        torch.save(nerf_model.state_dict(), _model_path)

        if e % eval_steps == 0 or e == nb_epochs - 1:
            nerf_model.eval()
            imgT_lst = []
            for idx in tqdm(range(200), desc="Validation"):
                imgT_lst.append(test(nerf_model, 2, 6, testing_dataset, img_index=idx, nb_bins=192).unsqueeze(0))
            img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
            _video_path = os.path.join(g_video_path,  f'nerf_model_{e:03d}.mp4')
            torchvision.io.write_video(_video_path, img_t, fps)
            nerf_model.train()


    return training_loss

In [8]:
model = NerfModel(hidden_dim=256).to(g_device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)

data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=g_device, hn=2, hf=6, nb_bins=192, H=400,
        W=400)
# for img_index in range(200):
#     test(2, 6, testing_dataset, img_index=img_index, nb_bins=192, H=400, W=400)

000-epoch


Training: 100%|██████████| 15625/15625 [14:19<00:00, 18.17it/s]
Validation: 100%|██████████| 200/200 [05:22<00:00,  1.61s/it]


001-epoch


Training: 100%|██████████| 15625/15625 [14:31<00:00, 17.93it/s]


002-epoch


Training: 100%|██████████| 15625/15625 [14:52<00:00, 17.50it/s]


003-epoch


Training: 100%|██████████| 15625/15625 [14:49<00:00, 17.58it/s]


004-epoch


Training: 100%|██████████| 15625/15625 [14:41<00:00, 17.72it/s]


005-epoch


Training: 100%|██████████| 15625/15625 [14:56<00:00, 17.43it/s]
Validation: 100%|██████████| 200/200 [05:14<00:00,  1.57s/it]


006-epoch


Training: 100%|██████████| 15625/15625 [14:46<00:00, 17.63it/s]


007-epoch


Training: 100%|██████████| 15625/15625 [14:31<00:00, 17.94it/s]


008-epoch


Training: 100%|██████████| 15625/15625 [14:33<00:00, 17.88it/s]


009-epoch


Training: 100%|██████████| 15625/15625 [15:22<00:00, 16.93it/s]


010-epoch


Training: 100%|██████████| 15625/15625 [14:00<00:00, 18.60it/s]
Validation: 100%|██████████| 200/200 [05:13<00:00,  1.57s/it]


011-epoch


Training: 100%|██████████| 15625/15625 [14:48<00:00, 17.58it/s]


012-epoch


Training: 100%|██████████| 15625/15625 [14:45<00:00, 17.65it/s]


013-epoch


Training: 100%|██████████| 15625/15625 [14:58<00:00, 17.40it/s]


014-epoch


Training: 100%|██████████| 15625/15625 [15:04<00:00, 17.27it/s]


015-epoch


Training: 100%|██████████| 15625/15625 [14:49<00:00, 17.57it/s]
Validation: 100%|██████████| 200/200 [05:20<00:00,  1.60s/it]


[301.42633056640625,
 271.7340087890625,
 270.22711181640625,
 241.88343811035156,
 244.6896514892578,
 247.31109619140625,
 244.82174682617188,
 257.3369140625,
 254.90774536132812,
 241.6492156982422,
 253.8494873046875,
 259.4611511230469,
 244.41427612304688,
 256.9162292480469,
 242.12158203125,
 237.85336303710938,
 234.53762817382812,
 242.06610107421875,
 248.26324462890625,
 251.45741271972656,
 246.43075561523438,
 248.4685821533203,
 231.91525268554688,
 253.648193359375,
 247.74658203125,
 225.51026916503906,
 228.46397399902344,
 230.15435791015625,
 243.95828247070312,
 248.57691955566406,
 259.08673095703125,
 242.52572631835938,
 236.35211181640625,
 244.2430419921875,
 231.0795135498047,
 256.97552490234375,
 235.52493286132812,
 228.23577880859375,
 227.00892639160156,
 239.4010009765625,
 232.09121704101562,
 236.44619750976562,
 223.5017547607422,
 225.32688903808594,
 228.67697143554688,
 231.19943237304688,
 237.62783813476562,
 244.51925659179688,
 249.6893005371

In [9]:
imgT_lst = []
for idx in tqdm(range(200), desc="Validation"):
        imgT_lst.append(test(model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192, H = 400, W = 400 ).unsqueeze(0))
img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
_video_path = os.path.join(g_video_path,  f'nerf_model_test.mp4')
torchvision.io.write_video(_video_path, img_t, fps)

Validation: 100%|██████████| 200/200 [05:13<00:00,  1.57s/it]


In [10]:
# model = NerfModel(hidden_dim=256).to(g_device)
# prev_mode_path = "./temp/07_KplaneNeRF_models/nerf_model_015.pth"
# model.load_state_dict(torch.load(prev_mode_path))

# imgT_lst = []
# for idx in tqdm(range(200), desc="Validation"):
#         imgT_lst.append(test(model, 2, 6, testing_dataset, chunk_size = 10, img_index=idx, nb_bins=192, H = 400, W = 400 ).unsqueeze(0))
# img_t = (torch.cat(imgT_lst) * 255).to(torch.uint8)
# _video_path = os.path.join(g_video_path,  f'nerf_model_test.mp4')
# torchvision.io.write_video(_video_path, img_t, fps)