In [1]:
# toggle to False to skip long pre-processes
restart = False
restart = True

# Imports

In [2]:
# !pip install -r requirements.txt

In [3]:
from pathlib import Path 
from PIL import Image
from typing import *

import cv2
import pickle 

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F
import torch.nn as nn
import torch.optim as optim
from torch import Tensor

import wandb
import numpy as np
from tqdm.notebook import tqdm
import einops 


# Sample images and declare constants

In [4]:
name = 'bedroom'
name = 'forest1'
name = 'forest2'
name = 'sidewalk'
name = 'study'
name = 'kitchen'
name = 'bottle'
name = 'apples'
name = 'sourcream'

In [5]:
SPARSE_PATH = Path(f'colmap/{name}/sparse')
IMG_PATH = Path(f'colmap/{name}/images')
IMG_PATH.mkdir(exist_ok=True)
SEED = 0


In [6]:
# cap = cv2.VideoCapture(f'videos/{name}.MOV')

# frame_no = 0
# every_n_frames = 10

# while cap.isOpened():
#     ret, frame = cap.read()
    
#     if frame_no % every_n_frames == 0:
#         target = imgpath.joinpath(f'{frame_no:06d}.jpg').as_posix()
#         cv2.imwrite(target, frame)
        
#     frame_no += 1
#     if not ret: break

# cap.release()
# print('done')

# Camera class and dict

In [7]:
# initialize camera objects (for now only one)

class Camera:
    def __init__(self, camera_id, model, width, height, params):
        self.camera_id = camera_id
        self.model = model
        self.width = width
        self.height = height
        self.params = params
        self.K = self._get_K()                  # size: 3 x 3
        self.K_inv = self.K.inverse()           # size: 3 x 3
        self.xy_pairs = self._get_xy_pairs()    # size: H*W, H*W
        self.d_camera = self._get_d()           # size: H*W x 3
    
    def _get_K(self):
        if self.model == 'PINHOLE':
            fx, fy, cx, cy = self.params
            K = torch.tensor([
                [fx, 0,  cx],
                [0,  fy, cy],
                [0,  0,  1 ],
            ])
        elif self.model == 'SIMPLE_PINHOLE':
            f, cx, cy = self.params
            K = torch.tensor([
                [f, 0, cx],
                [0, f, cy],
                [0, 0, 1 ],
            ])
        return K
    
    def _get_xy_pairs(self):
        y, x = torch.unravel_index(
            torch.arange(self.height * self.width), 
            (self.height, self.width)
        )
        
        return x, y

    def _get_d(self):
        x, y = self.xy_pairs
        x = x.unsqueeze(0)                  # size: 1 x HW
        y = y.unsqueeze(0)                  # size: 1 x HW
                                            
        x_y_1 = torch.cat(                  
            (x, y, torch.ones_like(x))      
        ).float()                           # size: 3 x HW
        d = self.K_inv @ x_y_1              # size: (3 x 3) @ (3 x HW) = 3 x HW
        d = (d / d.norm(dim=0)).T           # size: HW x 3
                                                    
        return d                                               


def get_cameras():
    camera_file=f'colmap/{name}/sparse/cameras.txt'

    with open(camera_file) as file:
        camera_lines = file.readlines()[3:]
    
    cameras = {}
    for camera_line in camera_lines:
        (camera_id, model, width, height, *params) = camera_line.strip('\n').split()
        camera_id, width, height = (int(s) for s in (camera_id, width, height))
        params = [float(s) for s in params]
        camera = Camera(camera_id, model, width, height, params)
        cameras[camera_id] = camera
        
    return cameras

cameras: Dict[int, Camera] = get_cameras()

# ImagePose class

In [8]:

class ImagePose:
    def __init__(self, image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name):
        self.image_id = image_id
        self.r = self._quaternions_to_matrix(qw, qx, qy, qz)
        self.t = torch.tensor([tx, ty, tz])
        self.camera_id = camera_id
        self.name = name
        self.image = self._imgfile_to_tensor()
    
    def _quaternions_to_matrix(self, qw, qx, qy, qz):
        # R = Rotation.from_quat([qw, qx, qy, qz])
        # R = torch.tensor(R.as_matrix())
        
        R = torch.tensor([
            [1 - 2 * (qy * qy + qz * qz),   2 * (qx * qy - qz * qw),        2 * (qx * qz + qy * qw)],
            [2 * (qx * qy + qz * qw),       1 - 2 * (qx * qx + qz * qz),    2 * (qy * qz - qx * qw)],
            [2 * (qx * qz - qy * qw),       2 * (qy * qz + qx * qw),        1 - 2 * (qx * qx + qy * qy)],
        ])
        return R
    
    def _imgfile_to_tensor(self):
        path = IMG_PATH.joinpath(self.name).as_posix()
        image = Image.open(path)
        image_tensor = F.to_tensor(image)
        return image_tensor


# Config and wandb

In [9]:
config = {
    'name': 'initial-run',
    'batch_size': 4096,
    'initial_lr': 5e-4,
    'final_lr': 5e-5,
    'num_iter': int(100e3),
    
    # these two probably shouldn't change. 
    # consider removing from config
    'tn': 0,
    'tf': 1,
    
    'Nc': 64,
    'Nf': 128,
    'Lx': 10, 
    'Ld': 4, 
}
B, TN, TF, NC, NF = config['batch_size'], config['tn'], config['tf'], config['Nc'], config['Nf'] 

In [10]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "nerf.ipynb"
# in terminal: 
# >> wandb login --relogin
# paste API key from your account

In [11]:
# run.finish()

In [12]:
run = wandb.init(
    project='nerf',
    name=config['name'],
    reinit=True,
    config=config,
    # id=None, # id of run to resume
    # resume='must', # if want to resume, comment reinit
)

[34m[1mwandb[0m: Currently logged in as: [33mjay-okoro[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
run.id

'514fbzzo'

# Create dataset and dataloader

In [14]:
if (
    restart
    # or True
):
    '''
    Collects 3D points from points_path and pairs it with various viewing 
    directions and cooresponding colors from those directions using images_path.
    Output format: (x, d), c
        where   x is tensor of 3D location          size: 3 or B x 3
                d is tensor of viewing direction    size: 3 or B x 3
                c is tensor of RGB value            size: 3 or B x 3
    '''
    class TrainDataset(Dataset):
        def __init__(self, images_path=SPARSE_PATH.joinpath('images.txt').as_posix()):
            self.num_pixels = cameras[1].height * cameras[1].width
            
            x, y = cameras[1].xy_pairs                          # size: HW
            d_camera = cameras[1].d_camera                      # size: HW x 3 
            f, cx, cy = cameras[1].params # cx, cy = W/2, H/2

            # values for NDC projection
            # paper uses [-cx, -cy, f] b/c they use (y up, z into camera)
            # but we use [cx, cy, f] b/c we use (y down, z out of camera) 
            # b/c of colmap, but also I prefer colmap's way and would use it again
            scalar = torch.tensor([f/cx, f/cy, 1]) #  = f / [cx, cy, f]
            two_f = 2 * f
            
            with open(images_path) as file:
                image_lines = file.readlines()[4::2]

            o_list, d_list, c_list = [], [], []
            for image_line in image_lines:
                pose = self._get_image_pose(image_line)
                
                # # for pose-dependent camera
                # x, y = cameras[pose.camera_id].xy_pairs       # size: HW
                # d_camera = cameras[pose.camera_id].d_camera   # size: HW x 3 
                # f, cx, cy = cameras[pose.camera_id].params 
                # # cx, cy = W/2, H/2
                
                # find origin of camera and directions from origin 
                # to pixels in world coordinates and get c (colors)
                r, t = pose.r, pose.t.unsqueeze(0)              # size: 3x3, 1x3
                o = -t @ r          # -r.T @ t                  # size: 1 x 3
                d = d_camera @ r    #  r.T @ d_camera           # size: HW x 3
                c = pose.image[:, y, x].T                       # size: HW x 3
                                                                
                # NDC projection                                
                #   links to understand NDC better:
                #   https://www.youtube.com/watch?v=U0_ONQQ5ZNM
                #   https://yconquesty.github.io/blog/ml/nerf/nerf_ndc.html#analysis
                o = o / o[:,2:3]                                
                d = d / d[:,2:3]                                
                                                                
                o[:,2] += two_f                                 
                d -= o                                          
                                                                
                o *= scalar                                     
                d *= scalar                                     

                # resize to use broadcasting for stratified 
                # sampling at training/inference stage.
                # doing it here is more efficient
                o = o.unsqueeze(1)                              # size: 1x1x3
                d = d.unsqueeze(1)                              # size: HWx1x3
                                                                
                # add to dataset                                
                o_list.append(o)                                
                d_list.append(d)                                
                c_list.append(c)                                
                                                                # A = num_images
            self.o_tensor = torch.cat(o_list)                   # size: A x 1 x 3
            self.d_tensor = torch.cat(d_list)                   # size: HWA x 1 x 3
            self.c_tensor = torch.cat(c_list)                   # size: HWA x 3
                        
        # converts line in image.txt to ImagePose object
        def _get_image_pose(self, image_line):
            (image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name
            ) = image_line.strip('\n').split()
            
            image_id, camera_id = int(image_id), int(camera_id)
            (qw, qx, qy, qz, tx, ty, tz, 
            ) = (float(s) for s in (qw, qx, qy, qz, tx, ty, tz))
            
            pose = ImagePose(image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name)
            
            return pose

        def __len__(self):
            return len(self.d_tensor)

        def __getitem__(self, idx):
            o = self.o_tensor[idx//self.num_pixels]             # size: 1 x 3
            d = self.d_tensor[idx]                              # size: 1 x 3
            c = self.c_tensor[idx]                              # size: 3
            return o, d, c
        
    trainset = TrainDataset()

In [15]:
trainloader = DataLoader(
    trainset, 
    batch_size=config['batch_size'],
    shuffle=True,
    drop_last=True,
)
num_epochs = config['num_iter'] * config['batch_size'] // len(trainloader)

# Model 

In [16]:
class MLP(nn.Module): # tested
    def __init__(self, in_feat, out_feat, activation=nn.ReLU()):
        super().__init__()     
        self.f = nn.Sequential(
            nn.Linear(in_feat, out_feat),
            activation
        )

    def forward(self, x):
        return self.f(x)

class PositionalEncoding(nn.Module): # tested

    def __init__(self, L):
        super().__init__()

        self.L = L
        self.omega = 2**torch.arange(0, L, 1/2).int() * torch.pi

    def forward(self, x):                                       # size: BxNx3
        gamma = x.unsqueeze(-1) * self.omega                    # size: (BxNx3x1) * (2L) -> BxNx3x2L
        gamma = einops.rearrange([torch.sin(gamma[...,::2]),
                                  torch.cos(gamma[...,1::2])], 
                                 't b n h w -> b n h (w t)')
        gamma = gamma.flatten(-2)                               # size: BxNx3x2L = BxNx6L
        return gamma 

class Nerf(nn.Module): # tested
    def __init__(self, Lx, Ld):
        super().__init__()
        self.pos_enc_x = PositionalEncoding(Lx)
        self.pos_enc_d = PositionalEncoding(Ld)
        
        self.bfr_x_res = nn.Sequential(
            nn.ReLU(),
            MLP(Lx*6,256),
            MLP(256,256),
            MLP(256,256),
            MLP(256,256),
            MLP(256,256),
        )
        self.bfr_d_in = nn.Sequential(
            MLP(256+Lx*6,256),
            MLP(256,256),
            MLP(256,256),
            MLP(256,257,nn.Identity()),
        )
        self.aft_d_in = nn.Sequential(
            MLP(256+Ld*6,128),
            MLP(128,3,nn.Sigmoid()),
        )

    def forward(self, x, d):                        #   size: BxNx3, BxNx3
        gamma_x = self.pos_enc_x(x)                 #   size: BxNx6Lx
        out = self.bfr_x_res(gamma_x)               #   size: BxNx256
        out = torch.cat((out, gamma_x), -1)         #   size: BxNx(256+6Lx)
        out = self.bfr_d_in(out)                    #   size: BxNx257
        sig, out = out[:,:,0:1], out[:,:,1:]        #   size: BxNx1, BxNx256
        gamma_d = self.pos_enc_d(d)                 #   size: BxNx6Ld
        out = torch.cat((out, gamma_d), -1)         #   size: BxNx(256+6Ld) torch.Size([1024, 63, 257]) torch.Size([1024, 64, 24])
        rgb = self.aft_d_in(out)                    #   size: BxNx3
        
        sig += torch.randn_like(sig) # paper says this is helpful for real scenes               
        sig = sig.relu()                                 

        return rgb, sig                             #   size: BxNx3, BxNx1

model = Nerf(
    config['Lx'],
    config['Ld']
)

# Loss Function

In [17]:
criterion = nn.MSELoss(reduction='sum')

# Optimizer

In [18]:
optimizer = optim.Adam(model.parameters(), lr=config['initial_lr'], eps=1.e-7)
gamma = - (np.log(config['final_lr']) - np.log(config['initial_lr'])) / num_epochs
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

# Train

In [19]:
delta_i_mask = torch.ones(config['Nc'], 1, requires_grad=False)
delta_i_mask_cat = torch.ones(config['Nc'] + config['Nf'], 1, requires_grad=False)
delta_i_mask[-1] = 0
delta_i_mask_cat[-1] = 0

def color_weights_and_t_i(
    tn: float, 
    tf: float, 
    N: int, 
    i_random: Tensor, 
    o: Tensor, 
    d: Tensor, 
    model: Nerf, 
    t_i_cat: Tensor = None
):
    t_i = tn + i_random * (tf - tn) / N                                 # size: BxN1x1 
    if t_i_cat is not None: 
        # concat fine and coarse pts using t parameter
        # t's must be in right order or integral will be wrong
        t_i = torch.cat((t_i, t_i_cat), -2).sort(-2)[0]                 # size: cat((BxN1x1),(BxN2x1), -2) -> Bx(N1+N2=N)x1
    
    x = o + t_i * d # in NDC coords :)                                  # size: (Bx1x3) + (BxNx1) * (Bx1x3) -> BxNx3
    d = d.tile((1, x.shape[-2], 1))                                     # size: BxNx3
    
    c_i, sigma_i = model(x, d)                                          # size: input: BxNx3, Bx3 | output: BxNx3, BxNx1
    c_i: Tuple = c_i
    sigma_i: Tuple = sigma_i

    delta_i = (t_i.roll(-1, -2) - t_i) * (delta_i_mask 
                                          if t_i_cat is None 
                                          else delta_i_mask_cat)        # size: N x 1

    neg_dlt_sig_i: Tensor = -delta_i * sigma_i  # [a, b, ..., 0]        # size: (N x 1) * (B x N x 1) -> B x N x 1
    neg_dlt_sig_im1 = neg_dlt_sig_i.roll(1, -2) # [0, a, b, ...]        # size: B x N x 1
    T_i = neg_dlt_sig_im1.cumsum(-2)            # [0, A, B, ...]        # size: B x N x 1

    w = T_i * (1 - torch.exp(neg_dlt_sig_i))                            # size: (BxNx1) * (BxNx1) -> BxNx1
    c = (w * c_i).sum(-2)                                               # size: ((BxNx1) * (BxNx3) -> BxNx3).sum(-2) -> Bx3
    
    return c, w, t_i



In [20]:
w_hat_cum_mask = torch.ones(B, config['Nc'], 1, requires_grad=False)
w_hat_cum_mask[:,0] = 0

def get_i_fine(
    w: Tensor, 
    # b: int, 
    Nf: int, 
    i: Tensor
):
    w_hat = w / w.sum(-2, True)                                     # size: (BxNcx1) / ((BxNcx1).sum(-2, True) -> Bx1x1) -> BxNcx1
    w_hat_cum = w_hat.cumsum(-2)                                    # size: BxNcx1
    
    u = torch.rand(B, 1, Nf)                                        # size: Bx1xNf
    idx = (u > w_hat_cum).sum(-2) # 0≤u≤w0: [F,F,F,...].sum: idx=0  # size: ((Bx1xNf) > (BxNcx1) -> BxNcxNf).sum(-2) -> BxNf
    
    w_hat_cum = w_hat_cum.roll(1, -2) * w_hat_cum_mask
    i_fine = (u - w_hat_cum) / w_hat + i # inv transf sample funcs  # size: ((Bx1xNf) - (BxNcx1) -> BxNcxNf) / (BxNcx1) + (Ncx1) -> BxNcxNf
    
    q = idx.flatten()                                               # size: BNf
    p, r = torch.unravel_index(torch.arange(B * Nf), (B, Nf))       # size: BNf, BNf
    i_fine = i_fine[p, q, r].view(B, Nf, 1) # pick correct function # size: BxNfx1
    
    return i_fine




In [21]:
def train(model, trainloader):
    torch.manual_seed(SEED)

    # t: NDC coords for ray parameter
    # Nc: num bins along ray at first
    # Nf: num new samples based on Nc dist.
    i = torch.arange(NC).unsqueeze(-1)                                              # size: Ncx1
    for epoch in tqdm(range(num_epochs), desc='epochs'):
        for o, d, c in tqdm(trainloader, desc='minibatches', leave=False):          # size: Bx1x3, Bx1x3, Bx3        
            optimizer.zero_grad()   
            b = o.shape[0] # might change in last batch 

            # get coarse colors 
            i_coarse = i + torch.rand(b, NC, 1)                                     # size: (Ncx1) + (BxNcx1) -> BxNcx1
            c_coarse, w, t_i = color_weights_and_t_i(TN, TF, NC, 
                                                     i_coarse, 
                                                     o, d, model)                   # Bx3, BxNcx1

            # get fine colors   
            i_fine = get_i_fine(w, NF, i)                                           # size: BxNcx1
            c_fine, _, _ = color_weights_and_t_i(TN, TF, NF, i_fine, 
                                                 o, d, model, t_i)                  # size: Bx3, BxNcx1
            
            loss: Tensor = criterion(c_coarse, c) + criterion(c_fine, c)
            loss.backward()
            optimizer.step()
            
            # not logging accuracy b/c continuous values 
            # may be close but will rarely equal each other
            wandb.log({'loss': loss}) 
        scheduler.step()


In [22]:
# torch.autograd.set_detect_anomaly(False)
train(model, trainloader)

epochs:   0%|          | 0/8256 [00:00<?, ?it/s]

minibatches:   0%|          | 0/49612 [00:00<?, ?it/s]

# Save Model

In [None]:
model_path = "model.pth"
torch.save(model, model_path)
wandb.save(model_path)
print("Model saved to Weights and Biases!")

Model saved to Weights and Biases!


In [None]:
wandb.finish()

0,1
loss,▁

0,1
loss,6769.81689


Run stopped
