## Inference Code for 'Neural Nano-Optics for High-quality Thin Lens Imaging'

#### This notebook can be used to produce the experimental reconstructions shown in the manuscript and in the supplemental information.

In [1]:
import tensorflow as tf
# import keras
import numpy as np
from networks.select import select_G
from args import parse_args
import metasurface.solver as solver
import metasurface.conv as conv
import matplotlib.pyplot as plt
import sys

In [2]:
# Set up the arguments for real inference
sys.argv=['','--train_dir','.',\
             '--test_dir' ,'.',\
             '--save_dir' ,'.',\
             '--ckpt_dir' ,'experimental/ckpt/',\
             '--real_psf' ,'./experimental/data/psf/psf.npy',\
             '--psf_mode' ,'REAL_PSF',\
             '--conv_mode','REAL',\
             '--conv'     ,'full_size']
args = parse_args()

Namespace(train_dir='.', test_dir='.', save_dir='.', save_freq=1000, log_freq=500, ckpt_dir='experimental/ckpt/', max_to_keep=2, loss_mode='L1', batch_weights=[1.0], Norm_loss_weight=1.0, P_loss_weight=0.0, Spatial_loss_weight=0.0, vgg_layers='block2_conv2,block3_conv2', steps=1000000000, aug_rotate=False, real_psf='./experimental/data/psf/psf.npy', psf_mode='REAL_PSF', conv_mode='REAL', conv='full_size', do_taper=True, offset=True, normalize_psf=False, theta_base=[0.0, 5.0, 10.0, 15.0], num_coeffs=8, use_general_phase=False, metasurface='zeros', s1=0.0009, s2=0.0014, alpha=270.176968209, target_wavelength=5.11e-07, bound_val=1000.0, a_poisson=4e-05, b_sqrt=1e-05, mag=8.1, Phase_iters=1, Phase_lr=0.005, Phase_beta1=0.9, G_iters=1, G_lr=0.0001, G_beta1=0.9, G_network='FP', snr_opt=False, snr_init=4.0)


In [3]:
# Initialize and restore deconvolution method
params = solver.initialize_params(args)
params['conv_fn'] = conv.convolution_tf(params, args)
params['deconv_fn'] = conv.deconvolution_tf(params, args)

snr = tf.Variable(args.snr_init, dtype=tf.float32)
G = select_G(params, args)
checkpoint = tf.train.Checkpoint(G=G, snr=snr)

status = checkpoint.restore(tf.train.latest_checkpoint(args.ckpt_dir, latest_filename=None))
status.expect_partial()

def phase_func(x,a2,a4,a6,a8,a10,a12,a14,a16): return a2*x**2 + a4*x**4 + a6*x**6 + a8*x**8 + a10*x**10 + a12*x**12 + a14*x**14 + a16*x**16
360
Image width: 720
PSF width: 360
Load width: 1440
Network width: 1080
Out width: 720


ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:

```
x = Input(...)
...
tf_fn(x)  # Invalid.
```

What you should do instead is wrap `tf_fn` in a layer:

```
class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)
```


# Perform deconvolution

In [None]:
# Check that the dimensions agree with experimental captures
assert(params['image_width'] == 720)
assert(params['psf_width'] == 360)
assert(params['network_width'] == 1080)

In [None]:
# Load in experimentally measured PSFs
psf = (np.load('./experimental/data/psf/psf.npy'))
psf = tf.constant(psf)
psf = tf.image.resize_with_crop_or_pad(psf, params['psf_width'], params['psf_width'])
psf = psf / tf.reduce_sum(psf, axis=(1,2), keepdims=True)

In [None]:
def reconstruct(img_name, psf, snr, G):
    img = np.load(img_name)
    _, G_img, _ = params['deconv_fn'](img, psf, snr, G, training=False)
    G_img_ = G_img.numpy()[0,:,:,:]

    # Vignette Correct
    vig_factor = np.load('experimental/data/vignette_factor.npy')[0,:,:,:]
    G_img_ = G_img_ * vig_factor
    
    # Gain
    G_img_ = G_img_ * 1.2
    G_img_[G_img_ > 1.0] = 1.0

    # Contrast Normalization
    minval = np.percentile(G_img_, 5)
    maxval = np.percentile(G_img_, 95)
    G_img_ = np.clip(G_img_, minval, maxval)
    G_img_ = (G_img_ - minval) / (maxval - minval)
    G_img_[G_img_ > 1.0] = 1.0

    plt.figure(figsize=(6,6))
    plt.imshow(G_img_)

### Reconstruct Images

In [None]:
# Figure 3
reconstruct('./experimental/data/captures/138301.npy', psf, snr, G)
reconstruct('./experimental/data/captures/102302.npy', psf, snr, G)
reconstruct('./experimental/data/captures/110802.npy', psf, snr, G)