# GQN View Interpolation
Loads a trained GQN and performs a sequence of view interpolations.

In [None]:
'''imports'''
# stdlib
import os
import sys
# numerical computing
import numpy as np
import tensorflow as tf
# plotting
import imageio
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from IPython.display import Image, display
# GQN src
root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root_path)
print(sys.path)
from data_provider.gqn_provider import gqn_input_fn
from gqn.gqn_predictor import GqnViewPredictor

In [None]:
'''directory setup'''
data_dir = os.path.join(root_path, 'data')
model_dir = os.path.join(root_path, 'models')
tmp_dir = os.path.join(root_path, 'notebooks', 'tmp')
gqn_dataset_path = os.path.join(data_dir, 'gqn-dataset')
# dataset flags
dataset_name = 'rooms_ring_camera'
data_path = os.path.join(gqn_dataset_path, dataset_name)
print("Data path: %s" % (data_path, ))
# model flags
model_name = 'gqn8'
gqn_model_path = os.path.join(model_dir, dataset_name)
model_path = os.path.join(gqn_model_path, model_name)
print("Model path: %s" % (model_path, ))
# tmp
notebook_name = 'view_interpolation'
notebook_tmp_path = os.path.join(tmp_dir, notebook_name)
os.makedirs(notebook_tmp_path, exist_ok=True)
print("Tmp path: %s" % (notebook_tmp_path, ))

In [None]:
'''data reader setup'''
mode = tf.estimator.ModeKeys.EVAL
ctx_size=5
batch_size=1
dataset = gqn_input_fn(
        dataset_name=dataset_name, root=gqn_dataset_path, mode=mode,
        context_size=ctx_size, batch_size=batch_size, num_epochs=1,
        num_threads=4, buffer_size=1)
iterator = dataset.make_initializable_iterator()
data = iterator.get_next()

In [None]:
'''video predictor & session setup'''
os.environ['CUDA_VISIBLE_DEVICES'] = ''  # run on CPU only
predictor = GqnViewPredictor(model_path)
sess = predictor.sess
sess.run(iterator.initializer)

In [None]:
'''data visualization'''

skip_load = 1
# fetch & parse
for _ in range(skip_load):
    d, _ = sess.run(data)
ctx_frames = d.query.context.frames
ctx_poses = d.query.context.cameras
tgt_frame = d.target
tgt_pose = d.query.query_camera
tuple_length = ctx_size + 1  # context points + 1 target
print(">>> Context frames:\t%s" % (ctx_frames.shape, ))
print(">>> Context poses: \t%s" % (ctx_poses.shape, ))
print(">>> Target frame:  \t%s" % (tgt_frame.shape, ))
print(">>> Target pose:   \t%s" % (tgt_pose.shape, ))

# visualization constants
MAX_COLS_PER_ROW = 6
TILE_HEIGHT, TILE_WIDTH, TILE_PAD = 3.0, 3.0, 0.8
np.set_printoptions(precision=2, suppress=True)

# visualize all data tuples in the batch
for n in range(batch_size):
    # define image grid
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # visualize context
    for ctx_idx in range(ctx_size):
        rgb = ctx_frames[n, ctx_idx]
        pose = ctx_poses[n, ctx_idx]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, pose[0:3], pose[3:])
        grid[ctx_idx].imshow(rgb)
        grid[ctx_idx].set_title(caption, loc='center')
    # visualize target
    rgb = tgt_frame[n]
    pose = tgt_pose[n]
    caption = "target\nxyz:%s\nyp:%s" % \
        (pose[0:3], pose[3:])
    grid[-1].imshow(rgb)
    grid[-1].set_title(caption, loc='center')
    plt.show()

In [None]:
'''run the view prediction'''

# visualize all data tuples in the batch
for n in range(batch_size):

    print(">>> Predictions:")
    # define image grid for predictions
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # load the scene context
    predictor.clear_context()
    for i in range(ctx_size):
        ctx_frame = ctx_frames[n, i]
        ctx_pose = ctx_poses[n, i]
        predictor.add_context_view(ctx_frame, ctx_pose)
    # render query
    query_pose = tgt_pose[n]
    pred_frame = predictor.render_query_view(query_pose)[0]
    caption = "query\nxyz:%s\nyp:%s" % \
        (query_pose[0:3], query_pose[3:])
    grid[0].imshow(pred_frame)
    grid[0].set_title(caption, loc='center')
    # re-render context (cycle-consistency)
    for ctx_idx in range(ctx_size):
        query_pose = ctx_poses[n, ctx_idx]
        pred_frame = predictor.render_query_view(query_pose)[0]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, query_pose[0:3], query_pose[3:])
        grid[ctx_idx + 1].imshow(pred_frame)
        grid[ctx_idx + 1].set_title(caption, loc='center')
    plt.show()

    print(">>> Ground truth:")
    # define image grid for predictions
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # query
    pose = tgt_pose[n]
    rgb = tgt_frame[n]
    caption = "query\nxyz:%s\nyp:%s" % \
        (pose[0:3], pose[3:])
    grid[0].imshow(rgb)
    grid[0].set_title(caption, loc='center')
    # context
    for ctx_idx in range(ctx_size):
        pose = ctx_poses[n, ctx_idx]
        rgb = ctx_frames[n, ctx_idx]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, pose[0:3], pose[3:])
        grid[ctx_idx + 1].imshow(rgb)
        grid[ctx_idx + 1].set_title(caption, loc='center')
    plt.show()

In [None]:
'''render a view interpolation trajectory'''

query_poses = [[0, 0, 0, yaw, 0] for yaw in range(0, 360, 10)]
query_poses = [np.array(qp) for qp in query_poses]
frame_buffer = []
for query_pose in query_poses:
    pred_frame = predictor.render_query_view(query_pose)[0]
    frame_buffer.append(pred_frame)

# show gif of view interpolation trajectory
gif_tmp_path = os.path.join(notebook_tmp_path, 'view_interpolation_preview.gif')
imageio.mimsave(gif_tmp_path, frame_buffer)
with open(gif_tmp_path, 'rb') as file:
    display(Image(file.read()))