In [None]:
import numpy as np

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import plotly.graph_objects as go

from google3.pyglib import gfile
from google3.experimental.users.xiuming.sim.sim import datasets, models
from google3.experimental.users.xiuming.sim.sim.util import io as ioutil, \
    logging as logutil


def query_opacity(model, pts, voxel_size, use_fine=False, mlp_chunk=65536):
    pref = 'fine_' if use_fine else 'coarse_'
    enc = model.net[pref + 'enc']
    a_out = model.net[pref + 'a_out']
    embedder = model.embedder['xyz']
    pts_flat = tf.reshape(pts, (-1, 3))
    # Chunk by chunk to avoid OOM
    a_chunks = []
    for i in range(0, pts_flat.shape[0], mlp_chunk):
        pts_chunk = pts_flat[i:min(pts_flat.shape[0], i + mlp_chunk)]
        pts_embed = embedder(pts_chunk)
        feat = enc(pts_embed)
        a_flat = a_out(feat)
        a_chunks.append(a_flat)
    a_flat = tf.concat(a_chunks, axis=0)
    opacity_flat = 1.0 - tf.exp(-tf.nn.relu(a_flat) * voxel_size)
    opacity = tf.reshape(opacity_flat, pts.shape[:3])
    return opacity


def gen_para_rays(
        n_voxels_per_unit, x_min=-0.3, x_max=0.3, y_min=0., y_max=1.8,
        z_min=2.8, z_max=3.2):
    """Ensures the volume is square, and all voxels are of the same size.
    """
    # go/holodeck-output-api#cameras
    n_x = int((x_max - x_min) * n_voxels_per_unit)
    n_y = int((y_max - y_min) * n_voxels_per_unit)
    n_z = int((z_max - z_min) * n_voxels_per_unit)
    x = tf.linspace(x_min, x_max, n_x) # +X is from right to left arm
    y = tf.linspace(y_min, y_max, n_y) # +Y is from feet to head
    z = tf.linspace(z_min, z_max, n_z) # +Z is from back to chest
    x = tf.cast(x, tf.float32)
    y = tf.cast(y, tf.float32)
    z = tf.cast(z, tf.float32)
    xyz = tf.stack(tf.meshgrid(x, y, z, indexing='xy'), axis=-1)
    return xyz


def get_config_ini(ckpt_path):
    return '/'.join(ckpt_path.split('/')[:-2]) + '.ini'


def make_datapipe(config):
    dataset_name = config.get('DEFAULT', 'dataset')
    Dataset = datasets.get_dataset_class(dataset_name)
    dataset = Dataset(config, 'vali')

    no_batch = config.getboolean('DEFAULT', 'no_batch')
    datapipe = dataset.build_pipeline(no_batch=no_batch)
    return datapipe


def restore_model(config, ckpt_path):
    model_name = config.get('DEFAULT', 'model')
    Model = models.get_model_class(model_name)
    model = Model(config)

    model.register_trainable()

    # Resume from checkpoint
    assert model.trainable_registered, (
        "Register the trainable layers to have them restored from the "
        "checkpoint")
    ckpt = tf.train.Checkpoint(net=model)
    ckpt.restore(ckpt_path).expect_partial()

    return model


def main():
    ckpt_path = '/cns/is-d/home/gcam-eng/gcam/interns/xiuming/sim/output/nerf_repro/324_20190806_134352_viewsyn_0235_transp-bg/lr:0.0001|mgm:-1/vis_test/ckpt-169'
    n_voxels_per_unit = 128
    
    config_ini = get_config_ini(ckpt_path)
    config = ioutil.read_config(config_ini)
    
    # Make dataset
    datapipe = make_datapipe(config)
    
    # Restore model
    model = restore_model(config, ckpt_path)
    
    # Generate rays starting at the XY plane, parallel to the Z axis
    xyz = gen_para_rays(n_voxels_per_unit)
    voxel_size = 1. / n_voxels_per_unit
    
    # Compute alpha (probability of being absorbed) at each location
    mlp_chunk = config.getint('DEFAULT', 'mlp_chunk')
    with tf.GradientTape() as tape:
        opacity = query_opacity(model, xyz, voxel_size, use_fine=True, mlp_chunk=mlp_chunk)
    print("Inference done")
    
    # Plot
    xyz = xyz.numpy()
    opacity = opacity.numpy()
    fig = go.Figure(
        data=go.Volume(
            x=xyz[:, :, :, 0].flatten(), y=xyz[:, :, :, 1].flatten(),
            z=xyz[:, :, :, 2].flatten(), value=opacity.flatten(),
            isomin=0., isomax=1., opacity=0.1, surface_count=16),
        layout=go.Layout(scene=dict(aspectmode='data')))
    fig.show()


main()

[36m[datasets/sim] Number of 'vali' light-view combinations: 5[0m
[36m[models/base] Layers registered as trainable:
	['net_coarse_enc_layer0', 'net_coarse_enc_layer1', 'net_coarse_enc_layer2', 'net_coarse_enc_layer3', 'net_coarse_enc_layer4', 'net_coarse_enc_layer5', 'net_coarse_enc_layer6', 'net_coarse_enc_layer7', 'net_coarse_a_out_layer0', 'net_coarse_bottleneck_layer0', 'net_coarse_rgb_out_layer0', 'net_coarse_rgb_out_layer1', 'net_fine_enc_layer0', 'net_fine_enc_layer1', 'net_fine_enc_layer2', 'net_fine_enc_layer3', 'net_fine_enc_layer4', 'net_fine_enc_layer5', 'net_fine_enc_layer6', 'net_fine_enc_layer7', 'net_fine_a_out_layer0', 'net_fine_bottleneck_layer0', 'net_fine_rgb_out_layer0', 'net_fine_rgb_out_layer1'][0m
Inference done