In [None]:
import matplotlib.pyplot as plt
import numpy as np
import skimage.data as skdata
from project.algorithms.simulation import dummy_object, ptychogram, mesh, illumination_beam
from project.algorithms.utils import circ_aperture, normalize, nrmse, ft, ift, corr
from project.algorithms.reconstruction import update_obj, update_probe, TransRefinement
import random

In [None]:
# Turn off for white background/jupyter notebook
background_dark = True
if background_dark:
    plt.style.use('dark_background')

In [None]:
"ground truth"
intensity = np.array(plt.imread('lena.tif'))
phase = skdata.camera()
obj = dummy_object(intensity=intensity, phase=phase, output_shape=(256, 256))
box_shape = (161, 161)   # the size of reconstruction box
r = 0.5
illumination = illumination_beam(box_shape, beam_radius=r)
illumination = normalize(illumination)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(12, 3))
axes[0].imshow(np.abs(obj), cmap='gray')
axes[0].set_title('intensity of object')
axes[1].imshow(np.angle(obj), cmap='rainbow')
axes[1].set_title('phase of object')
axes[2].imshow(np.abs(illumination), cmap='gray')
axes[2].set_title('intensity of illumination')
axes[3].imshow(np.angle(illumination), cmap='rainbow')
axes[3].set_title('phase of illumination')
plt.tight_layout()

In [None]:
"positions and ptychogram"
positions = mesh((256, 256), 40, 0.9, 7, error=3)
patterns = []
for position in positions:
    pattern = ptychogram(obj, illumination, position)
    patterns.append(pattern)

In [None]:
"initial guess positions"
guess_positions = np.array([[0, 0]]*len(positions)) + obj.shape[0]//2
ini_guess = guess_positions.copy()

In [None]:
"initial estimation and parameter"
guess_probe = circ_aperture(box_shape, radius=0.4).astype('complex')
guess_obj = np.ones(obj.shape, dtype="complex")
beta = [2000] * 2
a, b = 1., 1.

In [None]:
"outputs"
loss, mpe = [], []
bx, by = [], []

In [None]:
"reconstruction"
(K, L) = guess_probe.shape
obj_pad = np.pad(guess_obj, ((K // 2, K // 2), (L // 2, L // 2)))
position_errors = np.zeros((len(positions), 2, 2))

for n in range(300):
    loss_vals = []
    index = random.sample(range(0, len(positions)), len(positions))
    position_errors[:, :, 0] = position_errors[:, :, 1].copy()
    for i in index:
        x = guess_positions[i][1]
        y = guess_positions[i][0]
        pattern = patterns[i]
        obj_scanned = obj_pad[y:y + K, x:x + L]

        "revise the wave function in diffraction plane"
        psi = obj_scanned * guess_probe
        Psi = ft(psi)
        phase_Psi = np.exp(1j * np.angle(Psi))
        Psi_corrected = np.sqrt(pattern) * phase_Psi
        psi_corrected = ift(Psi_corrected)

        "update the object and probe functions"
        diff_psi = psi_corrected - psi
        temp_obj = obj_scanned.copy()
        obj_scanned = update_obj(obj_scanned, guess_probe, diff_psi, learning_rate=a)
        obj_pad[y:y + K, x:x + L] = obj_scanned
        guess_probe = update_probe(guess_probe, temp_obj, diff_psi, learning_rate=b)

        "NRMSE"
        loss_vals.append(nrmse(np.abs(Psi), pattern))

        "position correction"
        if n >= 3:
            mask = guess_probe > 0.1
            syj, sxj = TransRefinement(obj_scanned, temp_obj, integer_skip=False)
            dy, dx = round(syj * beta[0]), round(sxj * beta[1])
            y += dy
            x += dx

            if y < 0:
                y = 1
            elif y >= guess_obj.shape[0]:
                y = guess_obj.shape[0] - 1
            if x < 0:
                x = 1
            elif x >= guess_obj.shape[1]:
                x = guess_obj.shape[1] - 1
            guess_positions[i] = np.array([y, x])
            position_errors[i, :, 1] = np.array([syj, sxj])

    "update beta"
    if n >= 5:
        ky = corr(position_errors[:, 0, 0], position_errors[:, 0, 1])
        kx = corr(position_errors[:, 1, 0], position_errors[:, 1, 1])

        if ky >= 0.5:
            beta[0] = int(beta[0] * 1.1)
        elif ky <= -0.5:
            beta[0] = int(beta[0] * 0.9)
        if kx >= 0.5:
            beta[1] = int(beta[1] * 1.1)
        elif kx <= -0.5:
            beta[1] = int(beta[1] * 0.9)
    bx.append(beta[1])
    by.append(beta[0])

    if (n + 1) > (n_IPs + 50):
        n_IPs = n
        IPs_index = np.where(loss_vals > (min(loss_vals) + 0.1))
        IPs = np.array(guess_positions)[IPs_index]
        CPs_index = np.where(loss_vals <= (min(loss_vals) + 0.1))
        CPs = np.array(guess_positions)[CPs_index]

        intensity_IPs = np.array(patterns)[IPs_index]
        intensity_CPs = np.array(patterns)[CPs_index]

        for i, iip in zip(IPs_index)


    pe = guess_positions - np.array(positions)  # position errors
    pe = np.sum(np.sqrt(pe[:, 0] ** 2 + pe[:, 1] ** 2)) / 49
    mpe.append(pe)
    loss.append(np.mean(loss_vals))

    guess_obj = obj_pad[K // 2:-K // 2 + 1, L // 2:-L // 2 + 1]
