In [None]:
__author__ = "Jose David Marroquin Toledo"
__credits__ = ["Jose David Marroquin Toledo", ]
__email__ = "jose@marroquin.cl"
__status__ = "Development"

# Fourier Ptychographic Imaging

## 1 Forward Imaging Model

In this process, the LED grid illuminates the sample with one LED at a time and the camera takes a photo. Each ilumination has a different incident angle. This process consists of three stages:

1. Create a Hi-Res complex image.

2. Generate the incident wave vectors.

3. Produce the ouput Lo-Res images.

#####  [2 The Recovery Process](phaseretrieval.ipynb)

In [None]:
from PIL import Image
import numpy as np
import math
import decimal
import scipy.misc

#### 1.1 Create the Hi-Res complex image.

In [None]:
def sim_sample(amplitude, phase):
    """Simulates a sample, a Hi-Res complex object, to be used as the
    input of the forward imaging process of Fourier Ptychography and
    returns it as a numpy.ndarray.
    """
    print('<Amplitude-FileOpen>')
    amplitude = Image.open(amplitude)
    print('<Phase-FileOpen>')
    phase = Image.open(phase)
    w, h = amplitude.size  # Size of the complex input object.
    # Resize the phase image to the amplitude image's size. Some
    # programming languages such as MATLAB, the resize methods use
    # bicubic interpolation by default.
    phase = phase.resize((w, h), resample=Image.BICUBIC)
    
    arr_ampl = np.array(amplitude,
                        dtype='d')  # 'd' (str) for a double-precision
                                    # floating-point number. 
    arr_phase = np.array(phase, dtype='d')
    print('<Amplitude-FileClose>')
    amplitude.close()
    print('<Phasee-FileClose>')
    phase.close()
    arr_phase = math.pi * arr_phase / np.amax(arr_phase)
    obj = arr_ampl * np.exp(1j * arr_phase)
    obj = np.absolute(obj)
    return obj, w, h

In [None]:
def sim_sqr_grid(leds_per_row, distance):
    """Simulates a LED leds_per_row-by-leds_per_row matrix. Returns it
    as a list of (x, y) coordinates (tuple). The central LEDs are the
    closest to (0, 0).
    """
    if leds_per_row % 2 != 0:
        # In Python 3, the fraction is not lost when dividing by an
        # integer.
        x_max = math.floor(leds_per_row / 2) * distance
    else:
        x_max = (leds_per_row - 1) / 2 * distance
    x_min = -x_max
    row = list()
    grid = list()
    lst_range = list()
    for i in range(leds_per_row):
        lst_range.append(x_min)
        x_min += distance
    for y in lst_range:
        for x in lst_range:
            row.append((x, -y))
        grid.append(row)
        row = list()
    return grid

#### 1.2 Generate the incident wave vectors.

In [None]:
def gen_wave_vectors(xy_list, height):
    """Generates the incident wave vectors based on the LED positions
    and the distance between the sample and the LED lamp. Holds the unit
    of length.
    
    Args:
        xy_list: A list of tuples that represents the (x, y) coordinate
            of each LED in the lamp.
        height: The distance in mm between the sample and the LED lamp.
    
    Returns:
        k_xy_tuple_list: A list of tuples that represents the components
            of the incident wave vectors.
    """
    # Operate on xy_list (list of tuples) as numpy.ndarray
    xy_list_arr = np.array(xy_list)
    xy_list_arr = xy_list_arr / float(height)
    xy_list_arr = np.arctan(xy_list_arr)
    xy_list_arr = np.sin(xy_list_arr)
    xy_list_arr *= -1
    # Convert 2-dimensional array into a list of tuples.
    k_xy_tuple_list = list(map(lambda xy: tuple(xy), list(xy_list_arr)))
    return k_xy_tuple_list

In [None]:
def round_half_up(num):
    """In Python 3, the round() function had changed from Python 2.
    For example, in Python 3, round(2.5) returns 2 (int) such
    round(1.5). For that case, it is possible to get 3 (int) instead 2
    using the code below [1].
    
    [1] Barthelemy. (2014). Python 3.x rounding behavior. Message
    posted to
    http://stackoverflow.com/questions/10825926/python-3-x-rounding-behavior
    """
    return int(decimal.Decimal(num).quantize(decimal.Decimal(1),
                                             rounding=decimal.ROUND_HALF_UP))

In [None]:
def gen_cft(wave_vectors, wavelength, ccd_px, na, hires_w, hires_h):
    """Generates the coherent transfer function of the coherent imaging
    system.
    
    Args:
        wave_vectors: A 3-dimensional numpy.ndarray that represents the
            incident wave vectors.
        wavelength: The wavelength in mm.
        ccd_px: The sampling pixel size in mm of the CCD.
        na: The numerical aperture of the objective lens.
        hires_w: The width in pixels (int) of the Hi-Res output image.
        hires_h: The height (int) of the Hi-Res output image.
    
    Returns:
        cft: The coherent transfer function (numpy.ndarray).
        lores_w: The width of the Lo-Res output images in pixels (int).
        lores_h: The height of the Lo-Res output images in pixels (int).
        dkx: (float).
        dky: (float).
        kx: A numpy.ndarray with the x-components of k_0 * wave_vectors.
        ky: A numpy.ndarray with the y-components of k_0 * wave_vectors.
    """
    k_max = math.pi / ccd_px
    k_0 = 2 * math.pi / wavelength
    cutoff_freq = na * k_0
    hires_px = ccd_px / 4  # Pixel size of the reconstruction.
    dkx = 2 * math.pi / (hires_px * hires_w)
    dky = 2 * math.pi / (hires_px * hires_h)
    lores_w = int(hires_w / (ccd_px / hires_px))
    lores_h = int(hires_h / (ccd_px / hires_px))
    k = k_0 * wave_vectors;
    # From all rows, all columns, extract the first element.
    kx = k[:, :, 0]
    # Reshape kx (numpy.ndarray) in a 1-D array.
    kx = np.reshape(kx, len(kx.flat))
    ky = k[:, :, 1]
    ky = np.reshape(ky, len(ky.flat))
    kxm, kym = np.meshgrid(np.arange(-k_max, k_max + 1,
                                     k_max / ((lores_w - 1) / 2)),
                           np.arange(-k_max, k_max + 1,
                                     k_max / ((lores_h - 1) / 2)))
    cft = ((kxm ** 2 + kym ** 2) < cutoff_freq ** 2)
    cft = cft.astype(float)
    return cft, lores_w, lores_h, dkx, dky, kx, ky

#### 1.3 Produce the ouput Lo-Res images.

In [None]:
def sim_lores_set(hires_obj, hires_w, hires_h, n_leds, cft, lores_w,
                  lores_h, dkx, dky, kx, ky, output_path, **kwargs):
    """Simulates the multiple intensity Lo-Res images capturing them
    under different incident angles.
    
    Args:
        hires_obj: A Hi-Res complex input image.
        hires_w: The width of the Hi-Res output image in pixels (int).
        hires_h: The height of the Hi-Res output image in pixels (int).
        n_leds: The total number of LEDs of the lamp.
        cft: The coherent transfer function (numpy.ndarray).
        lores_w: The width of the Lo-Res output images in pixels (int).
        lores_h: The height of the Lo-Res output images in pixels (int).
        dkx: (float).
        dky: (float).
        kx: A numpy.ndarray with the x-components of k_0 * wave vectors.
        ky: A numpy.ndarray with the y-components of k_0 * wave vectors.
        output_path: The path of the directory for the Lo-Res image set.
        **kwargs: Keyword arguments.
    """
    img_prefix = kwargs.pop('prefix', 'IMG_')
    img_format = kwargs.pop('format', 'TIF')
    if kwargs:
        raise TypeError('{!s}() got an unexpected keyword argument {!r}'.format(generate_lores_set.__name__,
                  list(kwargs.keys())[-1]))
    obj_ft = np.fft.fftshift(np.fft.fft2(hires_obj))
    lores_img_seq = list()  # The Lo-Res images sequence.
    for i in range(n_leds):
        kxc = round_half_up((hires_w + 1) / 2 + kx[i] / dkx)
        kyc = round_half_up((hires_h + 1) / 2 + ky[i] / dky)
        kxl = round_half_up(kxc - (lores_w - 1) / 2)
        kyl = round_half_up(kyc - (lores_h - 1) / 2)
        kxh = round_half_up(kxc + (lores_w - 1) / 2)
        kyh = round_half_up(kyc + (lores_h - 1) / 2)
        lores_img_seq_ft = (lores_w / hires_w) ** 2
        lores_img_seq_ft *= obj_ft[kyl - 1:kyh, kxl - 1:kxh]
        lores_img_seq_ft *= cft
        lores_img = np.fft.ifft2(np.fft.ifftshift(lores_img_seq_ft))
        lores_img = np.absolute(lores_img)
        img_idx = str(i + 1)
        for j in range(len(str(n_leds)) - len(str(i + 1))):
            img_idx = '0' + img_idx
        filename = img_prefix + img_idx +  '.'
        filename += img_format.lower()
        file_path = output_path + '/' + filename
        print('<LoResImgWrite-' + img_idx + '>')
        # scipy.misc.imsave() rescales the dynamic range of the pixel
        # values [2]. Add cmin and cmax as parameters to
        # scipy.misc.toimage() as follow to prevent the rescaling.
        # 
        # [2] Lippens, S. (2016). Saving of images in scipy and
        # preventing dynamic range rescaling.
        scipy.misc.toimage(lores_img, cmin=0, cmax=255).save(file_path)