# TensorFlow integration

HCIPy supports not only Numpy as its numerical backend, but also TensorFlow. This allows for using GPU acceleration as well as automatic differentiation for fast gradient evaluation. In this tutorial, we will propagate through a few optical systems with Tensorflow and perform gradient descent optimization.

In [None]:
from hcipy import *
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

import time

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

## Basics

In [None]:
pupil_grid = make_pupil_grid(256)
focal_grid = make_focal_grid(8, 16)
prop = FraunhoferPropagator(pupil_grid, focal_grid)

In [None]:
# Numpy
aperture_np = evaluate_supersampled(circular_aperture(1), pupil_grid, 8)

# TensorFlow
aperture_tf = aperture_np.as_backend('tensorflow')

In [None]:
print(type(aperture_np).__name__)
print(type(aperture_tf).__name__)

In [None]:
# Numpy
img_np = prop(Wavefront(aperture_np))

start = time.time()
for i in range(100):
    img_np = prop(Wavefront(aperture_np))
end = time.time()
print((end - start) / 100, 's')

# TensorFlow
img_tf = prop(Wavefront(aperture_tf))

start = time.time()
for i in range(100):
    img_tf = prop(Wavefront(aperture_tf))
end = time.time()
print((end - start) / 100, 's')

In [None]:
plt.subplot(1,2,1)
imshow_psf(img_np, normalization='peak', title='Numpy')
plt.subplot(1,2,2)
imshow_psf(img_tf, normalization='peak', title='TensorFlow')
plt.show()

In [None]:
influence_functions = make_xinetics_influence_functions(pupil_grid, 32, 1 / 32)
dm = DeformableMirror(influence_functions)

In [None]:
dm.random(0.03)

In [None]:
# Numpy
img_np = prop(dm(Wavefront(aperture_np)))

start = time.time()
for i in range(100):
    img_np = prop(dm(Wavefront(aperture_np)))
end = time.time()
print((end - start) / 100, 's')

# TensorFlow
img_tf = prop(dm(Wavefront(aperture_tf)))

start = time.time()
for i in range(100):
    img_tf = prop(dm(Wavefront(aperture_tf)))
end = time.time()
print((end - start) / 100, 's')

In [None]:
plt.subplot(1,2,1)
imshow_psf(img_np, normalization='peak', title='Numpy')
plt.subplot(1,2,2)
imshow_psf(img_tf, normalization='peak', title='TensorFlow')
plt.show()

In [None]:
lyot_stop = circular_aperture(0.95)(pupil_grid)
coro = VortexCoronagraph(pupil_grid, 2, lyot_stop=lyot_stop)

# Numpy
coro_img_np = prop(coro(dm(Wavefront(aperture_np))))

start = time.time()
for i in range(100):
    coro_img_np = prop(coro(dm(Wavefront(aperture_np))))
end = time.time()
print((end - start) / 100, 's')

# TensorFlow
coro_img_tf = prop(coro(dm(Wavefront(aperture_tf))))

start = time.time()
for i in range(100):
    coro_img_tf = prop(coro(dm(Wavefront(aperture_tf))))
end = time.time()
print((end - start) / 100, 's')

In [None]:
plt.subplot(1,2,1)
imshow_psf(coro_img_np, normalization=img_np.power.max(), title='Numpy')
plt.subplot(1,2,2)
imshow_psf(coro_img_tf, normalization=img_tf.power.max(), title='TensorFlow')
plt.show()

print(coro_img_tf.power.mean() / img_tf.power.max())

In [None]:
@tf.function(autograph=False)
def img():
    return prop(coro(dm(Wavefront(aperture_tf)))).total_power

In [None]:
start = time.time()
for i in range(100):
    coro_img_tf = img()
coro_img_tf.numpy()
end = time.time()
print((end - start) / 100, 's')

In [None]:
import tensorflow_probability as tfp

@tf.function(autograph=False)
def f(x):
    dm.actuators.assign(x)
    with tf.GradientTape() as tape:
        power = img()
    
    grad = tape.gradient(power, dm.actuators)
    grad = tf.scatter_nd(tf.reshape(grad.indices, grad.indices.shape + [1]), grad.values, grad.dense_shape)
    
    return power, grad

start = tf.random.normal([dm.num_actuators], dtype='float64') * 0.03

optim_results = tfp.optimizer.lbfgs_minimize(f, initial_position=start, num_correction_pairs=10, tolerance=1e-3)
print(optim_results.position.numpy())
print(optim_results.num_iterations.numpy())

In [None]:
dm.actuators.assign(optim_results.position)
coro_img_tf = prop(coro(dm(Wavefront(aperture_tf))))

imshow_psf(coro_img_tf, normalization=img_tf.power.max(), title='TensorFlow')
plt.show()

print(coro_img_tf.power.mean() / img_tf.power.max())