# Initialize

## Basic libraries

In [None]:
import os
from os.path import join
from tqdm import tqdm

import numpy as np
import scipy as sp
#import pyqtgraph_extended as pg
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib.colors import LogNorm
#import holography
import array as array
from scipy import signal

import sys
#import fthtools.masks as masks
#import fthtools.fth as fth
#import fthtools.PhaseRetrieval as PhR
import matplotlib.image as mpimg

# Self-written libraries
sys.path.append(join(os.getcwd(), "library"))
import mask_lib
import helper_functions as helper
import interactive
from interactive import cimshow

# Correct phase retrieval library
from importlib.util import spec_from_loader, module_from_spec
from importlib.machinery import SourceFileLoader 
spec = spec_from_loader("module.name", SourceFileLoader("module.name", join(os.getcwd(), "paper_analysis_code","analysis_code","fthtools","PhaseRetrieval.py")))
PhR = module_from_spec(spec)
spec.loader.exec_module(PhR)

In [None]:
# interactive plotting
import ipywidgets

%matplotlib widget

plt.rcParams["figure.constrained_layout.use"] = True  # replaces plt.tight_layout

# Auto formatting of cells
#%load_ext jupyter_black

## Custom functions

In [None]:
##Function to define stochastic noise
def gauss_2D(xx,yy,amp,sigma, x0,y0):
    return amp*np.exp(-((xx-x0)**2+(yy-y0)**2)/(sigma)**2)

##Function to convert rgb image to gray scale
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

# Loading of files

In [None]:
## import image
img = mpimg.imread(join(os.getcwd(),'paper_analysis_code','sample_image.png'))
image1 = rgb2gray(img)

cimshow(image1)

# Create holography aperture mask

In [None]:
## Define circular matrix
pixeln=int(np.sqrt(image1.size))
clip_radius=40
rows, cols = pixeln,pixeln
row_vec = np.double(np.arange(0,rows))
col_vec = np.double(np.arange(0,cols))
yy, xx= np.meshgrid(rows//2-col_vec, cols//2-row_vec)
mask =1-((xx)**2 + (yy)**2>clip_radius**2)
mask1= mask
mask1

cimshow(mask)

In [None]:
## define holography hole
r_hole_radius1=2
r_hole_radius2=2
location_r1x= 180
location_r1y= 150
location_r2x= 100
location_r2y= -150
holography_hole1=(1-((xx+location_r1x)**2 + (yy+location_r1y)**2>r_hole_radius1**2))*1
holography_hole2=(1-((xx+location_r2x)**2 + (yy+location_r2y)**2>r_hole_radius2**2))*1
holography_hole=holography_hole1+holography_hole2

cimshow(holography_hole)

In [None]:
##Add holography hole
image1=image1*mask+holography_hole

##Define phase
phase1= (100*(xx+yy)/(pixeln*np.pi))*mask1

cimshow(image1)
cimshow(phase1)

# Calc holograms

In [None]:
##number of diffraction pattern available
n=100

##maximum number of vortices present in the system
m=1

## Initialization of matrices
incoherent_CDI=np.zeros_like(mask)
av_pattern=np.zeros_like(mask)
ER2s=np.zeros_like(mask, dtype=complex)
EsR1=np.zeros_like(mask, dtype=complex)
ER1s=np.zeros_like(mask, dtype=complex)
EsR2=np.zeros_like(mask, dtype=complex)
R1R2s=np.zeros_like(mask, dtype=complex)
R1sR2=np.zeros_like(mask, dtype=complex)
Sktav=np.zeros_like(mask)
FER2s=np.zeros_like(mask)
FEsR1=np.zeros_like(mask)
FER1s=np.zeros_like(mask)
FEsR2=np.zeros_like(mask)
FR1R2s=np.zeros_like(mask)
FR1sR2=np.zeros_like(mask)
F_Sktav=np.zeros_like(mask)
F_fluctuation=np.zeros_like(mask)
F_fluctuation_mean=np.zeros_like(mask)
F_fluctuation_sq_mean=np.zeros_like(mask)
FAutoCorl1=np.zeros_like(mask)
FAutoCorl=np.zeros_like(mask)
FAutoCorlN=np.zeros_like(mask)
FAutoCorlN2=np.zeros_like(mask)
Av_fluctuation=np.zeros_like(mask)
AvSq_fluctuation=np.zeros_like(mask)

mask_scatter = np.ones_like(mask) #for CDI part

##Averaging dataset
for ii in tqdm(np.arange(n)):
    #Put gaussian at random position
    fluctuation=np.zeros_like(mask)
    x_f=np.random.random()*80-40
    y_f=np.random.random()*80-40
    fluctuation=gauss_2D(xx,yy,1,3,x_f,y_f)
    fluctuation=fluctuation*mask

    # Ensemble average of fluctuations
    Av_fluctuation=Av_fluctuation+fluctuation
    
    # Put gaussian on top of object
    pattern=1*fluctuation+1*image1*(np.exp(0.0j*(phase1))) #single
    av_pattern= av_pattern+ pattern #ensemble avg

    # Calc diffraction pattern
    diffraction=np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(pattern))) #single amplitude
    incoherent_CDI=incoherent_CDI+np.abs(diffraction)**2 #ensemble avg of intensities

    # Calc diffraction only of gaussian
    F_fluctuation=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(fluctuation)))

    # Add to ensemble averages
    F_fluctuation_mean=F_fluctuation_mean+(F_fluctuation) # amplitudes
    F_fluctuation_sq_mean=F_fluctuation_sq_mean+np.abs(F_fluctuation)**2 # intensity

##Reconstruction of averaged hologram
hologram2=np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(incoherent_CDI/n)))
F_stochastic_numerical=np.abs(F_fluctuation_sq_mean/n)-np.abs(F_fluctuation_mean/n)**2

In [None]:
# Plotting of a lot of stuff
fig, ax = plt.subplots(3,3,figsize=(9,9),sharex=True,sharey=True)

ax[0,0].set_title("Single fluctuation in real space",fontsize=8)
tmp = fluctuation.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [1,99])
ax[0,0].imshow(tmp,vmin=vmin, vmax = vmax)

ax[0,1].set_title("Single pattern in real space",fontsize=8)
tmp = pattern.real.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [1,99])
ax[0,1].imshow(tmp,vmin=vmin, vmax = vmax)

ax[1,0].set_title("Ensemble fluctuations in real space",fontsize=8)
tmp = Av_fluctuation.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [1,99])
ax[1,0].imshow(tmp,vmin=vmin, vmax = vmax)

ax[1,1].set_title("Ensemble pattern in real space",fontsize=8)
tmp = av_pattern.real.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [1,99])
ax[1,1].imshow(tmp,vmin=vmin, vmax = vmax)

ax[0,2].set_title("Single diffraction in fourier space",fontsize=8)
tmp = np.abs(diffraction)**2
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[0,2].imshow(tmp,norm = LogNorm(vmin=vmin, vmax = vmax))

ax[1,2].set_title("Ensemble diffraction in fourier space",fontsize=8)
tmp = incoherent_CDI.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[1,2].imshow(tmp,norm = LogNorm(vmin=vmin, vmax = vmax))

ax[2,0].set_title("Ensemble diffract amplitudes squared of gaussian fourier space",fontsize=6)
tmp = np.abs(F_fluctuation_mean)**2
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[2,0].imshow(tmp,norm = LogNorm(vmin=vmin, vmax = vmax))

ax[2,1].set_title("Ensemble diffract intensities of  gaussian fourier space",fontsize=6)
tmp = np.abs(F_fluctuation_sq_mean)/n
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[2,1].imshow(tmp,norm = LogNorm(vmin=vmin, vmax = vmax))

ax[2,2].set_title("Stochastic contribution",fontsize=8)
tmp = np.abs(F_fluctuation_sq_mean/n)-np.abs(F_fluctuation_mean/n)**2
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[2,2].imshow(tmp,norm = LogNorm(vmin=vmin, vmax = vmax))


# Extraction of auto-correlations and cross-correlations

## Object cross-correlations

In [None]:
RadER1= clip_radius+r_hole_radius1+10
RadER2= clip_radius+r_hole_radius2+10
RadRR= r_hole_radius1+r_hole_radius2+10
##E*R1
EsR1[((rows//2+location_r1x)-RadER1):((rows//2+location_r1x)+RadER1),((cols//2+location_r1y)-RadER1):((cols//2+location_r1y)+RadER1)]=hologram2[((rows//2+location_r1x)-RadER1):((rows//2+location_r1x)+RadER1),((cols//2+location_r1y)-RadER1):((cols//2+location_r1y)+RadER1)]
FEsR1=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift( EsR1)))
##ER1*
ER1s[((rows//2-location_r1x)-RadER1):((rows//2-location_r1x)+RadER1),((cols//2-location_r1y)-RadER1):((cols//2-location_r1y)+RadER1)]=hologram2[((rows//2-location_r1x)-RadER1):((rows//2-location_r1x)+RadER1),((cols//2-location_r1y)-RadER1):((cols//2-location_r1y)+RadER1)]
FER1s=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(ER1s)))
##E*R2
EsR2[((rows//2+location_r2x)-RadER2):((rows//2+location_r2x)+RadER2),((cols//2+location_r2y)-RadER2):((cols//2+location_r2y)+RadER2)]=hologram2[((rows//2+location_r2x)-RadER2):((rows//2+location_r2x)+RadER2),((cols//2+location_r2y)-RadER2):((cols//2+location_r2y)+RadER2)]
FEsR2=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(EsR2)))
##ER2*
ER2s[((rows//2-location_r2x)-RadER2):((rows//2-location_r2x)+RadER2),((cols//2-location_r2y)-RadER2):((cols//2-location_r2y)+RadER2),]=hologram2[((rows//2-location_r2x)-RadER2):((rows//2-location_r2x)+RadER2),((cols//2-location_r2y)-RadER2):((cols//2-location_r2y)+RadER2)]
FER2s=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift( ER2s)))

In [None]:
fig, ax = cimshow([EsR1,ER1s,EsR2,ER2s])
ax.set_title("Selected cross-correlation areas of recos")

In [None]:
fig, ax = cimshow(np.abs(np.array([FEsR1,FER1s,FEsR2,FER2s])))
ax.set_title("Diffraction of selected cross-correlation areas of recos")

## Reference cross-correlations

In [None]:
##R1R2*
R1R2s[rows//2+(location_r1x-location_r2x)-RadRR:rows//2+(location_r1x-location_r2x)+RadRR, cols//2+(location_r1y-location_r2y)-RadRR:cols//2+(location_r1y-location_r2y)+RadRR]=hologram2[rows//2+(location_r1x-location_r2x)-RadRR:rows//2+(location_r1x-location_r2x)+RadRR, cols//2+(location_r1y-location_r2y)-RadRR:cols//2+(location_r1y-location_r2y)+RadRR]
FR1R2s=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(R1R2s)))

##R1*R2
R1sR2[rows//2-(location_r1x-location_r2x)-RadRR:rows//2-(location_r1x-location_r2x)+RadRR, cols//2-(location_r1y-location_r2y)-RadRR:cols//2-(location_r1y-location_r2y)+RadRR]=hologram2[rows//2-(location_r1x-location_r2x)-RadRR:rows//2-(location_r1x-location_r2x)+RadRR, cols//2-(location_r1y-location_r2y)-RadRR:cols//2-(location_r1y-location_r2y)+RadRR]
FR1sR2=np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(R1sR2)))


In [None]:
fig, ax = cimshow([R1R2s,R1sR2])
ax.set_title("Selected cross-correlation areas of references")

In [None]:
fig, ax = cimshow(np.abs(np.array([FR1R2s,FR1sR2])))
ax.set_title("Diffraction of selected cross-correlation areas of references")

## Calc auto-correlations

In [None]:
##|E|^2
FAutoCorl=np.abs(np.sqrt((FER2s*FEsR1*FER1s*FEsR2)/(FR1R2s*FR1sR2)))

##|R1|^2
FR1R1s=np.abs((FER1s*FEsR1)/FAutoCorl)

##|R2|^2
FR2R2s=np.abs((FER2s*FEsR2)/FAutoCorl)

In [None]:
#Plotting
fig, ax = cimshow(FAutoCorl)
ax.set_title("Auto correlation reco")

In [None]:
fig, ax = cimshow([FR1R1s,FR2R2s])
ax.set_title("Auto correlation references")

## Stochastic term from CIDI

In [None]:
##Stochastic term from CIDI
F_Sktav=incoherent_CDI/n-(FAutoCorl+FR1R1s+FR2R2s)-(FR1R2s+FR1sR2+FER2s+FEsR1+FER1s+FEsR2)

# Plotting
fig, ax = cimshow(F_Sktav)
ax.set_title("Stochastic term from CIDI")

# CDI

In [None]:
##Image of the isolated stochastic term in object plane (Used to define proper mask to perform CDI)
Sktav=np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(F_Sktav)))
support=(np.abs(Sktav) > 2e-6).astype(int)

## Square root of diffrction pattern for
SQRT_F_Sktav=np.abs(F_Sktav**0.5)

##Adding mask to numerical zero errors
mask_scatter=1-((xx)**2 + (yy)**2>230**2)+((xx)**2 + (yy)**2>240**2)
SQRT_F_Sktav=(SQRT_F_Sktav*mask_scatter)

initial_guess= np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(mask)))

x = np.clip(SQRT_F_Sktav, 0, None)
y = np.clip(np.abs(initial_guess), 0, None)
res = sp.stats.linregress(x.flatten(), y.flatten())
initial_guess -= res.intercept
initial_guess /= res.slope

initial_guess = (SQRT_F_Sktav  * np.exp(1j * np.angle(initial_guess)))

#Plot
fig, ax = plt.subplots(1,3,figsize=(12,4),sharex=True,sharey=True)

tmp = np.abs(Sktav.copy())
vmin, vmax = np.percentile(tmp,[.1,99.9])
ax[0].imshow(support)
ax[0].imshow(tmp, vmin = vmin, vmax = vmax,alpha=0.5)
ax[0].set_title("Real space image with support mask overlay")

ax[1].imshow(support)
ax[1].set_title("Support mask")

tmp = np.abs(initial_guess.copy())
vmin, vmax = np.percentile(tmp,[.1,99.9])
ax[2].imshow(tmp)
ax[2].set_title("Initial guess")


In [None]:
SW_freq = 1e4  # disable
Nit = 500

##CDI
retrieved_res0, Error_diff_p, Error_supp, supportmask = PhR.PhaseRtrv_CPU(
    diffract=SQRT_F_Sktav,
    mask=support*mask,
    mode="mine",
    beta_zero=0.5,
    Nit=Nit,
    beta_mode='arctan',
    plot_every=20,
    Phase=initial_guess,
    # Phase=0,
    seed=False,
    real_object=False,
    # bsmask=0*(1 - mask_scatter),
    bsmask=(1 - mask_scatter),
    average_img=20,
    Fourier_last=True,
    SW_freq=SW_freq
)

# Get fourier transform for image reconstruction
ret_pattern = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(retrieved_res0)))

In [None]:
# Plotting
fig, ax = plt.subplots(2,2,sharex=True,sharey=True,figsize=(8,8))
tmp = np.abs(retrieved_res0)
vmin, vmax = np.percentile(tmp,[1,99])
ax[0,0].imshow(tmp)
ax[0,0].set_title("'diffraction pattern retrieved from CIDI'",fontsize=8)

tmp = np.abs(np.abs(F_stochastic_numerical))
vmin, vmax = np.percentile(tmp,[1,99])
ax[0,1].imshow(tmp)
ax[0,1].set_title("diffraction pattern from numerical data",fontsize=8)

tmp = np.abs(np.abs(ret_pattern))
vmin, vmax = np.percentile(tmp,[1,99])
ax[1,0].imshow(tmp)
ax[1,0].set_title("Absolute value of retrieved image (object plane) from CIDI",fontsize=8)

tmp = np.abs(fluctuation)
vmin, vmax = np.percentile(tmp,[1,99])
ax[1,1].imshow(tmp)
ax[1,1].set_title("Input fluctuation",fontsize=8)