In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from DT import DT
from time import time
import scipy.io

sess = tf.InteractiveSession()

### Define settings for optimization
Define experimental hyperparameters (e.g., scattering model, regularization settings) for the optimization run as a Python dictionary. Below are the settings used for the experimental results in the paper (Figs. 7-10) -- choose one of these or create your own. Reduce batch size if your GPU runs out of memory (the batch sizes specified below should work for a 16-GB GPU). 

```python
# 2-um bead sample settings using first Born model (Fig. 7):
experiment = {'sample': '2umbeads', 'model': 'born',
              'reg': 'None', 'num_iter': 500, 'batch_size': 961}
experiment = {'sample': '2umbeads', 'model': 'born',
              'reg': 'DIP', 'activation': 'leaky', 'num_iter': 20000, 'batch_size': 961}
experiment = {'sample': '2umbeads', 'model': 'born',
              'reg': 'positivity', 'coeff': 1e-3, 'num_iter': 500, 'batch_size': 961}
experiment = {'sample': '2umbeads', 'model': 'born',
              'reg': 'TV', 'coeff': 1e-7, 'num_iter': 500, 'batch_size': 961}

# 1-um bead (actually 800-nm) sample settings using first Born model (Fig. 8):
experiment = {'sample': '1umbeads', 'model': 'born',
              'reg': 'None', 'num_iter': 500, 'batch_size': 961}
experiment = {'sample': '1umbeads', 'model': 'born',
              'reg': 'DIP', 'activation': 'linear', 'num_iter': 20000, 'batch_size': 961}
experiment = {'sample': '1umbeads', 'model': 'born',
              'reg': 'negativity', 'coeff': 1e-3, 'num_iter': 500, 'batch_size': 961}
experiment = {'sample': '1umbeads', 'model': 'born',
              'reg': 'TV', 'coeff': 1e-7, 'num_iter': 500, 'batch_size': 961}

# starfish sample settings using first Born model (Fig. 9):
experiment = {'sample': 'starfish', 'model': 'born',
              'reg': 'None', 'num_iter': 500, 'batch_size': 481}
experiment = {'sample': 'starfish', 'model': 'born',
              'reg': 'DIP', 'activation': 'linear', 'num_iter': 10000, 'batch_size': 481}
experiment = {'sample': 'starfish', 'model': 'born',
              'reg': 'positivity', 'coeff': 1e-2, 'num_iter': 500, 'batch_size': 481}
experiment = {'sample': 'starfish', 'model': 'born',
              'reg': 'TV', 'coeff': 1e-7, 'num_iter': 500, 'batch_size': 481}

# 2-um bead sample settings using multislice model (Fig. 10):
experiment = {'sample': '2umbeads', 'model': 'multislice',
              'reg': 'None', 'num_iter': 5000, 'batch_size': 961}
experiment = {'sample': '2umbeads', 'model': 'multislice',
              'reg': 'DIP', 'activation': 'leaky', 'num_iter': 60000, 'batch_size': 481}
experiment = {'sample': '2umbeads', 'model': 'multislice',
              'reg': 'positivity', 'coeff': 1e-3, 'num_iter': 5000, 'batch_size': 961}
experiment = {'sample': '2umbeads', 'model': 'multislice',
              'reg': 'TV', 'coeff': 1e-7, 'num_iter': 30000, 'batch_size': 961}
```

In [None]:
# settings for the optimization run (create your own or use one of the above):
experiment = {'sample': 'starfish', 'model': 'born', 'reg': 'None', 'num_iter': 500, 'batch_size': 481}

## Load multi-angle dataset
The raw data includes the full 500x500-pixel field of view (for 961 LEDs), but only a reduced field of view is reconstructed, depending on the experiment.

In [None]:
# load data from DPDT_raw_data folder:
if '2um' in experiment['sample']:
    path = './DPDT_raw_data/2umbead2layer.mat'
elif '1um' in experiment['sample']:
    path = './DPDT_raw_data/1umbead.mat'
elif 'star' in experiment['sample']:
    path = './DPDT_raw_data/starfish.mat'

stack = scipy.io.loadmat(path)
stack = stack['hdrarray_out'].astype(np.float32)
stack = np.transpose(stack, (2, 0, 1))
stack = stack[None, :, None]

# reduce the field of view:
if 'born' in experiment['model']:
    if '1um' in experiment['sample']:
        stack = stack[..., 200:363, 0:163]
    elif '2um' in experiment['sample']:
        stack = stack[..., 240:403, 20:183]
    elif 'star' in experiment['sample']:
        stack = stack[..., 0:450, 20:470]
else:
    stack = stack[..., 240:240+256, :256]

# plot example image:
plt.imshow(stack[0, 480, 0])
plt.title('center LED image')
plt.show()    

### Define tensorflow graph based on specified settings
See the DT.py file for more detailed explanations of the various settings modified below.

In [None]:
if 'born' in experiment['model']:
    if 'starfish' in experiment['sample']:
        a = DT(im_size=stack.shape[-1], xy_upsamp=1, z_upsamp=2, z_fov_upsamp=0.5)
        a.coordinate_offset = np.array([0, 0, 0], dtype=np.float32)
    elif '1um' in experiment['sample']:
        a = DT(im_size=stack.shape[-1], xy_upsamp=2, z_upsamp=4, z_fov_upsamp=1)
        a.coordinate_offset = np.array([1, 1, 0], dtype=np.float32)
    elif '2um' in experiment['sample']:
        a = DT(im_size=stack.shape[-1], xy_upsamp=1, z_upsamp=4, z_fov_upsamp=1)
        a.coordinate_offset = np.array([.5, .5, 0], dtype=np.float32)
    else:
        raise Exception('invalid sample')

    a.force_pass_thru_DC = True
    a.train_DC = True
        
    if 'DIP' in experiment['reg']:
        a.optimize_k_directly = False
    else:
        a.optimize_k_directly = True

if 'multislice' in experiment['model']:
    # here, the spatial patch size is 128x128, but the full reconstruction size is 256x256
    a = DT(im_size=128, im_size_full=256, scattering_model='multislice', xy_upsamp=1,
            sample_thickness_MS=10, sample_pix_MS=32, use_spatial_patching=True)

    a.force_pass_thru_DC = False
    a.train_DC = False
    a.focus_init = -3
    a.optimize_k_directly = False       

# set regularization hyperparameters:
if 'DIP' in experiment['reg']:
    a.use_deep_image_prior = True
    if experiment['activation'] == 'leaky':
        a.linear_DIP_output = False
    elif experiment['activation'] == 'linear':
        a.linear_DIP_output = True
    else:
        raise Exception('invalid activation')
else:
    a.use_deep_image_prior = False
    
if 'TV' in experiment['reg']:
    a.TV_reg_coeff = experiment['coeff']
else:
    a.TV_reg_coeff = 0
    
if 'positivity' in experiment['reg']:
    a.positivity_reg_coeff = experiment['coeff']
else:
    a.positivity_reg_coeff = 0
    
if 'negativity' in experiment['reg']:
    a.negativity_reg_coeff = experiment['coeff']
else:
    a.negativity_reg_coeff = 0
    
a.batch_size = experiment['batch_size']
    
# for ignoring dark field LEDs:
x = np.linspace(-1, 1, 31)
x, y = np.meshgrid(x, x)
background = x**2 + y**2 < .9
background = background.flatten()
a.data_ignore = ~background

# create graph and session for optimzation:
median = np.median(stack, (0, 2, 3, 4))
a.format_DT_data(stack / median[None, :, None, None, None], DC=np.ones_like(median))
a.reconstruct()
sess.run(tf.global_variables_initializer())

### Optimize
Run the optimization loop, monitoring progress periodically.

In [None]:
losses = list()  # to store the loss history
feed = {}  # feed_dict for the gradient step; may be used to modify learning rate for DIP
num_iter = experiment['num_iter']  # number of gradient steps
plot_iter = num_iter // 10  # plot cross-sections of the reconstruction every plot_iter iterations

for i in range(num_iter):

    start = time()
    _, loss_i = sess.run([a.train_op, a.loss], feed_dict=feed)
    losses.append(loss_i)
    
    if i % 10 == 0:
        print(i, loss_i, time()-start)
        # loss_i is a list of all the loss terms, with the first being the data-dependent loss

    # monitor results periodically:
    if i % plot_iter == 0 or i == num_iter-1:
        RI = a.RI.eval()

        # plot an xz and xy cross section:
        plt.figure(figsize=(15,8))
        plt.subplot(121)
        plt.imshow(np.real(RI[19].T))
        plt.title('RI (xz slice)')
        plt.colorbar()
        plt.subplot(122)
        plt.imshow(np.real(RI[:, :, a.side_kz//2 + 2]))
        plt.title('RI (xy slice)')
        plt.colorbar()
        plt.show()

        # plot loss curve on linear and log scale;
        # the data-dependent term is separated from the regularization terms in this plot;
        plt.figure(figsize=(10,5))
        plt.subplot(121)
        plt.plot(losses)
        plt.title('loss')
        plt.subplot(122)
        plt.plot(np.log(losses))
        plt.title('log loss')
        plt.legend(['data-dependent loss'])
        plt.show()

    # roll back if DIP optimization diverges and anneal learning rate (only for first Born model):
    save_iter = 15
    if a.use_deep_image_prior and not a.scattering_model == 'multislice':
        # average the last few losses (only the data loss portion):
        last_few_DIP_losses = [loss[0] for loss in losses[-6:-1]]
        if i > 500 and losses[-1][0] > 4 * np.mean(last_few_DIP_losses):
            print('Optimization diverged; rolling back ...')
            # don't restore the last one, but the last last one:
            if last_saved == 'model':
                a.saver.restore(sess, '/tmp/model_.ckpt')
            elif last_saved == 'model_':
                a.saver.restore(sess, '/tmp/model.ckpt')
            feed[a.DIP_lr] *= .9  # anneal learning rate
        elif i % (save_iter * 2) == 0:
            a.saver.save(sess, '/tmp/model.ckpt')
            last_saved = 'model'
        elif i % save_iter == 0:
            # save 2 checkpoints to be sure to roll back at least save_iter iterations
            a.saver.save(sess, '/tmp/model_.ckpt')
            last_saved = 'model_'
        else:
            pass

### Extra step for DIP reconstructions using the mutislice model and spatial patching
It is too computationally intensive for the DIP to generate the whole reconstruction field of view. Instead, we use "spatial patching" where at every iteration only a small, randomly chosen patch within the full field of view contributes to the loss. The DIP output matches the patch size and uses a common network for all spatial patches. To reconstruct the whole field of view, we use our stochastic stitching algorithm, which involves selecting many patches randomly and stitching them together.

Using our spatial patching scheme, we can theoretically reconstruct an arbitrarily-sized reconstruction seamlessly with a fixed memory budget. For more details, see the appendix of our paper, which is listed in the README file.

In [None]:
if 'DIP' in experiment['reg'] and 'multislice' in experiment['model']:
    # create full-size reconstruction by stitching 1000 patches:
    RI_stochastic_stitch = a.stochastic_stitch(sess, depad=35)
    
    # plot cross sections of the full 3D reconstruction:
    plt.figure(figsize=(15,8))
    plt.subplot(121)
    plt.imshow(np.real(RI_stochastic_stitch[:, 19].T))
    plt.title('RI, full FOV (xz slice)')
    plt.colorbar()
    plt.subplot(122)
    plt.imshow(np.real(RI_stochastic_stitch[:, :, a.side_kz//2 + 2]))
    plt.title('RI, full FOV (xy slice)')
    plt.colorbar()
    plt.show()