Licensed under the Apache License, Version 2.0

# yobo x NGP, interactive training/rendering with Multiscope (v3)

In [None]:
import os
import time

import flax
from flax.training import checkpoints
import gin
import jax
import jax.extend
import optax
from jax import random
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import functools

import mediapy as media
from six.moves import reload_module
from colabtools import adhoc_import, frontend
from colabtools.interactive_widgets import ProgressIter

port = multiscope.start_server()
renderer = None

In [None]:
# Thanks to using adhoc_import, you can edit these files in Cider, then use
# reload_module to update them w/o relaunching or even restarting the runtime.

backend = jax.extend.backend.get_backend()
for buf in backend.live_buffers():
   buf.delete()

gin.clear_config()
gin.unlock_config()


from google_research.yobo.internal import configs
from google_research.yobo.internal import grid_utils
from google_research.yobo.internal import camera_utils
camera_utils = reload_module(camera_utils)
from google_research.yobo.internal import datasets
datasets = reload_module(datasets)
from google_research.yobo.internal import math
math = reload_module(math)
from google_research.yobo.internal import render
render = reload_module(render)
from google_research.yobo.internal import coord
coord = reload_module(coord)
from google_research.yobo.internal import sample_net_utils
sample_net_utils = reload_module(sample_net_utils)
from google_research.yobo.internal.inverse_render import render_utils
render_utils = reload_module(render_utils)
from google_research.yobo.internal import models
models = reload_module(models)
from google_research.yobo.internal import sampling
sampling = reload_module(sampling)
from google_research.yobo.internal import geometry
geometry = reload_module(geometry)
from google_research.yobo.internal import integration
integration = reload_module(integration)
from google_research.yobo.internal import shading
shading = reload_module(shading)
from google_research.yobo.internal import material
material = reload_module(material)
from google_research.yobo.internal import stepfun
stepfun = reload_module(stepfun)
from google_research.yobo.internal import train_utils
train_utils = reload_module(train_utils)
from google_research.yobo.internal import loss_utils
loss_utils = reload_module(loss_utils)
from google_research.yobo.internal import utils
utils = reload_module(utils)
from google_research.yobo.internal import vis

from google_research.yobo import multiscope_renderer
multiscope_renderer = reload_module(multiscope_renderer)


depot_base = ''
config_base = depot_base + 'third_party/google_research/google_research/yobo/configs/'

for d in [depot_base, config_base]:
  if d not in gin.config._LOCATION_PREFIXES:
    gin.add_config_file_search_path(d)

# Load configs

## Dataset config


In [None]:
# Dataset


# Cache Checkpoint
ckpt_dir = None

# Cornelly

# Lego small light

# Scraperbikes

# Configs
config = None

# Other settings
use_material = True
use_light_sampler = False

optimize_cache = False or (not use_material)
resample_material = True and use_material
render_variate = True

jitter_rays = 0 if not use_material else 0
anneal_slope = 10.0 if ckpt_dir is None else 0.0

num_secondary_samples = (8 if resample_material else 2)

scale_fac = 4 if ckpt_dir is not None else 1
scale_fac = 4 if use_light_sampler and not use_material else scale_fac

batch_size = 65536 // scale_fac
grad_accum_steps = 1
max_steps = 25000 * scale_fac

lr_init = 0.01 / scale_fac
lr_final = 0.001 / scale_fac
lr_delay_steps = 2500 * scale_fac

lr_init_cache = (0.01 if ckpt_dir is None else 0.0005) / scale_fac
lr_final_cache = (0.001 if ckpt_dir is None else 0.00005) / scale_fac
lr_delay_steps_cache = (2500 if ckpt_dir is None else 0) * scale_fac

lr_init_material = (0.005 if ckpt_dir is None else 0.0005) / scale_fac
lr_final_material = (0.0005 if ckpt_dir is None else 0.00005) / scale_fac
lr_delay_steps_material = (2500 if ckpt_dir is None else 0) * scale_fac

lr_init_light = (0.001 if not use_material else 0.0005) / scale_fac
lr_final_light = (0.0001 if not use_material else 0.00005) / scale_fac
lr_delay_steps_light = (0 if not use_material else 0) * scale_fac

extra_opt_params = {
    'Cache': {
        'lr_delay_steps': lr_delay_steps_cache,
        'lr_final': lr_final_cache * optimize_cache,
        'lr_init': lr_init_cache * optimize_cache,
    },
    'MaterialShader': {
        'lr_delay_steps': lr_delay_steps_material,
        'lr_final': lr_final_material,
        'lr_init': lr_init_material,
    },
    'LightSampler': {
        'lr_delay_steps': lr_delay_steps_light,
        'lr_final': lr_final_light,
        'lr_init': lr_init_light,
    },
}

## Model config

In [None]:
# Config
# config_files = ['blender_ngp_yobo_material_cornelly.gin']
# config_files = ['blender_ngp_yobo_material_lego.gin']
# config_files = ['real_ngp_yobo_material_scraperbikes.gin']


gin_configs = [config_base + f for f in config_files]

gin_bindings = [
  f'Config.ckpt_dir = "{ckpt_dir}"',
  f'Config.max_steps = {max_steps * grad_accum_steps}',
  f'Config.batch_size = {batch_size}',
  f'Config.grad_accum_steps = {grad_accum_steps}',
  f'Config.lr_init = {lr_init}',
  f'Config.lr_final = {lr_final}',
  f'Config.lr_delay_steps = {lr_delay_steps}',
  f'Config.extra_opt_params = {extra_opt_params}',
  f'ProposalVolumeSampler.anneal_slope = {anneal_slope}',
  f'MaterialModel.use_material = {use_material}',
  f'MaterialModel.use_light_sampler = {use_light_sampler}',
  f'MaterialModel.resample_material = {resample_material}',
  f'MaterialModel.render_variate = {render_variate}',
  f'MaterialMLP.num_secondary_samples = {num_secondary_samples}',
  f'MaterialMLP.render_num_secondary_samples = {num_secondary_samples}',
]

gin.clear_config()
gin.parse_config_files_and_bindings(gin_configs, gin_bindings, skip_unknown=True)
config = configs.Config()
print(gin.config_str())

# Load dataset

In [None]:
# Load dataset.
dataset = datasets.load_dataset('train', config.data_dir, config)
multiscope_renderer.plot_poses(dataset.camtoworlds, eps=.05)

In [None]:
print(jnp.max(dataset.images))

# Load model

In [None]:
# Load config again
gin.clear_config()
gin.parse_config_files_and_bindings(gin_configs, gin_bindings, skip_unknown=True)
config = configs.Config()

# Create model and training functions.
# dataset.reload_mesh(config)
model, train_state, render_eval_pfn, train_pstep, _ = train_utils.setup_model(
    config, random.PRNGKey(np.random.randint(1000)), dataset
)

# Restore cache checkpoint
train_state = train_utils.restore_partial_checkpoint(
    config, train_state,
    prefixes=(
        ['Cache'] + (['LightSampler'] if use_material and use_light_sampler else [])
    ),
    replace_dict={
        'Cache': 'Cache',
        'LightSampler': 'LightSampler',
    }
)

In [None]:
# (Optionally) restore all
train_state = checkpoints.restore_checkpoint(config.ckpt_dir, train_state)

# Training

## Model Training Loop

In [None]:
## For the model

# Create the renderer.
multiscope.reset()
spl = multiscope_renderer.Spliner() if renderer is None else renderer.controller.spl

scale_factor = 2 if 'llff' not in config.dataset_loader else 8
width = ((dataset.width // scale_factor) // 16) * 16

renderer = multiscope_renderer.MultiscopeRenderer(dataset, config, model, train_state, train_pstep, spl, hwf_init=(
    width,
    width,
    (float(width) / dataset.width) / dataset.pixtocams[0, 0, 0]
  )
)

# Uncomment this if you want to start training right away:
renderer.training = True

# Run one step to jit the render function.
renderer.step()

In [None]:
if True:
  # Open the page
  frontend.OpenUrl(multiscope.get_dashboard_url(port))

  # Set training to true
  renderer.training = True

  # Run the renderer indefinitely.
  while True:
    renderer.step()
    #time.sleep(.01)

# Checkpoint

In [None]:
# Save checkpoint
from datetime import date
today = date.today()
scene_name = config.data_dir.split('/')[-1]

if 'tiny' in config_files[0]:
  model_suffix = 'tiny'
elif 'small' in config_files[0]:
  model_suffix = 'small'
else:
  model_suffix = 'large'

ckpt_dir = f'{scene_name}/{today}/{model_suffix}' + ('/light_sampler' if use_light_sampler else '') + ('/material' if use_material else '')

train_state = flax.jax_utils.unreplicate(renderer.state)
train_step = flax.jax_utils.unreplicate(renderer.state.step)
checkpoints.save_checkpoint(ckpt_dir=ckpt_dir, target=train_state, step=train_step, overwrite=True)