In [None]:
__author__ = "Jose David Marroquin Toledo"
__credits__ = ["Jose David Marroquin Toledo",]
__email__ = "jose@marroquin.cl"
__status__ = "Development"

## 1. Forward Imaging Model
In this process, the sample is illuminated with one LED of the grid at time and a camera captures Lo-Res images under different incident angles. To simulate it, we have to:

1. Create a Hi-Res complex image.

2. Generate the incident wave vectors.

3. Produce the ouput Lo-Res images.

##  [2. The Recovery Process](phaseretrieval.ipynb)

In [None]:
from PIL import Image
import numpy as np
import math
import matplotlib.pyplot as plt
import decimal
import scipy.misc
import os

In [None]:
%matplotlib inline

In [None]:
# 1. Create the Hi-Res complex image.
# Return the Hi-Res complex object ('numpy.ndarray'). It can also show
# it in the notebook.
def generate_input_obj(**kwargs):
    amplitude = kwargs.pop('amplitude', '../img/cameraman.tif')
    phase = kwargs.pop('phase', '../img/westconcordorthophoto.png')
    # Show the output image in the notebook.
    show = kwargs.pop('show', False)
    amplitude = Image.open(amplitude)
    phase = Image.open(phase)
    # In MATLAB, imresize uses bicubic interpolation by default.
    phase = phase.resize((256, 256), resample=Image.BICUBIC)
    # 'd' ('str') is a character code for a double-precision
    # floating-point number.
    arr_amplitude = np.array(amplitude, dtype='d')
    arr_phase = np.array(phase, dtype='d')
    arr_phase = math.pi * arr_phase / np.amax(arr_phase)
    obj = arr_amplitude * np.exp(1j * arr_phase)
    obj = np.absolute(obj)  # The Hi-Res complex image.
    if show:
        # Using PIL, the objects is shown as an image almost
        # completely white.
        # img_obj = Image.fromarray(obj)
        # img_obj.show()
        # Using matplotlib anc 'Greys_r' ('str') as value for 'cmap'
        # key [1], the object is shown like in Matlab R2009b.
        #
        # [1] unutbu. (2010). Display image as grayscale using matplotlib [Msg 1]. Message posted to
        # http://stackoverflow.com/questions/3823752/display-image-as-grayscale-using-matplotlib?answertab=votes#tab-top
        plt_img = plt.imshow(obj, cmap='Greys_r')
    m, n = obj.shape
    return obj, m, n

In [None]:
# 2. Generate the incident wave vectors.
# Return a n-by-n array ('numpy.ndarray') with magnitude of the
# x-component, kx, and y-component, ky, of the wave vectors of the
# incident waves that emerge from the LED grid, and the length ('int')
# of the grid.
def generate_wave_vectors(**kwargs):
    nleds = kwargs.pop('nleds', 15)  # Order of the n-by-n array.
    dist = kwargs.pop('dist', 4)  # Distance in mm between LEDs.
    h = kwargs.pop('h', 90)  # Distance in mm between the LED grid and
                             # the sample.
    x_max = math.floor((nleds / 2)) * dist
    x_min = -x_max
    y_max = x_max
    y_min = x_min
    l_row = list()
    l_arr = list()
    # xy_max (int) + 1 to include xy_max.
    l_range = list(range(x_min, x_max + 1, dist))
    for i in l_range:
        for j in l_range:
            l_row.append((j, -i))
        l_arr.append(l_row)
        l_row = list()
    arr = np.array(l_arr)  # arr ('numpy.ndarray') contains (x, y)
                           # coordinates.
    arr = arr / h
    arr = np.arctan(arr)
    arr = np.sin(arr)
    arr = -arr
    return arr, nleds

In [None]:
# In Python 3, the round() function had changed. For example,
# round(2.5) returns 2 ('int') like round(1.5). It is possible to
# obtain 3 ('int') for 2.5 rounded using the decimal module [2].
#
# [2] Barthelemy. (2014). Python 3.x rounding behavior. Message posted to
# http://stackoverflow.com/questions/10825926/python-3-x-rounding-behavior
def round_half_up(num):
    return int(decimal.Decimal(num).quantize(decimal.Decimal(1),
                                         rounding=decimal.ROUND_HALF_UP))

In [None]:
# 3. Produce the output images.
# Return an array ('numpy.ndarray') with the sequence of Lo-Res
# images. Each image only contains the amplitude information.
def simulate_lores_img_set(hires_obj, m, n, wave_vectors, grid_len, **kwargs):
    outpath = kwargs.pop('outpath', '../img-lores/')
    prefix = kwargs.pop('prefix', 'lores_')
    extension = kwargs.pop('extension', '.tiff')
    wavelen = kwargs.pop('wavelen', 0.63e-6)
    # Sampling pixel size of the CCD.
    ccdpx = kwargs.pop('ccdpx', 2.75e-6)
    # Pixel size of the reconstruction.
    hirespx = kwargs.pop('hirespx', ccdpx / 4)
    na = kwargs.pop('na', 0.08)  # Numerical aperture of the employed
                                 # objective lens.
    k_0 = 2 * math.pi / wavelen;
    p = int(m / (ccdpx / hirespx))  # Number of rows for the output.
    q = int(n / (ccdpx / hirespx))  # Number of columns for the output.
    img_seq_lores = list()
    k = k_0 * wave_vectors;
    # From all rows, all columns, extract the first element.
    kx = k[:, :, 0]
    # Reshape xk ('numpy.ndarray') in 1-D array.
    kx = np.reshape(kx, len(kx.flat))
    ky = k[:, :, 1]
    ky = np.reshape(ky, len(ky.flat))
    dkx = 2 * math.pi / (hirespx * n)
    dky = 2 * math.pi / (hirespx * m)
    cutoff_freq = na * k_0
    k_max = math.pi / ccdpx
    kxm, kym = np.meshgrid(np.arange(-k_max, k_max + 1,
                                     k_max / ((q - 1) / 2)),
                           np.arange(-k_max, k_max + 1,
                                     k_max / ((q - 1) / 2)))
    coherent_transfer_funct = ((kxm ** 2 + kym ** 2) < cutoff_freq ** 2)
    # E.g., convert from [[False, True, ..., False, False],
    #                     ...,
    #                     [True, False, ..., True, True]]
    #                 to [[0, 1, ..., 0, 0],
    #                     , ...,
    #                     [1, 0, ..., 1, 1]] ('numpy.ndarray').
    coherent_transfer_funct = coherent_transfer_funct.astype(float)
    obj_ft = np.fft.fftshift(np.fft.fft2(hires_obj))
    for i in range(grid_len ** 2):
        kxc = round_half_up((m + 1) / 2.0 + kx[i] / dkx)
        kyc = round_half_up((m + 1) / 2.0 + ky[i] / dky)
        kyl = round_half_up(kyc - (p - 1) / 2.0);
        kyh = round_half_up(kyc + (p - 1) / 2.0);
        kxl = round_half_up(kxc - (q - 1) / 2.0);
        kxh = round_half_up(kxc + (q - 1) / 2.0);
        img_seq_lores_ft = (p / m) ** 2
        img_seq_lores_ft *= obj_ft[kyl - 1:kyh, kxl - 1:kxh]
        img_seq_lores_ft *= coherent_transfer_funct
        img_lores = np.absolute(np.fft.ifft2(np.fft.ifftshift(img_seq_lores_ft)))
        img_seq_lores.append(img_lores)
        # scipy.misc.save() saves the TIFF file as 'uint8' [3].
        #
        # [3] Olsson, T. (2015). Saving 16-bit tiff files using Python. Retrieved from
        # http://tjelvarolsson.com/blog/saving-16bit-tiff-files-using-python/
        if not os.path.exists(outpath):
            os.makedirs(outpath)
        scipy.misc.imsave(outpath + prefix + str(i) + extension, img_lores)
    # plt_img = plt.imshow(img_seq_lores[109], cmap='Greys_r')
    return p, q, kx, ky, dkx, dky, coherent_transfer_funct, np.array(img_seq_lores)

## References
Zheng, G. (2015). *Fourier Ptychographic Imaging: A MATLAB&reg; tutorial*. San Rafael, CA: Morgan &amp; Claypool Publishers.  