# Multi-view satellite photogrammetry with Neural Radiance Fields (NeRF)
The objective of this notebook is to study the application of [Neural Radiance Fields](https://www.matthewtancik.com/nerf) to multi-view satellite photogrammetry. This is done using data from the Track 3 of the [IEEE GRSS Data Fusion Contest](http://www.grss-ieee.org/community/technical-committees/data-fusion/2019-ieee-grss-data-fusion-contest/).

In particular, the aim of this notebook is to demonstrate an extension of NeRF called SNeRF (Shadow NeRF) which implicitly models directional lighting effects.

In [1]:
import os
import time
import math
import numpy as np

import tensorflow as tf
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

from IPython.display import HTML, display
import tabulate

from scipy import ndimage
from skimage.transform import rescale, resize
from skimage.metrics import structural_similarity as ssim
from skimage.measure import marching_cubes_lewiner

import data_handling
import models
import render
import train
from plots import plot_images, plot_view_light_directions, plot_depth_map

np.set_printoptions(precision=3,suppress=True)
def_dtype = np.float32


This notebook is best run with a GPU.

In [2]:
print(tf.__version__)
physical_devices = tf.config.experimental.list_physical_devices('GPU')
print(physical_devices)

2.2.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## Dataset
This dataset is composed of multi-view images over two cities in USA. The azimuth and elevation angles of the satellite, the geographic location of the image, and the solar angles are provided as metadata.

In [3]:
# Read config from file
parser = train.config_parser()
args = parser.parse_args('--config ../configs/068_config.txt')
arg_dict = vars(args)
print(arg_dict)

dataset = data_handling.generate_dataset(arg_dict)
arg_dict['rend.rescale'] = render.calculate_rescale_factor(dataset)

{'config': '../configs/068_config.txt', 'data.image.path': '/home/derksend/snerf_release/stereo_nerf/data/068/JAX_068_df1', 'data.image.df': 8, 'data.image.sd': 0.3, 'data.depth.path': '/home/derksend/snerf_release/stereo_nerf/data/068/JAX_068_df1_dsm.tif', 'data.depth.df': 1, 'data.md.path': '/home/derksend/snerf_release/stereo_nerf/data/068/JAX_068_df1_md.txt', 'data.train_id': ['01', '02', '03', '04', '05', '06', '07', '09', '10', '11', '12', '13', '14', '15', '18', '19'], 'data.test_id': ['20', '22'], 'model.ins.light': True, 'model.ins.views': False, 'model.outs.shad': True, 'model.outs.sky': True, 'model.act': 'sin', 'model.act.sin.w0': 32.0, 'model.sigma.depth': 8, 'model.sigma.width': 100, 'model.sigma.skips': [], 'model.c.depth': 1, 'model.c.width': 50, 'model.shad.depth': 4, 'model.shad.width': 50, 'model.emb.pos': 0, 'model.emb.dir': 1, 'rend.nsamples': 64, 'rend.nimportance': 64, 'rend.mode': 'alt', 'rend.mode.nf.near': 3.0, 'rend.mode.nf.far': 10.0, 'rend.mode.alt.max': 30

Here we plot some views of a building with an interesting shape : The Jacksonville UF Health

In [4]:
# plot_images(arg_dict['data.train_id'], dataset['train_imgs'], dataset['train_view_dirs'], dataset['train_light_dirs'])

In [5]:
# plot_view_light_directions(dataset['train_view_dirs'], dataset['train_light_dirs'])

In [6]:
# plot_depth_map(dataset['depth_map'])

## Model

In [7]:
model =  models.generate_model(arg_dict)

## Load model

In [None]:
model = snerf_models.load_model(f"{arg_dict['out.path']}model.npy", arg_dict)
its, train_loss = snerf_plots.parse_train_loss(arg_dict)
snerf_plots.plot_train_loss(its, train_loss)

## Training

In [None]:
arg_dict['train.shad'] = False
arg_dict['train.shad.lambda'] = 0.05
# arg_dict['train.shad.df'] = 1
# arg_dict['rend.nimportance'] = 0
# arg_dict['rend.nsamples'] = 64

In [8]:
## Training rays
train_rays = render.generate_train_rays(dataset, arg_dict)
if arg_dict['train.shad']:
    sc_train_rays = render.generate_train_light_correction_rays(dataset, arg_dict)
    if arg_dict['train.shad.custom'] in ['linear', 'rectangle']:
        custom_sc_rays = render.generate_custom_light_correction_rays(dataset, arg_dict)
        sc_train_rays = render.concat_rays(sc_train_rays, custom_sc_rays)
else:
    sc_train_rays = None

In [9]:
optimizer = train.init_exp_decay_adam(1e-4, 1000, 0.1)

In [10]:
# N_iterations = arg_dict['train.n_epoch']
N_iterations = 100

model, train_loss, scores = train.train_model(model, optimizer, N_iterations, arg_dict, train_rays, sc_train_rays=sc_train_rays)

Begin training
0 11.593910217285156 0.03413091227412224
1 14.110355377197266 0.0356413908302784
2 14.238982200622559 0.03246796503663063
3 14.488728523254395 0.03257639333605766
4 14.663616180419922 0.03240058571100235
5 15.320385932922363 0.03386761620640755
6 14.781527519226074 0.03275985270738602
7 15.734472274780273 0.03188895806670189
8 15.28876781463623 0.030201686546206474
9 14.815862655639648 0.03031565621495247
25 16.14531135559082 0.015308404341340065
50 16.4764347076416 0.0036405466962605715
75 17.44025993347168 0.002665905514732003


## Train loss

In [None]:
# snerf_models.save_model(arg_dict['out.path'], model)

In [None]:
# # print(train_loss[0])
train_loss = [line.split(' ') for line in train_loss]
its = [int(l[0]) for l in train_loss]
rgb_loss = [float(l[1]) for l in train_loss]
loss_out = [rgb_loss]
if arg_dict['train.shad']:
    shad_loss = [float(l[2]) for l in train_loss]
    loss_out.append(shad_loss)

snerf_plots.plot_train_loss(its, loss_out)

## Rendering on training data

In [12]:
dataset_rend = render.render_dataset(dataset, model, ['rgb', 'depth', 'sky', 'no_shadow'], arg_dict)

In [13]:
snerf_plots.plot_results(dataset['train_imgs'], dataset_rend['train_rend'])

NameError: name 'snerf_plots' is not defined

## Render test data

In [None]:
snerf_plots.plot_results(dataset['test_imgs'], dataset_rend['test_rend'])

## Nadir view

In [None]:
def render_vertical_depth_comparison(model, arg_dict, dsm, dsm_df, thresh=1):
    SR = 0.5 * dsm_df
    radius = 617000.0/SR
    arg_dict_temp = arg_dict.copy()
    arg_dict_temp['data.image.sd'] = SR
    arg_dict_temp['data.image.df'] = 1
    az, el = np.pi, np.pi/2
    pose = data_handling.pose_spherical(az, -el, radius)
    hwf = dsm.shape[0], dsm.shape[1], radius
    light_dir=tf.reshape(tf.convert_to_tensor([np.deg2rad(100), np.deg2rad(80)], dtype=def_dtype), [1,2])
    view_dir=tf.reshape(tf.convert_to_tensor([az, el], dtype=def_dtype), [1,2])
    ret_dict = snerf_render.render_image(model, arg_dict_temp, hwf, pose, 1.0, light_dir, view_dir, rets=['no_shadow','depth'])
    plt.figure(figsize=(8, 8))
    plt.imshow(ret_dict['no_shadow'])
    plt.title('Rendered RGB')
    plt.figure(figsize=(20, 10))
    plt.subplot(121)
    disp = ret_dict['depth'] * SR
    plt.imshow(disp, vmin=np.min(dsm), vmax=np.max(dsm))
    m_e = np.mean(np.abs(disp-dsm))
    plt.title(f"Altitude rendering\n"
              f"Average error : {m_e:.4} m")
    plt.colorbar()
    plt.subplot(122)
    plt.imshow(dsm)
    plt.title('Ground truth altitude')
    plt.colorbar()
    plt.figure(figsize=(20, 10))
    plt.subplot(121)
    a_max = max(arg_dict['rend.mode.alt.max'],  -arg_dict['rend.mode.alt.min'])
    plt.imshow(disp-dsm, cmap = 'rainbow', vmin=-20.0, vmax=20.0)
    plt.title('Difference between estimated surface altitude and lidar DSM')
    plt.colorbar()
    plt.subplot(122)
    errors = (disp-dsm).numpy().flatten()
    plt.hist(errors, bins = 64)
    N_good_pixels = np.sum(np.where(np.abs(errors) < thresh, 1, 0))
    plt.title(f"Histogram of altitude errors\n"
              f"Pixels with error < {thresh}m : {100*N_good_pixels/hwf[0]/hwf[1]:.3}%")
    plt.show()


In [None]:
dsm = dataset['depth_map']
print(np.max(dsm))
render_vertical_depth_comparison(model, arg_dict, dsm, arg_dict['data.depth.df'], thresh=1)

## Shadow interpolation

In [None]:
hwf = [dataset['train_imgs'][0].shape[0], dataset['train_imgs'][0].shape[1], 617000.0/0.3/arg_dict["data.image.df"]]
light_start = [160.0, 33.5]
light_end =[114.7, 74.5]
view_angle=(np.pi, np.pi/2)
snerf_plots.plot_light_angle_inter(model, arg_dict, hwf, light_start, light_end, view_angle, nplots=20, 
                                   rets=['rgb','depth', 'sky', 'no_shadow'])


## Data set video

In [None]:
data_url_ds=snerf_plots.train_data_video(dataset, arg_dict['out.path'])
HTML(f"""
<video width=400 controls autoplay loop>
      <source src="{data_url_ds}" type="video/mp4">
</video>
""")

## Flyover video

In [None]:
hwf = [dataset['train_imgs'][0].shape[0], dataset['train_imgs'][0].shape[1], 617000.0/0.3/arg_dict["data.image.df"]]

data_url_f  = snerf_plots.render_flyover_video(arg_dict['out.path'], model, arg_dict, hwf, light_start, light_end, 
                                               rets=['rgb', 'depth', 'ret_sun', 'no_shadow'])
HTML(f"""
<video width=400 controls autoplay loop>
      <source src="{data_url_f[0]}" type="video/mp4">
</video>
<video width=400 controls autoplay loop>
      <source src="{data_url_f[1]}" type="video/mp4">
</video>
<video width=400 controls autoplay loop>
      <source src="{data_url_f[2]}" type="video/mp4">
</video>
<video width=400 controls autoplay loop>
      <source src="{data_url_f[3]}" type="video/mp4">
</video>
""")

## Qualitative results

In [None]:
result_table = snerf_train.score_overview(snerf_train.test_model(model, dataset, arg_dict), train_loss)
print(result_table)
result_table = [line.split(',') for line in result_table]
display(HTML(tabulate.tabulate(result_table, tablefmt='html')))
