# Zoom-in sampling of coro

In [None]:
import os
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

from Asterix.optics import butterworth_circle, fqpm_mask, mft, prop_fpm_regional_sampling, fft_choosecenter
from Asterix.optics import phase_amplitude_functions as paf
from Asterix.utils.plotting import display_complex

## Set up inputs

In [None]:
dim = 512
ray = dim / 2

res_list = np.array([0.1, 1, 10, 100])

In [None]:
pup = paf.roundpupil(dim, ray, grey_pup_bin_factor=10)
lyot_stop = paf.roundpupil(dim, ray*0.95)

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(pup, origin='lower', cmap='Greys_r')
plt.title('Grey pupil')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(lyot_stop, origin='lower', cmap='Greys_r')
plt.title('Lyot stop')
plt.colorbar()

In [None]:
fpm = fqpm_mask(pup.shape[0])

plt.imshow(fpm, origin='lower', cmap='Reds')
plt.title('Phase')
plt.colorbar()

## Test built-in function

In [None]:
psam_pre_ls = prop_fpm_regional_sampling(pup, np.exp(1j*fpm), nbres=res_list, samp_outer=4)

display_complex(psam_pre_ls)
plt.suptitle('Pre-LS E-field')

In [None]:
plt.imshow(np.abs(psam_pre_ls)**2, origin='lower', cmap='inferno', norm=LogNorm())

In [None]:
this = mft(psam_pre_ls, 512, 512, 64)

In [None]:
plt.imshow(np.abs(this)**2, origin='lower', cmap='inferno', norm=LogNorm())

## Break up function

In [None]:
dim = pup.shape[0]

fpm_z = np.exp(1j*fpm)
nbres = np.array([0.1,1,2,5], dtype=float)
samp_outer = 2
filter_order = 15
alpha = 1.5

### Inner part of FPM

In [None]:
# Butterworth filter
but0 = butterworth_circle(dim, dim / alpha, filter_order, -0.5, -0.5)
display_complex(but0)

In [None]:
# E-field before the FPM in inner part of focal plane
efield_before_fpm = mft(pup, real_dim_input=dim, dim_output=dim, nbres=nbres[0])
display_complex(efield_before_fpm)

In [None]:
# Total E-field before the LS
efield_before_ls = mft(efield_before_fpm * fpm_z * but0, real_dim_input=dim, dim_output=dim, nbres=nbres[0],
                       inverse=True)
display_complex(efield_before_ls)

In [None]:
plt.imshow(np.abs(efield_before_ls)**2, origin='lower', cmap='inferno', norm=LogNorm())
plt.colorbar()

### Layers of FPM sampling

In [None]:
but_list = []
pre_fpm_list = []
pre_ls_list = []

In [None]:
const_but = butterworth_circle(dim, dim / alpha, filter_order, xshift=-0.5, yshift=-0.5)

for k in range(nbres.shape[0] - 1):
    print(k)
    # Butterworth filter in each layer
    sizebut_here = dim / alpha * nbres[k] / nbres[k + 1]
    but = (1 - butterworth_circle(dim, sizebut_here, filter_order, xshift=-0.5, yshift=-0.5)) * const_but
    but_list.append(but)
    
    # E-field before the FPM in each layer
    ef_pre_fpm = mft(pup, real_dim_input=dim, dim_output=dim, nbres=nbres[k + 1])
    pre_fpm_list.append(ef_pre_fpm)

    # E-field before the LS in each layer
    ef_pre_ls = mft(ef_pre_fpm * fpm_z * but, real_dim_input=dim, dim_output=dim, nbres=nbres[k + 1],
                    inverse=True)
    pre_ls_list.append(ef_pre_ls)

In [None]:
# Plot all filters
plt.figure(figsize=(15,7))
for k in range(3):
    plt.subplot(2,3,k+1)
    plt.imshow(np.abs(but_list[k])**2, origin='lower', cmap='Greys_r')
    plt.colorbar()
    plt.title(f"Intensity, k={k}")
    plt.subplot(2,3,k+4)
    plt.imshow(np.angle(but_list[k]), origin='lower', cmap='RdBu')
    plt.colorbar()
    plt.title(f"Phase, k={k}")

In [None]:
print(dim / samp_outer)
that0 = mft(but_list[0], 512, 512, 256)
that1 = mft(but_list[1], 512, 512, 256)
that2 = mft(but_list[2], 512, 512, 256)

In [None]:
plt.imshow(np.abs(that0)**2, origin='lower', cmap='inferno', norm=LogNorm())

The filters are supposed to be the same. But the sampling will same when applied, so that the E-field won't be the same. The ratio between the elements of `nbres` used is the same.

In [None]:
# Plot all E-fields before the FPM
plt.figure(figsize=(15,7))
for k in range(3):
    plt.subplot(2,3,k+1)
    plt.imshow(np.abs(pre_fpm_list[k])**2, origin='lower', cmap='inferno')
    plt.colorbar()
    plt.title(f"Intensity, k={k}")
    plt.subplot(2,3,k+4)
    plt.imshow(np.angle(pre_fpm_list[k]), origin='lower', cmap='RdBu')
    plt.colorbar()
    plt.title(f"Phase, k={k}")

In [None]:
# Plot filter over pre-FPM PSF
n = 0
plt.imshow(np.abs(but_list[n] * pre_fpm_list[n])**2, origin='lower', cmap='inferno', norm=LogNorm())
plt.colorbar()

In [None]:
# Plot all E-fields before the LS
plt.figure(figsize=(15,7))
for k in range(3):
    plt.subplot(2,3,k+1)
    plt.imshow(np.abs(pre_ls_list[k])**2, origin='lower', cmap='inferno', norm=LogNorm())
    plt.colorbar()
    plt.title(f"Intensity, k={k}")
    plt.subplot(2,3,k+4)
    plt.imshow(np.angle(pre_ls_list[k]), origin='lower', cmap='RdBu')
    plt.colorbar()
    plt.title(f"Phase, k={k}")

In [None]:
# Plot all E-fields before the LS
plt.figure(figsize=(15,7))
for k in range(3):
    plt.subplot(2,3,k+1)
    plt.imshow(np.abs(pre_ls_list[k]), origin='lower', cmap='inferno')
    plt.colorbar()
    plt.title(f"Absolute value, k={k}")
    plt.subplot(2,3,k+4)
    plt.imshow(np.real(pre_ls_list[k]), origin='lower', cmap='inferno')
    plt.colorbar()
    plt.title(f"Real part, k={k}")

In [None]:
# Total E-field before the LS
summed_layers_before_ls = efield_before_ls + np.sum(np.array(pre_ls_list), axis=0)
display_complex(summed_layers_before_ls)
plt.suptitle("Summed pre-LS E-field, through all layers (except outermost)")

In [None]:
plt.imshow(np.abs(summed_layers_before_ls)**2, origin='lower', cmap='inferno', norm=LogNorm())
plt.colorbar()
plt.title("SUmmed layers before LS (includes inner part)")

In [None]:
plt.figure(figsize=(16,8))

plt.subplot(2,6,1)
plt.imshow(np.abs(efield_before_ls)**2, cmap='inferno', origin='lower', norm=LogNorm())
plt.title('efield_before_ls')

plt.subplot(2,6,2)
plt.imshow(np.abs(pre_ls_list[0])**2, cmap='inferno', origin='lower', norm=LogNorm())
plt.title('pre_ls_list[0]')

plt.subplot(2,6,3)
plt.imshow(np.abs(pre_ls_list[1])**2, cmap='inferno', origin='lower', norm=LogNorm())
plt.title('pre_ls_list[1]')

plt.subplot(2,6,4)
plt.imshow(np.abs(pre_ls_list[2])**2, cmap='inferno', origin='lower', norm=LogNorm())
plt.title('pre_ls_list[2]')

plt.subplot(2,6,5)
plt.imshow(np.abs(np.sum(np.array(pre_ls_list), axis=0))**2, cmap='inferno', origin='lower', norm=LogNorm())
plt.title('np.sum(np.array(pre_ls_list), axis=0)')

plt.subplot(2,6,6)
plt.imshow(np.abs(summed_layers_before_ls)**2, cmap='inferno', origin='lower', norm=LogNorm())
plt.title('summed_layers_before_ls')

plt.subplot(2,6,7)
plt.imshow(np.angle(efield_before_ls), cmap='RdBu', origin='lower', norm=LogNorm())
plt.title('efield_before_ls')

plt.subplot(2,6,8)
plt.imshow(np.angle(pre_ls_list[0]), cmap='RdBu', origin='lower', norm=LogNorm())
plt.title('pre_ls_list[0]')

plt.subplot(2,6,9)
plt.imshow(np.angle(pre_ls_list[1]), cmap='RdBu', origin='lower', norm=LogNorm())
plt.title('pre_ls_list[1]')

plt.subplot(2,6,10)
plt.imshow(np.angle(pre_ls_list[2]), cmap='RdBu', origin='lower', norm=LogNorm())
plt.title('pre_ls_list[2]')

plt.subplot(2,6,11)
plt.imshow(np.angle(np.sum(np.array(pre_ls_list), axis=0)), cmap='RdBu', origin='lower', norm=LogNorm())
plt.title('np.sum(np.array(pre_ls_list), axis=0)')

plt.subplot(2,6,12)
plt.imshow(np.angle(summed_layers_before_ls), cmap='RdBu', origin='lower', norm=LogNorm())
plt.title('summed_layers_before_ls')

### Outer part of FPM

In [None]:
# Butterworth filter in outer part of focal plane
nbres_outer = dim / samp_outer
sizebut_outer = dim / alpha * nbres[-1] / nbres_outer
but_outer = 1 - butterworth_circle(dim, sizebut_outer, filter_order, xshift=-0.5, yshift=-0.5)
display_complex(but_outer)

In [None]:
# E-field before the FPM in outer part of focal plane
ef_pre_fpm_outer = mft(pup, real_dim_input=dim, dim_output=dim, nbres=nbres_outer, inverse=True)
display_complex(ef_pre_fpm_outer)

In [None]:
# E-field before the LS in outer part of focal plane
ef_pre_ls_outer = mft(ef_pre_fpm_outer * fpm_z * but_outer, real_dim_input=dim, dim_output=dim, nbres=nbres_outer,
                      inverse=True)
display_complex(ef_pre_ls_outer)

In [None]:
# Total E-field before the LS
psam_pre_ls = summed_layers_before_ls + ef_pre_ls_outer
display_complex(psam_pre_ls)

In [None]:
plt.imshow(np.abs(psam_pre_ls)**2, origin='lower', cmap='inferno', norm=LogNorm())

## Calc PSFs

In [None]:
post_ls = psam_pre_ls * lyot_stop

display_complex(post_ls)
plt.suptitle('E-field after LS')

In [None]:
lamD_psf = 4

In [None]:
direct_ef = mft(pup*lyot_stop, real_dim_input=dim, dim_output=dim, nbres=dim/lamD_psf)
direct_psf = np.abs(direct_ef)**2
norm = direct_psf.max()

plt.imshow(direct_psf / norm, origin='lower', cmap='inferno', norm=LogNorm())
plt.title('Direct PSF')
plt.colorbar()

In [None]:
coro_ef = mft(post_ls, real_dim_input=dim, dim_output=dim, nbres=dim/lamD_psf)
coro_psf = np.abs(coro_ef)**2

plt.imshow(coro_psf / norm, origin='lower', cmap='inferno', norm=LogNorm())
plt.title('Coronagraphic PSF')
plt.colorbar()

In [None]:
#plt.imshow(coro_psf[100:410, 100:410] / norm, origin='lower', cmap='inferno', norm=LogNorm())
plt.imshow(coro_psf[200:310, 200:310] / norm, origin='lower', cmap='inferno', norm=LogNorm())
plt.title('Coronagraphic PSF')
plt.colorbar()

## Improving on the residual energy - testing

In [None]:
dim = 512
rad = dim / 2
samp_outer = 4

pup = paf.roundpupil(dim, rad, grey_pup_bin_factor=10)
lyot_stop = paf.roundpupil(dim, rad * 0.95)
fpm = fqpm_mask(dim)

res_list = np.array([0.1, 1, 10, 100])
pre_ls_areas = prop_fpm_regional_sampling(pup, np.exp(1j * fpm), nbres=res_list, samp_outer=samp_outer)
post_ls_areas = pre_ls_areas * lyot_stop

nbres_direct = dim / samp_outer
pre_fpm = mft(pup, real_dim_input=dim, dim_output=dim, nbres=nbres_direct)
post_fpm = pre_fpm * np.exp(1j * fpm)
pre_ls_direct = mft(post_fpm, real_dim_input=dim, dim_output=dim, nbres=nbres_direct, inverse=True)
post_ls_direct = pre_ls_direct * lyot_stop

In [None]:
direct_ef = mft(pup*lyot_stop, real_dim_input=dim, dim_output=dim, nbres=nbres_direct)
direct_psf = np.abs(direct_ef)**2
norm = direct_psf.max()

In [None]:
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(np.abs(post_ls_areas), cmap='inferno', origin='lower', norm=LogNorm())
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(np.abs(post_ls_direct), cmap='inferno', origin='lower', norm=LogNorm())
plt.colorbar()

In [None]:
print(np.sum(np.abs(post_ls_areas)**2))
print(np.sum(np.abs(post_ls_direct)**2))

In [None]:
coro_ef_areas = mft(post_ls_areas, real_dim_input=dim, dim_output=dim, nbres=nbres_direct)
coro_psf_areas = np.abs(coro_ef_areas) ** 2 / norm

In [None]:
coro_ef_direct = mft(post_ls_direct, real_dim_input=dim, dim_output=dim, nbres=nbres_direct)
coro_psf_direct = np.abs(coro_ef_direct) ** 2 / norm

In [None]:
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(coro_psf_areas, cmap='inferno', origin='lower', norm=LogNorm())
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(coro_psf_direct, cmap='inferno', origin='lower', norm=LogNorm())
plt.colorbar()

In [None]:
assert (np.max(coro_psf_direct) / np.max(coro_psf_areas)) > 1e4