Copyright 2021 Google LLC.

Licensed under the Apache License, Version 2.0 (the \"License\");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an \"AS IS\" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Infinite Nature example Colab

The Colab is part of the code for the paper ___Infinite Nature: Perpetual View Generation of Natural Scenes from a Single Image___, and may be found at <br>https://github.com/google-research/google-research/tree/master/infinite_nature.

The project page is at https://infinite-nature.github.io/.

Choose __Run all__ from the Runtime menu to:
* download our code and the tf_mesh_renderer library which we use,
* set up the network and load our trained model,
* run a demo generating frames indefinitely with autopilot.

You can also interact with our demo by clicking on the image to steer the camera.

If an error is encountered while running the demo, please send an email to the authors.


## Download code, model weights, and example data and install dependencies.

In [None]:
%%shell
echo Fetching code from github...

apt install subversion
svn export --force https://github.com/google-research/google-research/trunk/infinite_nature

echo
echo Fetching trained model weights...
rm -f autocruise_input*.pkl
rm -f ckpt.tar.gz
rm -rf ckpt
wget https://storage.googleapis.com/gresearch/infinite_nature_public/autocruise_input1.pkl
wget https://storage.googleapis.com/gresearch/infinite_nature_public/autocruise_input2.pkl
wget https://storage.googleapis.com/gresearch/infinite_nature_public/autocruise_input3.pkl
wget https://storage.googleapis.com/gresearch/infinite_nature_public/ckpt.tar.gz
tar -xf ckpt.tar.gz

echo
echo Installing required dependency...
pip install tensorflow-addons

echo
echo Fetching tf_mesh_renderer and compiling kernels...
cd infinite_nature
rm -rf tf_mesh_renderer
source download_tf_mesh_renderer.sh

echo Done.

In [None]:
import tensorflow as tf
import os
import sys

# Make sure dynamic linking can find tensorflow libraries.
os.system('ldconfig ' + tf.sysconfig.get_lib())

# Make sure python can find our libraries.
sys.path.append('infinite_nature')
sys.path.append('infinite_nature/tf_mesh_renderer/mesh_renderer')

# Make sure the mesh renderer library knows where to load its .so file from.
os.environ['TEST_SRCDIR'] = 'infinite_nature'

In [None]:
import IPython
import numpy as np
import pickle
import tensorflow_hub as hub

import config
import fly_camera
import infinite_nature_lib

# Build model and restore checkpoint.
config.set_training(False)
model_path = "ckpt/model.ckpt-6935893"
render_refine, style_encoding = infinite_nature_lib.load_model(model_path)
initial_rgbds = [
    pickle.load(open("autocruise_input1.pkl", "rb"))['input_rgbd'],
    pickle.load(open("autocruise_input2.pkl", "rb"))['input_rgbd'],
    pickle.load(open("autocruise_input3.pkl", "rb"))['input_rgbd']]

# Code for an autopilot demo. We expose two functions that will be invoked
# from an HTML/JS frontend: reset and step.

# The state that we need to remember while flying:
state = {
  'intrinsics': None,
  'pose': None,
  'rgbd': None,
  'start_rgbd': None,
  'style_noise': None,
  'next_pose_function': None,
  'direction_offset': None,  # Direction controlled by user's mouse clicks.
}


def current_image_as_png():
  imgdata = tf.image.encode_png(
      tf.image.convert_image_dtype(state['rgbd'][..., :3], dtype=tf.uint8))
  return IPython.display.Image(data=imgdata.numpy())


def reset(rgbd=None):
  if rgbd is None:
    rgbd = state['start_rgbd']

  height, width, _ = rgbd.shape
  aspect_ratio = width / float(height)

  rgbd = tf.image.resize(rgbd, [160, 256])
  state['rgbd'] = rgbd
  state['start_rgbd'] = rgbd
  state['pose'] = np.array(
      [[1.0, 0.0, 0.0, 0.0],
       [0.0, 1.0, 0.0, 0.0],
       [0.0, 0.0, 1.0, 0.0]],
      dtype=np.float32)
  # 0.8 focal_x corresponds to a FOV of ~64 degrees.
  state['intrinsics'] = np.array(
      [0.8, 0.8 * aspect_ratio, .5, .5],
      dtype=np.float32)
  state['direction_offset'] = (0.0, 0.0)
  state['style_noise'] = style_encoding(rgbd)
  state['next_pose_function'] = fly_camera.fly_dynamic(
    state['intrinsics'],
    state['pose'],
    turn_function=(lambda _: state['direction_offset']))
  return current_image_as_png()


def step(offsetx, offsety):
  state['direction_offset'] = (offsetx, offsety)
  next_pose = state['next_pose_function'](state['rgbd'])
  next_rgbd = render_refine(
       state['rgbd'], state['style_noise'],
       state['pose'], state['intrinsics'],
       next_pose, state['intrinsics'])
  state['pose'] = next_pose
  state['rgbd'] = next_rgbd
  return current_image_as_png()


# To run on user-supplied images, we use MiDaS V2 to obtain initial disparity.
midas_model = hub.load('https://tfhub.dev/intel/midas/v2/2', tags=['serve'])


def midas_disparity(rgb):
  """Computes MiDaS v2 disparity on an RGB input image.

  Args:
    rgb: [H, W, 3] Range [0.0, 1.0].
  Returns:
    [H, W, 1] MiDaS disparity resized to the input size and in the range
    [0.0, 1.0]
  """
  size = rgb.shape[:2]
  resized = tf.image.resize(rgb, [384, 384], tf.image.ResizeMethod.BICUBIC)
  # MiDaS networks wants [1, C, H, W]
  midas_input = tf.transpose(resized, [2, 0, 1])[tf.newaxis]
  prediction = midas_model.signatures['serving_default'](midas_input)['default'][0]
  disp_min = tf.reduce_min(prediction)
  disp_max = tf.reduce_max(prediction)
  prediction = (prediction - disp_min) / (disp_max - disp_min)
  return tf.image.resize(
      prediction[..., tf.newaxis], size,  method=tf.image.ResizeMethod.AREA)


def load_initial(i):
  return reset(rgbd=initial_rgbds[i])


def load_image(data):
  # Data converted from JS ends up as a string, needs to be converted to
  # bytes using Latin-1 encoding (which just maps 0-255 to 0-255).
  data = data.encode('Latin-1')
  rgb = tf.image.decode_image(data, channels=3, dtype=tf.float32)
  resized = tf.image.resize(rgb, [160, 256], tf.image.ResizeMethod.AREA)
  rgbd = tf.concat([resized, midas_disparity(resized)], axis=-1)
  return reset(rgbd=rgbd)

In [None]:
from google.colab import output

# The front-end for our interactive demo.

html='''
<style>
#view {
  width: 512px;
  height: 320px;
  background-color: #aaa;
  background-size: 100% 100%;
  border: 1px solid #000;
  margin: 20px;
  position: relative;
}
#rgb {
  height: 100%;
}
#cursor {
  position: absolute;
  height: 0; width: 0;
  left: 50%; top: 50%;
  opacity: .5;
}
#cursor::before, #cursor::after {
  content: '';
  position: absolute;
  background: #f04;
  pointer-events: none;
}
#cursor::before {
  left: -10px; top: -1px; width: 20px; height: 2px;
}
#cursor::after {
  left: -1px; top: -10px; width: 2px; height: 20px;
}
.buttons {
  margin: 20px;
}
.buttons div {
  display: inline-block;
  cursor: pointer;
  padding: 20px;
  background: #eee;
  border: 2px solid #aaa;
  border-radius: 3px;
  margin-right: 10px;
  font-weight: bold;
  text-transform: uppercase;
  letter-spacing: 1px;
  color: #444;
}
.buttons div:active {
  background: #444;
  color: #fff;
}
h3 {
  margin-left: 20px;
}
</style>
<h3>Infinite Nature interactive demo</h3>
<div id=view><img id=rgb><div id=cursor></div></div>
<div class=buttons>
Click <b>Play</b> to run or <b>Step</b> to advance frame by frame.
Click mouse over image to steer.<br><br>
<div id=restart>Restart</div><div id=play>Play</div><div id=pause>Pause</div><div id=step>Step</div>
<br><br>
Select starting image (be patient…):<br><br>
<div id=image1>Image 1</div><div id=image2>Image 2</div><div id=image3>Image 3</div><div id=upload>Upload…</div><br>
<input style="display:none" type=file id=chooser accept=".png,.jpg">
</div>
<script>
let playing = true;
let pending = false;
let x = 0.5;
let y = 0.5;
let cursor_count = 0;

async function call(name, ...parms) {
  pending = true;
  const result = await google.colab.kernel.invokeFunction(name, parms, {});
  pending = false;
  const url = `data:image/png;base64,${result.data['image/png']}`;
  document.querySelector('#rgb').src = url;
  if (!playing) { return; }
  step();
}

async function reset() {
  playing = false;
  await call('reset');
}

async function selectImage(i) {
  playing = false;
  await call('load_initial', i);
}

function upload() {
  playing = false;
  document.querySelector('#chooser').click();
}

function uploadFile(file) {
  if (file.type != 'image/png' && file.type != 'image/jpeg') {
    error('Only PNG or JPEG files accepted.');
    return;
  }
  console.log(file);
  const reader = new FileReader();
  reader.onload = (e) => {
    const imagebytes = e.target.result;
    call('load_image', imagebytes);
  }
  document.querySelector('#rgb').src = '';
  reader.readAsBinaryString(file);
}

async function step() {
  if (pending) { return; }
  await call('step', 2*x - 1, 2*y - 1);
  // Cursor moves back towards center.
  if (cursor_count) {
    cursor_count--;
  } else {
    x = 0.5 + (x - 0.5) * .9;
    y = 0.5 + (y - 0.5) * .9;
    update_cursor();
  }
}

async function play() {
  playing = true;
  await step();
}

async function pause() {
  playing = false;
}

function update_cursor() {
  let cursor = document.querySelector('#cursor');
  cursor.style.left = `${(100 * x).toFixed(2)}%`;
  cursor.style.top = `${(100 * y).toFixed(2)}%`;
}

function cursor(e) {
  console.log(e);
  x = e.offsetX / e.target.clientWidth;
  y = e.offsetY / e.target.clientHeight;
  cursor_count = 1;
  update_cursor();
}

document.querySelector('#restart').addEventListener('click', reset);
document.querySelector('#image1').addEventListener('click', () => selectImage(0));
document.querySelector('#image2').addEventListener('click', () => selectImage(1));
document.querySelector('#image3').addEventListener('click', () => selectImage(2));
document.querySelector('#upload').addEventListener('click', upload);
document.querySelector('#play').addEventListener('click', play);
document.querySelector('#pause').addEventListener('click', pause);
document.querySelector('#step').addEventListener('click', () => { playing = false; step(); });
document.querySelector('#view').addEventListener('click', cursor);
document.querySelector('#chooser').addEventListener('change', (e) => {
  if (e.target.files.length > 0) {
    uploadFile(e.target.files[0]);
  }
});
selectImage(0);
</script>
'''

display(IPython.display.HTML(html))

output.register_callback('load_initial', load_initial)
output.register_callback('load_image', load_image)
output.register_callback('reset', reset)
output.register_callback('step', step)