In [1]:
import itertools

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
 
from nerfy.renderer import RadianceFieldRenderer, RadianceFieldRendererConfig
from nerfy.dataset import get_nerf_datasets, TrivialCollator
from nerfy.utils import sample_images_at_mc_locs, calc_mse, calc_psnr

import wandb

In [3]:
# wandb.init(project='nerf')

---

## Data

In [4]:
train_dataset, val_dataset, test_dataset = get_nerf_datasets(
    dataset_name='lego',
    image_size=(200, 200)
)

Loading dataset lego, image size=(200, 200) ...
Rescaling dataset (factor=0.25)


In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=TrivialCollator())
val_dataloader = DataLoader(val_dataset, batch_size=1, collate_fn=TrivialCollator())
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=TrivialCollator())

---

## Model

In [6]:
device = torch.device('cuda:1')
config = RadianceFieldRendererConfig(
    n_rays_per_image=5_000,
    n_pts_per_ray=64,
    image_width=200, image_height=200
)

model = RadianceFieldRenderer(config).to(device).train()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, eps=1e-7)

---

## Train

In [11]:
import itertools
import imageio
from tqdm import tqdm, trange

In [9]:
EPOCH = 100

In [None]:
for epoch in trange(EPOCH):
    
    model.train()
    for batch in train_dataloader:
        
        batch = batch.to(device)
        (rgb, weights), ray_bundle = model.renderer(batch.camera, model.implicit_function)

        sampled_image = sample_images_at_mc_locs(batch.image, ray_bundle.xys)
        loss = criterion(
            sampled_image,
            rgb
        )
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        wandb.log({
            'train/psnr': calc_psnr(rgb.cpu().detach(), sampled_image.cpu().detach()),
            'train/mse': loss.item(),
        })
    
    for batch in val_dataloader:
        
        batch = batch.to(device)
        with torch.no_grad():
            (rgb, weights), ray_bundle = model.renderer(batch.camera, model.implicit_function)

        sampled_image = sample_images_at_mc_locs(batch.image, ray_bundle.xys)
        loss = criterion(
            sampled_image,
            rgb
        )
        
        wandb.log({
            'val/psnr': calc_psnr(rgb.cpu().detach(), sampled_image.cpu().detach()),
            'val/mse': loss.item(),
        })
    
    model.eval()
    test_rendered_images = []
    for batch in test_dataloader:
        
        batch = batch.to(device)
        with torch.no_grad():
            (rgb, _), _ = model.renderer(batch.camera, model.implicit_function)
        
        rgb = (rgb.cpu() * 255) \
            .squeeze(dim=0) \
            .view(config.image_height, config.image_width, -1) \
            .to(torch.uint8) \
            .numpy()
        test_rendered_images.append(rgb)
    
    imageio.mimsave(f'lego_{epoch}.gif', test_rendered_images, fps=30)
    wandb.log({
        'test/video': wandb.Video(f'lego_{epoch}.gif', fps=30),
        'epoch': epoch
    })
    

 27%|██▋       | 27/100 [53:56<2:26:37, 120.51s/it]