# SMERF Demo

This notebook demonstrates how to train and render a (prebaked) SMERF model. 
It should produce identical results to `scripts/demo.sh`.

The notebook requires a CUDA-capable GPU with at least 12 GB of VRAM such as an RTX 3080 Ti. 
See `README.md` for instructions.

Licensed under the Apache License, Version 2.0

## Setup

This section imports necessary libraries and initializes configuration.

In [None]:
# @title Imports
import datetime
import functools
import os
import sys
import time

from absl import app
from absl import flags
import camp_zipnerf.internal
from etils import ecolab
from etils import epath
import flax
from flax.training import checkpoints
import gin
import jax
import jax.numpy as jnp
import mediapy as media
import numpy as np
import pycolmap
import smerf.internal

gin.enter_interactive_mode()

In [None]:
# @title Setup config

# Create a timestamp for this execution
TIMESTAMP = datetime.datetime.utcnow().strftime('%Y%m%d_%H%M')

# Setup Gin-related flags
try:
  camp_zipnerf.internal.configs.define_common_flags()
  # Jupyter will set the working directory to notebooks/. This sets it to the project root.
  os.chdir('../')
except Exception as err:
  print(err)

flags.FLAGS.gin_configs = [
    'configs/models/smerf.gin',
    'configs/mipnerf360/bicycle.gin',
    'configs/mipnerf360/extras.gin',
    'configs/mipnerf360/rtx3080ti.gin',
]
flags.FLAGS.gin_bindings = f"""
smerf.internal.configs.Config.checkpoint_dir = 'checkpoints/{TIMESTAMP}-notebook'
"""

# Parse gin configs from command line flags. There is no direct way to pass Gin configs to load_config().
app._run_init(sys.argv[:1], flags.FLAGS)

# Parse Gin configs
gin.clear_config()
smerf_config, config = smerf.internal.distill.load_config(False)

print(f'{TIMESTAMP=}')
print(f'{smerf_config.checkpoint_dir=}')
print(f'{smerf_config.baking_checkpoint_dir=}')
print(f'{smerf_config.distill_teacher_ckpt_dir=}')
print(f'{config.data_dir=}')
print(f'{config.batch_size=}')

In [None]:
# @title Print Gin config
print(gin.config_str())

In [None]:
# @title Print GPU RAM usage
def memory_stats(device):
  stats = device.memory_stats()
  gb_in_use = stats['bytes_in_use'] / 2 ** 30
  gb_available = stats['bytes_limit'] / 2 ** 30
  percent_bytes_in_use = 100 * gb_in_use / gb_available
  return f"GPU_{device.id}: {gb_in_use:0.2f} GiB of {gb_available:0.2f} GiB ({percent_bytes_in_use:0.1f}%)"

jax.tree.map(memory_stats, jax.devices())

## Dataset

This section loads the training dataset into memory.
Make sure that the `config.data_dir` points to a directory with a single scene's worth of data.

In [None]:
# @title Load Train
dataset = camp_zipnerf.internal.datasets.load_dataset('train', config.data_dir, config)
raybatcher = camp_zipnerf.internal.datasets.RayBatcher(dataset)
dataset;

In [None]:
# @title Construct camera parameters
np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x
cameras = dataset.get_train_cameras(config)
cameras = jax.tree_util.tree_map(np_to_jax, cameras)
pcameras = flax.jax_utils.replicate(cameras)
pcameras;s  # intrinsics, extrinsics, ...

## Teacher

This section initializes the teacher model. 
Make sure that `config.checkpoint_dir` points to a directory with a pretrained `camp_zipnerf` model checkpoint.

To verify that the checkpoint loaded successfully, a single training camera is rendered.

In [None]:
# @title Setup Teacher
rng = jax.random.PRNGKey(config.jax_rng_seed)
teacher_model, teacher_state, teacher_render_eval_pfn, _, _ = camp_zipnerf.internal.train_utils.setup_model(
    config, rng, dataset=dataset
)
teacher_state.params;s

In [None]:
# @title Reload teacher's state from disk
teacher_state = checkpoints.restore_checkpoint(config.checkpoint_dir, teacher_state)
print('step:', teacher_state.step)

In [None]:
# @title Replicate state across devices.
teacher_pstate = flax.jax_utils.replicate(teacher_state)
teacher_pvariables = teacher_pstate.params

In [None]:
# @title Render a frame.

CAM_IDXS = [0]  # Which train images to render.

def main():
  for cam_idx in CAM_IDXS:
    # Prepare rays
    rays = smerf.internal.datasets.cam_to_rays(dataset, cam_idx, xnp=jnp)

    # Render rays
    start = time.time()
    teacher_rendering = camp_zipnerf.internal.models.render_image(
        functools.partial(
            teacher_render_eval_pfn,
            teacher_pstate.params,
            1.0,
            None,  # No cameras needed
        ),
        rays=rays,
        rng=None,
        config=config,
        return_all_levels=True,
    )
    teacher_rendering = jax.device_get(teacher_rendering)

    # Print elapsed time
    end = time.time()
    elapsed = end - start
    print(f'Elapsed time: {elapsed:0.2f}')
    print(f'Resolution: {teacher_rendering["rgb"].shape}')

    media.show_images(
        {'gt': dataset.images[cam_idx], 'teacher': teacher_rendering['rgb']},
        ylabel=f'{cam_idx=}',
        width=800,
    )


main()

## SMERF

### Setup

This section loads the test dataset, finalizes SMERF's config, and initializes the SMERF model.

In [None]:
# @title Load test dataset
test_dataset = camp_zipnerf.internal.datasets.load_dataset('test', config.data_dir, config)
test_raybatcher = camp_zipnerf.internal.datasets.RayBatcher(test_dataset)

In [None]:
# @title Initialize grid_config
smerf_config = smerf.internal.grid_utils.initialize_grid_config(
    smerf_config, [dataset, test_dataset]
)
hash(smerf_config)  # Make sure hashing works

{
    'grid_config': smerf_config.grid_config,
    'exposure_config': smerf_config.exposure_config,
}

In [None]:
# @title Initialize model
model, state, _, train_pstep, _ = smerf.internal.train_utils.setup_model(
    smerf_config, jax.random.PRNGKey(smerf_config.model_seed), dataset
)
smerf_render_eval_pfn = smerf.internal.distill.create_prender_student(
    teacher_model=teacher_model,
    student_model=model,
    merf_config=smerf_config,
    alpha_threshold=smerf.internal.baking.final_alpha_threshold(smerf_config),
    return_ray_results=True,
)
state.params;s

In [None]:
# @title Replicate state across devices
pstate = flax.jax_utils.replicate(state)
pstate.params;s

### Train

This is the main training loop.

In [None]:
# @title Library

# Function for generating teacher supervision.
prender_teacher = smerf.internal.distill.create_prender_teacher(teacher_model, config)


def mse_to_psnr(v):
  return -10 * np.log10(v) if v > 0 else 0.0
    

def render_example(dataset, cam_idx, teacher_pstate, smerf_pstate):
  """Renders a single camera with the teacher and student."""

  # Construct camera rays
  assert (
      0 <= cam_idx < len(dataset.images)
  ), f'{cam_idx=} is not in this dataset.'
  val_rays = smerf.internal.datasets.cam_to_rays(dataset, cam_idx)
  val_rays = smerf.internal.datasets.preprocess_rays(
      rays=val_rays, mode='test', merf_config=smerf_config, dataset=dataset
  )

  # Render the teacher.
  teacher_rendering = camp_zipnerf.internal.models.render_image(
      functools.partial(
          teacher_render_eval_pfn,
          teacher_pstate.params,
          1.0,
          None,  # No cameras needed
      ),
      rays=val_rays,
      rng=None,
      config=config,
      return_all_levels=True,
  )

  # Render the student.
  smerf_rendering = smerf.internal.models.render_image(
      functools.partial(
          smerf_render_eval_pfn,
          teacher_pstate.params,
          smerf_pstate.params,
          1.0,
      ),
      rays=val_rays,
      rng=None,
      config=smerf_config,
      verbose=False,
  )

  # Calculate step
  step = int(smerf_pstate.step[0] // smerf_config.gradient_accumulation_steps)

  # Visualize
  media.show_images(
      {
          'gt': dataset.images[cam_idx],
          'teacher': teacher_rendering['rgb'],
          'student': smerf_rendering['rgb'],
      },
      ylabel=f'{cam_idx=} {step=}',
      width=2000 // 3,
  )

In [None]:
# @title Train
#
# This section will train a SMERF model, printing the train losses every `PRINT_EVERY` 
# steps and rendering a single test camera every `RENDER_EVERY` steps.
#

PRINT_EVERY = 10     # Print loss every N steps
RENDER_EVERY = 100   # Render a test image every N steps

def train(pstate):
  teacher_config = config
  # Prepare dataset for iterating
  p_raybatcher = flax.jax_utils.prefetch_to_device(raybatcher, 3)
  prng = jax.random.split(jax.random.PRNGKey(1234567), jax.local_device_count())

  step = int(state.step) // smerf_config.gradient_accumulation_steps
  for i in range(step, smerf_config.max_steps):
    # Compute fraction of training complete.
    train_frac = np.clip((i - 1) / (smerf_config.max_steps - 1), 0, 1)

    for j in range(smerf_config.gradient_accumulation_steps):
      pbatch = next(p_raybatcher)

      # Cast rays
      pbatch = pbatch.replace(
          rays=smerf.internal.datasets.preprocess_rays(
              rays=pbatch.rays,
              mode='train',
              merf_config=smerf_config,
              dataset=dataset,
              pcameras=pcameras,
              prng=prng,
          ),
      )

      # Push ray origins forward along camera rays.
      pbatch = smerf.internal.train_utils.pshift_batch_forward(
          prng=prng,
          pbatch=pbatch,
          teacher_pstate=teacher_pstate,
          prender_teacher=prender_teacher,
          config=smerf_config,
      )

      # Render teacher.
      teacher_prng = prng if smerf_config.distill_teacher_use_rng else None
      pteacher_history = prender_teacher(teacher_prng, teacher_pstate, pbatch)

      # Update SMERF parameters.
      pstate, pstats, prng = train_pstep(
          prng, pstate, pbatch, pteacher_history, train_frac
      )

    # Print
    if i % PRINT_EVERY == 0:
      stats = flax.jax_utils.unreplicate(pstats)
      psnr = {k: f"{mse_to_psnr(v):0.2f}" for k, v in stats['losses'].items()}
      print(f'{i:05d}: {psnr}')

    # Render a test frame.
    if i % RENDER_EVERY == 0:
      render_example(test_dataset, 0, teacher_pstate, pstate)

    yield pstate

# Run train loop
pstate_iter = train(pstate)
for pstate in pstate_iter:
  pass
print('Done!')