## tf-garf

This is an implementation of "[Gaussian Activated Radiance Fields](https://arxiv.org/abs/2204.05735)" by Chng et al. in TensorFlow 2.

<a href="https://twitter.com/laura_a_n_n"><img src="https://raw.githubusercontent.com/FortAwesome/Font-Awesome/6.x/svgs/brands/twitter.svg" alt="Twitter" style="width: 24px; height: 24px" /></a>
<a href="https://github.com/laura-a-n-n"><img src="https://raw.githubusercontent.com/FortAwesome/Font-Awesome/b452a2c086a5e3f319df61b1ce1db7d8e1ad2b7c/svgs/brands/github.svg" alt="GitHub" style="width: 24px; height: 24px" /></a>

### Data

In [None]:
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
from IPython.display import Image

from lib.data import load_data

DATA_PATH = 'path/to/llff/data' # put your data here!
data = load_data(DATA_PATH)

### Train (optional)

In [None]:
''' training !!! '''

from model.garf import GaussianRadianceField
from train import train

model = GaussianRadianceField(data, num_samples=128)
model.compile()

# note that the batch size is actually 64 *times* the number of images
train(model, 200000, batch_size=64, val_idx='rand', out_path='movie_new/.png', notebook=True, save=True, overwrite=False)

### Render

In [None]:
''' rendering/validation !!! '''

from model.garf import GaussianRadianceField

model = GaussianRadianceField(data, num_samples=128)
model.compile()
model.load('pretrain/flowers', opt=False)

In [None]:
avg_psnr = 0
n_psnr = 0

for i, yy in enumerate(model.img_rgb):
    xx = model.render(i)
    
    psnr = tf.image.psnr(xx, yy, 1.)
    if ~tf.math.is_nan(psnr):
        avg_psnr += psnr
        n_psnr += 1
        
    tf.print(f'PSNR for view index {i}: {psnr:.5f}')

    # plt.imsave(f'val/{i}.png', np.array(xx)) 
    plt.imshow(xx)
    plt.show()
    plt.imshow(yy)
    plt.show()

avg_psnr /= n_psnr
tf.print(f'Average PSNR {avg_psnr:.5f}')

In [None]:
''' spiral gif! '''

import imageio
import glob
import os

from lib.rays import create_spiral_poses

''' params '''
file_path = 'render' # path/folder name (no trailing slash)
file_type = 'png' # extension
N = 3 # filename stem length

frames = 55 # number of frames in output
n_circ = 2 # number of circles in total

spiral_axes = [.2, .2, .05] # x, y, z
spiral_depth = 1.2 # focus depth
''' end params '''

out_path = file_path + '/.' + file_type

if not os.path.exists(file_path):
    os.mkdir(file_path)

# create poses
m_poses = tf.cast(create_spiral_poses(np.array(spiral_axes), spiral_depth, n_poses=frames, n_circ=n_circ), tf.float32)

for i in range(tf.shape(m_poses)[0]):
    name = list('0' * N)
    name[-len(str(i)):N] = str(i)
    name = out_path[:-4] + (''.join(name)) + out_path[-4:]
    plt.imsave(name, np.array(model.render(pose=m_poses[i])))

# render gif
with imageio.get_writer('movie.gif', mode='I') as writer:
    for filename in sorted(glob.glob('render/*')):
        image = imageio.imread(filename)
        writer.append_data(image)

Image(filename='movie.gif')