### Tiny NeRF + StableDiffusion + Score Distillation Sampling (Dreamfusion)

TinyNeRF Based on: https://colab.research.google.com/drive/1rO8xo0TemN67d4mTpakrKrLp03b9bgCX#scrollTo=ptTYjWao3VsM
from https://github.com/krrish94/nerf-pytorch, which reimplements https://www.matthewtancik.com/nerf

Stable Diffusion from https://github.com/CompVis/stable-diffusion,
using parts of scripts/txt2img.py and ldm/models/diffusion/ddim.py

To run/install has the same requirements as the CompVis stable-diffusion repo

In [None]:
from typing import Optional

import numpy as np
import torch
import matplotlib.pyplot as plt
from imageio import imsave

# Imports for stablediffusion txt2img

import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
import time

from pytorch_lightning import seed_everything
from torch import autocast

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


In [None]:
# Monkeypatch the ldm ddpm.py LatentDiffusion class instance
# to remove torch.no_grad decorator
def encode_first_stage(self, x):
    if hasattr(self, "split_input_params"):
        if self.split_input_params["patch_distributed_vq"]:
            ks = self.split_input_params["ks"]  # eg. (128, 128)                                                                                                                     
            stride = self.split_input_params["stride"]  # eg. (64, 64)                                                                                                               
            df = self.split_input_params["vqf"]
            self.split_input_params['original_image_size'] = x.shape[-2:]
            bs, nc, h, w = x.shape
            if ks[0] > h or ks[1] > w:
                ks = (min(ks[0], h), min(ks[1], w))
                print("reducing Kernel")

            if stride[0] > h or stride[1] > w:
                stride = (min(stride[0], h), min(stride[1], w))
                print("reducing stride")

            fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
            z = unfold(x)  # (bn, nc * prod(**ks), L)                                                                                                                                
            z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )                                                                                    

            output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
                           for i in range(z.shape[-1])]

            o = torch.stack(output_list, axis=-1)
            o = o * weighting

            # Reverse reshape to img shape                                                                                                                                           
            o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)                                                                                                 
            # stitch crops together                                                                                                                                                  
            decoded = fold(o)
            decoded = decoded / normalization
            return decoded

        else:
            return self.first_stage_model.encode(x)
    else:
        return self.first_stage_model.encode(x)



In [None]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.encode_first_stage = encode_first_stage.__get__(model, type(model))
    model.cuda()
    model.eval()
    return model


In [None]:
# Initialize SD model (txt2img.py)

seed_everything(123)

config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "sd-v1-4.ckpt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.model = model.model.to(torch.float16)
sd_model = model.to(device)


In [None]:
sampler = DDIMSampler(model)

In [None]:
# Generate DDIM schedule to provide alpha values at each timestep
ddim_steps = 50
ddim_eta = 0.0
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=True)

In [None]:
# Code from pytorch-nerf tinyNeRF 

def meshgrid_xy(tensor1: torch.Tensor, tensor2: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """Mimick np.meshgrid(..., indexing="xy") in pytorch. torch.meshgrid only allows "ij" indexing.
    (If you're unsure what this means, safely skip trying to understand this, and run a tiny example!)

    Args:
      tensor1 (torch.Tensor): Tensor whose elements define the first dimension of the returned meshgrid.
      tensor2 (torch.Tensor): Tensor whose elements define the second dimension of the returned meshgrid.
    """
    ii, jj = torch.meshgrid(tensor1, tensor2)
    return ii.transpose(-1, -2), jj.transpose(-1, -2)


def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
  r"""Mimic functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch.

  Args:
    tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1
      is to be computed.
  
  Returns:
    cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of
      tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).
  """
  # TESTED
  # Only works for the last dimension (dim=-1)
  dim = -1
  # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
  cumprod = torch.cumprod(tensor, dim)
  # "Roll" the elements along dimension 'dim' by 1 element.
  cumprod = torch.roll(cumprod, 1, dim)
  # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
  cumprod[..., 0] = 1.
  
  return cumprod

#### Compute the "bundle" of rays through all pixels of an image (tinyNeRF)

In [None]:
# Code from pytorch-nerf tinyNeRF


def get_ray_bundle(height: int, width: int, focal_length: float, tform_cam2world: torch.Tensor):
  r"""Compute the bundle of rays passing through all pixels of an image (one ray per pixel).

  Args:
    height (int): Height of an image (number of pixels).
    width (int): Width of an image (number of pixels).
    focal_length (float or torch.Tensor): Focal length (number of pixels, i.e., calibrated intrinsics).
    tform_cam2world (torch.Tensor): A 6-DoF rigid-body transform (shape: :math:`(4, 4)`) that
      transforms a 3D point from the camera frame to the "world" frame for the current example.
  
  Returns:
    ray_origins (torch.Tensor): A tensor of shape :math:`(width, height, 3)` denoting the centers of
      each ray. `ray_origins[i][j]` denotes the origin of the ray passing through pixel at
      row index `j` and column index `i`.
      (TODO: double check if explanation of row and col indices convention is right).
    ray_directions (torch.Tensor): A tensor of shape :math:`(width, height, 3)` denoting the
      direction of each ray (a unit vector). `ray_directions[i][j]` denotes the direction of the ray
      passing through the pixel at row index `j` and column index `i`.
      (TODO: double check if explanation of row and col indices convention is right).
  """  
  ii, jj = meshgrid_xy(
      torch.arange(width).to(tform_cam2world),
      torch.arange(height).to(tform_cam2world)
  )
  directions = torch.stack([(ii - width * .5) / focal_length,
                            -(jj - height * .5) / focal_length,
                            -torch.ones_like(ii)
                           ], dim=-1)
  ray_directions = torch.sum(directions[..., None, :] * tform_cam2world[:3, :3], dim=-1)
  ray_origins = tform_cam2world[:3, -1].expand(ray_directions.shape)
  return ray_origins, ray_directions

# Note that this grid of rays are not normalized (doesn't matter for our purposes)

#### Compute "query" 3D points given the "bundle" of rays (tinyNeRF)

We assume that a _near_ and a _far_ clipping distance are provided that delineate the volume of interest. Each ray is evaluated only within these bounds. We randomly sample points along each ray, while trying to ensure most parts of the ray's trajectory are spanned.

In [None]:
# Code from pytorch-nerf tinyNeRF

def compute_query_points_from_rays(
    ray_origins: torch.Tensor,
    ray_directions: torch.Tensor,
    near_thresh: float,
    far_thresh: float,
    num_samples: int,
    randomize: Optional[bool] = True
) -> (torch.Tensor, torch.Tensor):
  r"""Compute query 3D points given the "bundle" of rays. The near_thresh and far_thresh
  variables indicate the bounds within which 3D points are to be sampled.

  Args:
    ray_origins (torch.Tensor): Origin of each ray in the "bundle" as returned by the
      `get_ray_bundle()` method (shape: :math:`(width, height, 3)`).
    ray_directions (torch.Tensor): Direction of each ray in the "bundle" as returned by the
      `get_ray_bundle()` method (shape: :math:`(width, height, 3)`).
    near_thresh (float): The 'near' extent of the bounding volume (i.e., the nearest depth
      coordinate that is of interest/relevance).
    far_thresh (float): The 'far' extent of the bounding volume (i.e., the farthest depth
      coordinate that is of interest/relevance).
    num_samples (int): Number of samples to be drawn along each ray. Samples are drawn
      randomly, whilst trying to ensure "some form of" uniform spacing among them.
    randomize (optional, bool): Whether or not to randomize the sampling of query points.
      By default, this is set to `True`. If disabled (by setting to `False`), we sample
      uniformly spaced points along each ray in the "bundle".
  
  Returns:
    query_points (torch.Tensor): Query points along each ray
      (shape: :math:`(width, height, num_samples, 3)`).
    depth_values (torch.Tensor): Sampled depth values along each ray
      (shape: :math:`(num_samples)`).
  """
  # TESTED
  # shape: (num_samples)
  depth_values = torch.linspace(near_thresh, far_thresh, num_samples).to(ray_origins)
  if randomize is True:
    # ray_origins: (width, height, 3)
    # noise_shape = (width, height, num_samples)
    noise_shape = list(ray_origins.shape[:-1]) + [num_samples]
    # depth_values: (num_samples)
    depth_values = depth_values \
        + torch.rand(noise_shape).to(ray_origins) * (far_thresh
            - near_thresh) / num_samples
  # (width, height, num_samples, 3) = (width, height, 1, 3) + (width, height, 1, 3) * (num_samples, 1)
  # query_points:  (width, height, num_samples, 3)
  query_points = ray_origins[..., None, :] + ray_directions[..., None, :] * depth_values[..., :, None]
  return query_points, depth_values

#### Volumetric rendering (TinyNeRF)
> **NOTE**: This volumetric rendering module (like the authors' tiny_nerf [Colab notebook](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb) does not implement 5D input (which includes view directions, in addition to X, Y, Z coordinates). It also does not implement the hierarchical sampling procedure. 

In [None]:
### This is modified to include a (simple) illumination model,
### using surface normals extracted from the gradient of the NeRF density field
### This is probably very inefficient - need to backprop through these normals
### Must be a better way to do this - probably finite differences

def render_volume_density(
    radiance_field: torch.Tensor,
    ray_origins: torch.Tensor,
    depth_values: torch.Tensor,
    bg_color: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
  r"""Differentiably renders a radiance field, given the origin of each ray in the
  "bundle", and the sampled depth values along them.

  Args:
    radiance_field (torch.Tensor): A "field" where, at each query location (X, Y, Z),
      we have an emitted (RGB) color and a volume density (denoted :math:`\sigma` in
      the paper) (shape: :math:`(width, height, num_samples, 4)`).
    ray_origins (torch.Tensor): Origin of each ray in the "bundle" as returned by the
      `get_ray_bundle()` method (shape: :math:`(width, height, 3)`).
    depth_values (torch.Tensor): Sampled depth values along each ray
      (shape: :math:`(num_samples)`).
  
  Returns:
    rgb_map (torch.Tensor): Rendered RGB image (shape: :math:`(width, height, 3)`).
    depth_map (torch.Tensor): Rendered depth image (shape: :math:`(width, height)`).
    grad_map (torch.Tensor): Accumulated surfae 
    norm_reg_map (torch.tensor):
    T_map (torch.Tensor): # TODO: Double-check (I think this is the accumulated
      transmittance map).
  """
  # Use a RELU activation function for density field (rather than exp in paper).
  # Not smooth like in paper - but on other hand, easier to make vanish (exp(-4.5) = 1%)
  # We really want to encourage the model to clear space

  sigma_a = torch.nn.functional.relu(radiance_field[..., 3])
  # Swap between NeRF colors (obj_col_w=0) and a constant col obj_col (obj_col_w=1)
  rgb = torch.sigmoid(radiance_field[..., :3])

  # The original tinyNerf used a large depth value for the final point in the ray,
  # Here we set this to zero, ignoring this point

  small = torch.tensor([0e-6], dtype=ray_origins.dtype, device=ray_origins.device)
  dists = torch.cat((depth_values[..., 1:] - depth_values[..., :-1],
                  small.expand(depth_values[..., :1].shape)), dim=-1)

  # Integrated opacity for each ray segment. Note that this uses samples from the density field
  # at the front of each ray segment - think it would reduce noise a little bit if we averaged the
  # front and back
  alpha = 1. - torch.exp(-sigma_a * dists)
  # Weights for each ray segment -> alpha compositing. Opaque segments reduce visibility
  # of those deeper in the volume
  weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)
  # Transparency map - used for bg compositing and to estimate occupancy of nerf volume
  T_map = weights.sum(dim=-1)
  
  rgb_map = (weights[ ... , None] * rgb).sum(dim=-2) + (1-T_map[..., None])*bg_color[None, None, :]
  
  depth_map = (weights * depth_values).sum(dim=-1)

  return rgb_map, depth_map, T_map

#### Positional encoding (TinyNeRF)

In [None]:
# pytorch-tinyNeRF code
def positional_encoding(
    tensor, num_encoding_functions=6, include_input=True, log_sampling=True
) -> torch.Tensor:
  r"""Apply positional encoding to the input.

  Args:
    tensor (torch.Tensor): Input tensor to be positionally encoded.
    num_encoding_functions (optional, int): Number of encoding functions used to
        compute a positional encoding (default: 6).
    include_input (optional, bool): Whether or not to include the input in the
        computed positional encoding (default: True).
    log_sampling (optional, bool): Sample logarithmically in frequency space, as
        opposed to linearly (default: True).
  
  Returns:
    (torch.Tensor): Positional encoding of the input tensor.
  """
  # The input tensor is added to the positional encoding
  # Not optional for this notebook as we need the real position to
  # add the initial "gaussian" bump
  encoding = [tensor] if include_input else []
  # Now, encode the input using a set of high-frequency functions and append the
  # resulting values to the encoding.
  frequency_bands = None
  if log_sampling:
      frequency_bands = 2.0 ** torch.linspace(
            0.0,
            num_encoding_functions - 1,
            num_encoding_functions,
            dtype=tensor.dtype,
            device=tensor.device,
        )
  else:
      frequency_bands = torch.linspace(
          2.0 ** 0.0,
          2.0 ** (num_encoding_functions - 1),
          num_encoding_functions,
          dtype=tensor.dtype,
          device=tensor.device,
      )

  for freq in frequency_bands:
      for func in [torch.sin, torch.cos]:
          encoding.append(func(tensor * freq))

  # Special case, for no positional encoding
  if len(encoding) == 1:
      return encoding[0]
  else:
      return torch.cat(encoding, dim=-1)

## TinyNeRF: Network architecture (pytorch-tinyNeRF)

In [None]:
# Added an extra layer. Also added the "seed" gaussian as a constant output,
# Model made to omit zero density outside a bounded sphere - mipNeRF 360
# has a much cleverer solution with space rescaling.
# Without this hard to minimize haze away from object.

r2_max = 1**2

class VeryTinyNerfModel(torch.nn.Module):
  r"""Define a "very tiny" NeRF model comprising three fully connected layers.
  """
  def __init__(self, filter_size=128, num_encoding_functions=6, l_s=5, s_s=0.2):
    super(VeryTinyNerfModel, self).__init__()
    self.l_s = l_s # Amplitude of initial gaussian bump (centred at origin)
    self.s_s = s_s # Standard deviation of gaussian 
    # Input layer (default: 39 -> 128)
    self.layer1 = torch.nn.Linear(3 + 3 * 2 * num_encoding_functions, filter_size)
    self.layer1_norm = torch.nn.LayerNorm(filter_size)
    # Layer 2 (default: 128 -> 128)
    self.layer2 = torch.nn.Linear(filter_size, filter_size)
    self.layer2_norm = torch.nn.LayerNorm(filter_size)


    # Layer 3 (default: 128 -> 4)
    self.layer3 = torch.nn.Linear(filter_size, 4)
    # Short hand for torch.nn.functional.relu
    self.relu = torch.nn.functional.silu
    
  def forward(self, x0):
    x = self.relu(self.layer1_norm(self.layer1(x0)))
    x = self.relu(self.layer2_norm(self.layer2(x)))

    x = self.layer3(x)
    r2 = (x0[:,:3]**2).sum(axis=1)
    x[:,3] += self.l_s*torch.exp(-r2/(2*self.s_s**2))
    x = x*(r2[:,None]<r2_max)
    return x

## Dataloading utils (tinyNeRF)

In [None]:
# Pytorch-tinyNerf
def get_minibatches(inputs: torch.Tensor, chunksize: Optional[int] = 1024 * 8):
  r"""Takes a huge tensor (ray "bundle") and splits it into a list of minibatches.
  Each element of the list (except possibly the last) has dimension `0` of length
  `chunksize`.
  """
  return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]

## Determine device to run on (GPU vs CPU)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load up input images, poses, intrinsics, etc.

In [None]:

sc = 1

# Original focal length from the tiny-nerf lego dataset
focal_length = torch.tensor(138.8889*sc).cuda() 

# Height and width of each image
height, width = (64*sc, 64*sc) #images.shape[1:3]

# Near and far clipping thresholds for depth values.
near_thresh = 2.
far_thresh = 6.



## Train TinyNeRF! (tinyNerf)

In [None]:
# One iteration of TinyNeRF (forward pass).
def run_one_iter_of_tinynerf(height, width, focal_length, tform_cam2world,
                                  near_thresh, far_thresh, depth_samples_per_ray,
                                  encoding_function, get_minibatches_function, 
                                 bg_color):
  



  # Get the "bundle" of rays through all image pixels.
  ray_origins, ray_directions = get_ray_bundle(height, width, focal_length,
                                               tform_cam2world)
  
  # Sample query points along each ray
  query_points, depth_values = compute_query_points_from_rays(
      ray_origins, ray_directions, near_thresh, far_thresh, depth_samples_per_ray
  )

  # "Flatten" the query points.
  flattened_query_points = query_points.reshape((-1, 3))

  batches = get_minibatches_function(flattened_query_points, chunksize=chunksize)
  predictions = []
  for batch in batches:
    encoded_batch = encoding_function(batch)
    p = model(encoded_batch)
    predictions.append(p)
  radiance_field_flattened = torch.cat(predictions, dim=0)


  # "Unflatten" to obtain the radiance field.
  unflattened_shape = list(query_points.shape[:-1]) + [4]
  radiance_field = torch.reshape(radiance_field_flattened, unflattened_shape)
  
  # Perform differentiable volume rendering to re-synthesize the RGB image.
  rgb_predicted, depth_predicted,  T_predicted = render_volume_density(radiance_field,
                                                                       ray_origins, 
                                                                       depth_values,
                                                                       bg_color)

  return rgb_predicted, depth_predicted,  T_predicted

In [None]:
# Rescale from [-1,1] (SD) to [0,1] (NeRF) 
def rescale_dd(img):
    return torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)

In [None]:
"""
Parameters for TinyNeRF training - works on 3090 
"""

# Number of functions used in the positional encoding (Be sure to update the 
# model if this number changes).
num_encoding_functions = 6
# Specify encoding function.
encode = lambda x: positional_encoding(x, num_encoding_functions=num_encoding_functions)
# Number of depth samples along each ray.
depth_samples_per_ray = 128

# Chunksize (Note: this isn't batchsize in the conventional sense. This only
# specifies the number of rays to be queried in one go. Backprop still happens
# only after all rays from the current "bundle" are queried and rendered).
chunksize=4096

# Optimizer parameters
lr = 5e-4
num_iters = 5000

# Misc parameters
display_every = 50  # Number of iters after which stats are displayed
save_every = 500
"""
Model
"""
model = VeryTinyNerfModel(num_encoding_functions=num_encoding_functions)
model.to(device)

"""
Optimizer
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

"""
Train-Eval-Repeat!
"""

# Seed RNG, for repeatability
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)

# Lists to log metrics etc.
psnrs = []
iternums = []
# Use mixed precision training - SD model evaluated in fp16,
# NeRF optimized in mixed fp16 / fp32
scaler = torch.cuda.amp.GradScaler()


In [None]:

def cam_view(radius, phi, theta, offset):
    trans_v = lambda v: np.array([
            [1, 0, 0, v[0]],
            [0, 1, 0, v[1]],
            [0, 0, 1, v[2]],
            [0, 0, 0, 1],
        ], dtype=np.float32)
    trans_t = lambda t: np.array([
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 1, t],
            [0, 0, 0, 1],
        ], dtype=np.float32)
    rotation_phi = lambda phi: np.array([
            [1, 0, 0, 0],
            [0, np.cos(phi), -np.sin(phi), 0],
            [0, np.sin(phi), np.cos(phi), 0],
            [0, 0, 0, 1],
        ], dtype=np.float32)

    rotation_theta = lambda th: np.array([
            [np.cos(th), 0, -np.sin(th), 0],
            [0, 1, 0, 0],
            [np.sin(th), 0, np.cos(th), 0],
            [0, 0, 0, 1],
    ], dtype=np.float32)
    
    cam_to_world = trans_t(radius)
    cam_to_world = rotation_phi(phi / 180. * np.pi) @ cam_to_world
    cam_to_world = rotation_theta(theta / 180. * np.pi) @ cam_to_world
    cam_to_world = trans_v(offset) @ cam_to_world
    cam_to_world = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
                                dtype=np.float32) @ cam_to_world
    return cam_to_world

    

In [None]:
# View-angle dependent prefixes.
# Not the same as in original paper (they may have experimented more,
# may be different between Imagen and Stable Diffusion).
def view_to_prefix(radius, phi, theta, offset):
    if phi < -60:
        return "overhead view of a "
    elif phi > 60:
        return "bottom view of a "
    elif 45 < theta < 135:
        return "side view of a "
    elif 135 < theta < 225:
        return "front view of a "
    elif 225 < theta < 315:
        return "side view of a "
    else:
        return "rear view of a "

In [None]:
# Conditional / unconditional guidance weights
# This is less than in original paper (100). Experiment for best results
scale = 50
cond = {}
# Ideally pick something with a distinct colour that's also radially symmetric
prompt = 'yellow bath duck blender render'
# Conditioning for different views. Original model blended these.
# Note that even prepending to the prompt isn't very strong guidance
# For the duck, if there's a beak on one side, the model will complete the face even
# with "rear view"
for p in ["", "overhead view of a ", "bottom view of a ", "side view of a ", "rear view of a ", "front view of a "]:
    cond[p]= sd_model.get_learned_conditioning([p + prompt ])
uc = sd_model.get_learned_conditioning([""])

In [None]:
# Illumination / camera view direction
def illum(phi, theta):
    phi = phi/180.0*np.pi
    theta = theta/180.0*np.pi
    return torch.tensor([np.sin(theta)*np.cos(-phi), np.cos(theta)*np.cos(-phi), np.sin(-phi)])

In [None]:
start_i = 0
for i in range(start_i, start_i+num_iters):

  # Select a random DDIM timestep length
  idx = torch.randint(len(sampler.ddim_timesteps),(1,))
  t = torch.tensor([sampler.ddim_timesteps[idx]]).cuda()

  lambda_T = 1e-6 # Transparency loss weight
  if i<2500:
    lambda_norm = 0 # Normal direction loss weight
  elif i<9000:
    lambda_norm = 1e-4
  else:
    lambda_norm = 1e-3

  # Randomly select camera angle and position
  # Different to paper. Also no rotation of up angle - I suspect that's good to
  # add later on for regularization, but not helpful early
  theta = np.random.uniform(0, 360)
  if np.random.random() < 0.5:
      phi = np.random.uniform(-90, 0) 
  else:
      phi = np.random.uniform(-50, 10)
  # Vary view distance. This perhaps should be a wider range
  radius = np.random.uniform(3, 4)
  # Jitter camera position
  offset = np.random.uniform(-0.05, 0.05, [3])

  # Illuminate from an angle near the camera direction. Probably should also
  # illuminate from above sometimes. Would like it if the lower half of the object isn't
  # made to be dark. 
  illum_dir = illum(phi+10*np.random.randn(), theta+10*np.random.randn()).half().cuda()
  cam_dir = illum(phi, theta).half().cuda()
 
  if i>4000:
    # Random background colours (like Dreamfields) later on
    # Not in original paper (but has a NeRF that's better for background)
    # Really not sure this is helpful - for one test ended up with more haze
    # after switching it on
    bg_color = torch.tensor(np.random.uniform(0.0, 0.4, [3])).half().cuda()
  else:
    # Dull grey background. 
    bg_color = torch.tensor([0.3, 0.3, 0.3]).half().cuda()
  
        
  target_tform_cam2world = torch.tensor(cam_view(radius, phi, theta, offset)).cuda()

  prefix = view_to_prefix(radius, phi, ((theta)%360), offset)
  c = cond[prefix]
  with torch.cuda.amp.autocast():
  # Run one iteration of TinyNeRF and get the rendered RGB image.
      rgb_predicted, depth_predicted,  T_predicted \
        = run_one_iter_of_tinynerf(height, width, focal_length,
                                           target_tform_cam2world, near_thresh,
                                           far_thresh, depth_samples_per_ray,
                                           encode, get_minibatches, bg_color)
    
    
      # Rescale image components to (-1, 1)
      z0 = 2*(rgb_predicted-0.5)
      z = z0.permute(2,0,1).unsqueeze(0)
      # Upscale to 512x512 - stable diffusion gives poor results when naively
      # downsampled from trained resolution
      z = torch.nn.functional.interpolate(z, size=(512, 512), mode='bilinear')
      # Embed in SD latent space
      x = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(z))
      # Generate random noise
      eps = torch.randn(x.shape).cuda()
      # Use noise level from DDIM schedule. Note that ldm SD alpha == alpha^2 in Dreamfusion paper
      alpha_t = sampler.ddim_alphas[idx].cuda()
      # Add noise to x
      x_t = alpha_t.sqrt() * x + torch.sqrt(1-alpha_t) * eps
    
      # Guided diffusion - evaluate with and without conditional prompt
      x_in = torch.cat([x_t] * 2)
      t_in = torch.cat([t] * 2)
      c_in = torch.cat([uc, c])
      with torch.no_grad(): # No gradients in SD UNet required
          # Apply UNet to estimate noise vector
          e_t_uncond, e_t = sd_model.apply_model(x_in, t_in, c_in).chunk(2)
          # Combine conditional and unconditional results according to scale
          e_t = e_t_uncond + scale * (e_t - e_t_uncond)
          #
          d = e_t - eps
          # Predict noise-free x using estimated noise
          pred_x0 = (x_t - torch.sqrt(1-alpha_t) * e_t) / alpha_t.sqrt()
          # Decode 64x64x4 latent into 512x512x3 image space
          y = sd_model.decode_first_stage(pred_x0)
      # Now apply gradients to embedding.
    
      # Regularization to try to remove hazyness from NeRF volume
      # Original paper used sqrt(lambda_T**2 + 1e-3).sum()
      loss_T = lambda_T*T_predicted.sum()
      # Normal direction regularization loss - penalize accumulation from backfaces.
      # Tries to ensure that the object is dense enough to block transmission from its
      # rear. Also smooth surface out
      
      loss = loss_T 
      # Unscaled gradients *seem* to work OK, but unsure if need two scalers to make
      # work properly
      x.backward(gradient=d, retain_graph=True)
      
  ### AMP loss scaling and update
  scaler.scale(loss).backward()
  scaler.step(optimizer)
  scaler.update()
  optimizer.zero_grad()

  # Display images/plots/stats
  if i % display_every == 0:
    print(prefix + prompt, loss.detach(), loss_T.detach() )
    plt.figure(figsize=(5, 5))
    plt.subplot(221)
    plt.imshow(rgb_predicted.detach().cpu().float().numpy())
    plt.title(f"Iteration {i}")
    plt.subplot(222)
    plt.imshow(rescale_dd(y[0]).detach().cpu().float().numpy().transpose(1,2,0))
    plt.title("Target")
    plt.subplot(223)
    plt.imshow(depth_predicted.detach().cpu().float().numpy())
    plt.title("Depth")
    
    plt.subplot(224)
    plt.imshow(T_predicted.detach().cpu().float().numpy())
    plt.title("Transparency")

    
    plt.show()
  if (i+1) % save_every == 0:
    torch.save({
            'epoch': i,
            'model1_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            }, f'checkpoint_{prompt}_{i}.ckpt')




In [None]:
# From mipnerf https://github.com/google/mipnerf
def generate_spherical_cam_to_world(radius, n_poses=120, d_th=-5, d_phi=-5):
    """
    Generate a 360 degree spherical path for rendering
    ref: https://github.com/kwea123/nerf_pl/blob/master/datasets/llff.py
    ref: https://github.com/yenchenlin/nerf-pytorch/blob/master/load_blender.py
    Create circular poses around z axis.
    Inputs:
        radius: the (negative) height and the radius of the circle.
    Outputs:
        spheric_cams: (n_poses, 3, 4) the cam to world transformation matrix of a circular path
    """

    def spheric_pose(theta, phi, radius):
        trans_t = lambda t: np.array([
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 1, t],
            [0, 0, 0, 1],
        ], dtype=np.float32)

        rotation_phi = lambda phi: np.array([
            [1, 0, 0, 0],
            [0, np.cos(phi), -np.sin(phi), 0],
            [0, np.sin(phi), np.cos(phi), 0],
            [0, 0, 0, 1],
        ], dtype=np.float32)

        rotation_theta = lambda th: np.array([
            [np.cos(th), 0, -np.sin(th), 0],
            [0, 1, 0, 0],
            [np.sin(th), 0, np.cos(th), 0],
            [0, 0, 0, 1],
        ], dtype=np.float32)
        cam_to_world = trans_t(radius)
        cam_to_world = rotation_phi(phi / 180. * np.pi) @ cam_to_world
        cam_to_world = rotation_theta(theta) @ cam_to_world
        cam_to_world = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
                                dtype=np.float32) @ cam_to_world
        return cam_to_world

    spheric_cams = []
    for th in np.linspace(0, 2 * np.pi, n_poses + 1)[:-1]:
        spheric_cams += [spheric_pose(th, -30, radius)]
    illum_dir = []
    for th in np.linspace(0, 2 * np.pi, n_poses + 1)[:-1]:
      illum_dir += [illum(30+d_phi, 180+th*180/np.pi+d_th) ]
        
    return np.stack(spheric_cams, 0), np.stack(illum_dir, 0)


In [None]:
### Really simple rendering
### Note that volume rendering for training needs to use autograd for
### gradient of NeRF field, with retain_graph=True. Very VRAM costly so
### needs to be removed for rendering at higher resolutions (sc>2)
### This outputs 120 128x128 images at different view angles

sc = 2
near_thresh=1
far_thresh=4.5
height = width = 64
depth_samples_per_ray = 128
chunksize=1024

for j in range(120): 
  target_tform_cam2world = torch.tensor(cam_view(3.5, -30, 360/120*j, np.array([0,0,0]))).to(device) #poses[target_img_idx].to(device)

  with torch.no_grad():
    rgb_predicted, depth_predicted, _ = run_one_iter_of_tinynerf(height*sc, width*sc, focal_length*sc,
                                           target_tform_cam2world, near_thresh,
                                           far_thresh, depth_samples_per_ray,
                                           encode, get_minibatches, 
                                           torch.tensor([0.0, 0.0, 0.0]).half().cuda(),
)                                    
                                
    
  
  print(rgb_predicted.shape, rgb_predicted.min(), rgb_predicted.max())
  plt.figure()

  plt.imshow(rgb_predicted.detach().cpu().float().numpy())
    
  im = (rgb_predicted.detach().cpu().float().numpy()*255).astype(np.uint8)
  imsave(f'anim_{j:03d}.png', im)
  plt.title(f"View {j:03d}")
  plt.savefig(f"View_{j:03d}.png")

  plt.show()
