In [None]:
ROOT_PATH = '/'
# uncomment for gcp:
# %pdb off
# from google.colab import drive
# drive.mount('/content/drive')
import os
requirements_path = os.path.join(ROOT_PATH, 'requirements.txt')
os.system(f'pip install -r {requirements_path}')

# Imports

In [None]:
from PIL import Image
from typing import *

import cv2
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F
import torch.nn as nn
import torch.optim as optim
from torch import Tensor

import wandb
import numpy as np
from tqdm.notebook import tqdm
import einops
import ipdb
from math import ceil

# Declare constants and load Google Drive

In [None]:
NAME = 'bedroom'
NAME = 'forest1'
NAME = 'forest2'
NAME = 'sidewalk'
NAME = 'study'
NAME = 'kitchen'
NAME = 'bottle'
NAME = 'apples'
NAME = 'sourcream'

In [None]:
SPARSE_PATH = os.path.join(ROOT_PATH, f'colmap/{NAME}/sparse')
IMG_PATH = os.path.join(ROOT_PATH, f'colmap/{NAME}/images')
IMG_PATH.mkdir(exist_ok=True)
SEED = 0
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE
# Sample images

'cuda'

In [None]:
# cap = cv2.VideoCapture(f'videos/{name}.MOV')

# frame_no = 0
# every_n_frames = 10

# while cap.isOpened():
#     ret, frame = cap.read()

#     if frame_no % every_n_frames == 0:
#         target = imgpath.joinpath(f'{frame_no:06d}.jpg').as_posix()
#         cv2.imwrite(target, frame)

#     frame_no += 1
#     if not ret: break

# cap.release()
# print('done')

# Camera class and dict

In [None]:
# initialize camera objects (for now only one)

class Camera:
    def __init__(self, camera_id, model, width, height, params):
        self.camera_id = camera_id
        self.model = model
        self.width = width
        self.height = height
        self.params = params
        self.K = self._get_K()                  # size: 3 x 3
        self.K_inv = self.K.inverse()           # size: 3 x 3
        self.xy_pairs = self._get_xy_pairs()    # size: H*W, H*W
        self.d_camera = self._get_d()           # size: H*W x 3

    def _get_K(self):
        if self.model == 'PINHOLE':
            fx, fy, cx, cy = self.params
            K = torch.tensor([
                [fx, 0,  cx],
                [0,  fy, cy],
                [0,  0,  1 ],
            ]).to(DEVICE)
        elif self.model == 'SIMPLE_PINHOLE':
            f, cx, cy = self.params
            K = torch.tensor([
                [f, 0, cx],
                [0, f, cy],
                [0, 0, 1 ],
            ]).to(DEVICE)
        return K

    def _get_xy_pairs(self):
        y, x = torch.unravel_index(
            torch.arange(self.height * self.width).to(DEVICE),
            (self.height, self.width),
        )
        return x, y

    def _get_d(self):
        x, y = self.xy_pairs
        x = x.unsqueeze(0)                  # size: 1 x HW
        y = y.unsqueeze(0)                  # size: 1 x HW

        x_y_1 = torch.cat(
            (x, y, torch.ones_like(x).to(DEVICE))
        ).float().to(DEVICE)                # size: 3 x HW
        d = self.K_inv @ x_y_1              # size: (3 x 3) @ (3 x HW) = 3 x HW
        d = (d / d.norm(dim=0)).T           # size: HW x 3

        return d


def get_cameras():
    camera_file_path = os.path.join(ROOT_PATH, f'colmap/{NAME}/sparse/cameras.txt')

    with open(camera_file_path) as file:
        camera_lines = file.readlines()[3:]

    cameras = {}
    for camera_line in camera_lines:
        (camera_id, model, width, height, *params) = camera_line.strip('\n').split()
        camera_id, width, height = (int(s) for s in (camera_id, width, height))
        params = [float(s) for s in params]
        camera = Camera(camera_id, model, width, height, params)
        cameras[camera_id] = camera

    return cameras

cameras: Dict[int, Camera] = get_cameras()

# ImagePose class

In [None]:

class ImagePose:
    def __init__(self, image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name):
        self.image_id = image_id
        self.r = self._quaternions_to_matrix(qw, qx, qy, qz)
        self.t = torch.tensor([tx, ty, tz]).to(DEVICE)
        self.camera_id = camera_id
        self.name = name
        self.image = self._imgfile_to_tensor()

    def _quaternions_to_matrix(self, qw, qx, qy, qz):
        # r = Rotation.from_quat([qw, qx, qy, qz])
        # r = torch.tensor(R.as_matrix()).to(DEVICE)

        r = torch.tensor([
            [1 - 2 * (qy * qy + qz * qz),   2 * (qx * qy - qz * qw),        2 * (qx * qz + qy * qw)],
            [2 * (qx * qy + qz * qw),       1 - 2 * (qx * qx + qz * qz),    2 * (qy * qz - qx * qw)],
            [2 * (qx * qz - qy * qw),       2 * (qy * qz + qx * qw),        1 - 2 * (qx * qx + qy * qy)],
        ]).to(DEVICE)

        return r

    def _imgfile_to_tensor(self):
        path = os.path.join(IMG_PATH, self.name)
        image = Image.open(path)
        image_tensor = F.to_tensor(image).to(DEVICE)
        return image_tensor


# Config and wandb

In [None]:
config = {
    'name': 'initial-run-aws',
    'batch_size': 4096,
    'initial_lr': 5e-4,
    'final_lr': 5e-5,
    'num_iter': int(100e3),

    # these two probably shouldn't change.
    # consider removing from config
    'tn': 0,
    'tf': 1,

    'Nc': 64,
    'Nf': 128,
    'Lx': 10,
    'Ld': 4,

    'eps': 1.e-7,
}
B, TN, TF, NC, NF = config['batch_size'], config['tn'], config['tf'], config['Nc'], config['Nf']

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "nerf.ipynb"
!wandb login

In [None]:
# run.finish()

In [None]:
run = wandb.init(
    project='nerf',
    name=config['name'],
    reinit=True,
    config=config,
    # id=None, # id of run to resume
    # resume='must', # if want to resume, comment reinit
)

[34m[1mwandb[0m: Currently logged in as: [33mjay-okoro[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
run.id

'1x6v6j70'

# Create dataset and dataloader

In [None]:
'''
Collects 3D points from points_path and pairs it with various viewing
directions and cooresponding colors from those directions using images_path.
Output format: (x, d), c
    where   x is tensor of 3D location          size: 3 or B x 3
            d is tensor of viewing direction    size: 3 or B x 3
            c is tensor of RGB value            size: 3 or B x 3
'''
class TrainDataset(Dataset):
    def __init__(self, images_path=os.path.join(SPARSE_PATH, 'images.txt')):
        # values for NDC projection
        # paper uses [-cx, -cy, f] b/c they use (y up, z into camera)
        # but we use [cx, cy, f] b/c we use (y down, z out of camera)
        # b/c of colmap, but also I prefer colmap's way and would use it again
        # pose dependent: 1 -> pose.camera_id
        f, cx, cy = cameras[1].params  # cx, cy = W/2, H/2
        self.x, self.y = cameras[1].xy_pairs          # size: HW
        self.d_camera = cameras[1].d_camera           # size: HW x 3
        self.num_pixels = cameras[1].height * cameras[1].width

        self.scalar = torch.tensor([f/cx, f/cy, 1]).to(DEVICE) #  = f / [cx, cy, f]
        self.two_f = 2 * f

        with open(images_path) as file:
            self.image_lines = file.readlines()[4::2]

    # converts line in image.txt to ImagePose object
    def _get_image_pose(self, image_line):
        (image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name
        ) = image_line.strip('\n').split()

        image_id, camera_id = int(image_id), int(camera_id)
        (qw, qx, qy, qz, tx, ty, tz,
        ) = (float(s) for s in (qw, qx, qy, qz, tx, ty, tz))

        pose = ImagePose(image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name)

        return pose

    def __len__(self):
        return len(self.image_lines * self.num_pixels)

    def __getitem__(self, idx):
        image_idx = idx // self.num_pixels
        pixel_idx = idx % self.num_pixels

        image_line = self.image_lines[image_idx]
        pose = self._get_image_pose(image_line)

        # find origin of camera and directions from origin
        # to pixels in world coordinates and get c (colors)
        r, t = pose.r, pose.t.unsqueeze(0)              # size: 3x3, 1x3
        o = -t @ r          # -r.T @ t                  # size: 1 x 3
        d = self.d_camera @ r    #  r.T @ d_camera      # size: HW x 3
        c = pose.image[:, self.y, self.x].T                       # size: HW x 3

        # NDC projection
        #   links to understand NDC better:
        #   https://www.youtube.com/watch?v=U0_ONQQ5ZNM
        #   https://yconquesty.github.io/blog/ml/nerf/nerf_ndc.html#analysis
        o = o / o[:,2:3]
        d = d / d[:,2:3]

        o[:,2] += self.two_f
        d -= o

        o *= self.scalar
        d *= self.scalar

        # pick the pixel
        d = d[pixel_idx]                               # size: 3
        c = c[pixel_idx]                               # size: 3

        # resize to use broadcasting for stratified
        # sampling at training/inference stage.
        d = d.unsqueeze(0)                            # size: 1x3

        return o, d, c                                # size: 1x3, 1x3, 3

trainset = TrainDataset()

In [None]:
trainloader = DataLoader(
    trainset,
    batch_size=config['batch_size'],
    shuffle=True,
    drop_last=True,
)
num_epochs = int(ceil(config['num_iter'] * config['batch_size'] / len(trainloader)))

# Model

In [None]:
class MLP(nn.Module): # tested
    def __init__(self, in_feat, out_feat, activation=nn.ReLU()):
        super().__init__()
        self.f = nn.Sequential(
            nn.Linear(in_feat, out_feat),
            activation
        )

    def forward(self, x):
        return self.f(x)

class PositionalEncoding(nn.Module): # tested

    def __init__(self, L):
        super().__init__()

        self.L = L
        self.omega = 2**torch.arange(0, L, 1/2).int().to(DEVICE) * torch.pi

    def forward(self, x):                                       # size: BxNx3
        gamma = x.unsqueeze(-1) * self.omega                    # size: (BxNx3x1) * (2L) -> BxNx3x2L
        gamma = einops.rearrange([torch.sin(gamma[...,::2]).to(DEVICE),
                                  torch.cos(gamma[...,1::2]).to(DEVICE)],
                                 't b n h w -> b n h (w t)').to(DEVICE)
        gamma = gamma.flatten(-2)                               # size: BxNx3x2L = BxNx6L
        return gamma

class Nerf(nn.Module): # tested
    def __init__(self, Lx, Ld):
        super().__init__()
        self.pos_enc_x = PositionalEncoding(Lx)
        self.pos_enc_d = PositionalEncoding(Ld)

        self.bfr_x_res = nn.Sequential(
            nn.ReLU(),
            MLP(Lx*6,256),
            MLP(256,256),
            MLP(256,256),
            MLP(256,256),
            MLP(256,256),
        )
        self.bfr_d_in = nn.Sequential(
            MLP(256+Lx*6,256),
            MLP(256,256),
            MLP(256,256),
            MLP(256,257,nn.Identity()),
        )
        self.aft_d_in = nn.Sequential(
            MLP(256+Ld*6,128),
            MLP(128,3,nn.Sigmoid()),
        )

    def forward(self, x, d):                            #   size: BxNx3, BxNx3
        gamma_x = self.pos_enc_x(x)                     #   size: BxNx6Lx
        out = self.bfr_x_res(gamma_x)                   #   size: BxNx256
        out = torch.cat((out, gamma_x), -1)             #   size: BxNx(256+6Lx)
        out = self.bfr_d_in(out)                        #   size: BxNx257
        sig, out = out[:,:,0:1], out[:,:,1:]            #   size: BxNx1, BxNx256
        gamma_d = self.pos_enc_d(d)                     #   size: BxNx6Ld
        out = torch.cat((out, gamma_d), -1)             #   size: BxNx(256+6Ld)
        rgb = self.aft_d_in(out)                        #   size: BxNx3

        sig += torch.randn_like(sig).to(DEVICE) # paper says this is helpful for real scenes
        sig = sig.relu()

        return rgb, sig                             #   size: BxNx3, BxNx1

model = Nerf(
    config['Lx'],
    config['Ld']
).to(DEVICE)

# Loss Function

In [None]:
criterion = nn.MSELoss(reduction='sum')

# Optimizer

In [None]:
optimizer = optim.Adam(model.parameters(), lr=config['initial_lr'], eps=config['eps'])
gamma = - (np.log(config['final_lr']) - np.log(config['initial_lr'])) / num_epochs
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

# Train Functions

In [None]:
def color_weights_and_t_i(
    tn: float,
    tf: float,
    N: int,
    i_random: Tensor,
    o: Tensor,
    d: Tensor,
    model: Nerf,
    i_random_cat: Tensor = None
):
    if i_random_cat is not None:
        # concat fine and coarse pts using t parameter
        i_random = torch.cat((i_random, i_random_cat), -2)              # size: cat((BxNfx1),(Bx[Nc+1]x1), -2) -> Bx(Nf+Nc+1=N+1)x1
        i_random = i_random.sort(-2)[0]

    t_i = tn + i_random * (tf - tn) / N                                 # size: Bx(N+1)x1: (N = Nc if i_random_cat is None else N = Nf + Nc)
    delta_i = (t_i.roll(-1, -2) - t_i)[:,:-1]                           # size: ((Bx[N+1]x1) - (Bx[N+1]x1))[:,:-1] -> BxNx1

    x = o + t_i[:,:-1] * d # in NDC coords :)                           # size: (Bx1x3) + (BxNx1) * (Bx1x3) -> BxNx3
    d = d.tile((1, x.shape[-2], 1))                                     # size: BxNx3

    c_i, sigma_i = model(x, d)                                          # size: input: BxNx3, Bx3 | output: BxNx3, BxNx1
    c_i: Tuple = c_i
    sigma_i: Tuple = sigma_i

    neg_sig_dlt_i: Tensor = -sigma_i * delta_i  # [a,  b,  ..., 0]      # size: (B x N x 1) * (B x N x 1) -> B x N x 1
    neg_sig_dlt_im1 = neg_sig_dlt_i.roll(1, -2) # [0,  a,  b, ...]      # size: B x N x 1
    T_i = torch.exp(neg_sig_dlt_im1.cumsum(-2)) # [1, eA, eB, ...]      # size: B x N x 1

    w = T_i * (1 - torch.exp(neg_sig_dlt_i))                            # size: (BxNx1) * (BxNx1) -> BxNx1
    c = (w * c_i).sum(-2)                                               # size: ((BxNx1) * (BxNx3) -> BxNx3).sum(-2) -> Bx3

    return w, c



In [None]:
w_hat_cum_mask = torch.ones(B, config['Nc'], 1).to(DEVICE)
w_hat_cum_mask[:,0] = 0

def get_i_fine(
    w: Tensor,
    Nf: int,
    i: Tensor
):
    w_hat = w / w.sum(-2, True)                                                     # size: (BxNcx1) / ((BxNcx1).sum(-2, True) -> Bx1x1) -> BxNcx1
    w_hat_cum = w_hat.cumsum(-2)                                                    # size: BxNcx1

    u = torch.rand(B, 1, Nf).to(DEVICE)                                             # size: Bx1xNf
    q = (u > w_hat_cum).sum(-2).flatten() # 0≤u≤w0: [F,F,F,...].sum: idx=0          # size: (((Bx1xNf) > (BxNcx1) -> BxNcxNf).sum(-2) -> (BxNf)) -> B*Nf
    p, r = torch.unravel_index(torch.arange(B * Nf).to(DEVICE), (B, Nf))            # size: B*Nf, B*Nf

    # inverse transform sampling (u > w_hat)
    w_hat_cum = w_hat_cum.roll(1, -2) * w_hat_cum_mask
    i_fine = ((u[p, :, r] - w_hat_cum[p, q]) / w_hat[p, q] + i[q])                  # size:
                                                                                    # ( (Bx1xNf)  [(B*Nf),:,(B*Nf)]
                                                                                    # - (BxNcx1)  [(B*Nf),(B*Nf)] )
                                                                                    # / (BxNcx1)  [(B*Nf),(B*Nf)]
                                                                                    # + (Ncx1)    [(B*Nf)]
                                                                                    # -> B*Nfx1
    i_fine = i_fine.view(B, Nf, 1)                                                  # size: BxNfx1

    return i_fine

In [None]:
def train(model, trainloader):
    torch.manual_seed(SEED)

    # t: NDC coords for ray parameter
    # Nc: num bins along ray at first
    # Nf: num new samples based on Nc dist.
    i = torch.arange(float(NC+1), requires_grad=True).to(DEVICE).unsqueeze(-1)      # size: [Nc+1]x1
    for epoch in tqdm(range(num_epochs), desc='epochs'):
        for o, d, c in tqdm(trainloader, desc='minibatches', leave=False):          # size: Bx1x3, Bx1x3, Bx3
            optimizer.zero_grad()
            b = o.shape[0] # might change in last batch

            # get coarse colors
            i_coarse = i + torch.rand(b, NC+1, 1).to(DEVICE)                        # size: ([Nc+1]x1) + (Bx[Nc+1]x1) -> Bx[Nc+1]x1
            w_coarse, c_coarse = color_weights_and_t_i(TN, TF, NC,
                                                        i_coarse,
                                                        o, d, model)                # Bx3, BxNcx1

            # get fine colors
            i_fine = get_i_fine(w_coarse, NF, i)                                    # size: BxNcx1
            w_fine, c_fine = color_weights_and_t_i(TN, TF, NF, i_fine,
                                                    o, d, model, i_coarse)          # size: Bx3, BxNcx1


            loss: Tensor = criterion(c_coarse, c) + criterion(c_fine, c)

            loss.backward()

            optimizer.step()

            # not logging accuracy b/c continuous values
            # may be close but will rarely equal each other
            wandb.log({'loss': loss})
            torch.cuda.empty_cache()
        scheduler.step()


In [None]:
# Train!
train(model, trainloader)

In [None]:
# Save Model

In [None]:
model_state_dict_path = os.path.join(ROOT_PATH, "model_state_dict.pth")
torch.save(model.state_dict(), model_state_dict_path)
wandb.save(model_state_dict_path)
print("Model saved to Weights and Biases!")

In [None]:
wandb.finish()