In [None]:
import phasenet.zernike as Z
import phasenet.psf as P
import phasenet.model as M

# Zernike

In [None]:
print(Z.Zernike(5, order='noll'))
print(Z.Zernike(3, order='ansi'))
print(Z.Zernike((2,-2)))
print(Z.Zernike('oblique astigmatism'))

In [None]:
for j in range(15):
    print(Z.Zernike(j, order='ansi'))

In [None]:
fig, ax = plt.subplots(3,5, figsize=(16,10))
for i,a in enumerate(ax.ravel()):
    z = Z.Zernike(i, order='ansi')
    w = z.polynomial(128)
    a.imshow(w)
    a.set_title(z.name)
    a.axis('off')
None;

## ZernikeWavefront

In [None]:
f = Z.ZernikeWavefront(np.random.uniform(-1,1,4), order='ansi')
f.zernikes

print(f.amplitudes_noll)
print(f.amplitudes_ansi)

plt.imshow(f.polynomial(512)); plt.colorbar(); plt.axis('off');

In [None]:
f = Z.random_zernike_wavefront([1,1,1], order='ansi')
print(f.amplitudes_requested, f.amplitudes_ansi)

f = Z.random_zernike_wavefront([0,0,(1,2)], order='ansi')
print(f.amplitudes_requested, f.amplitudes_ansi)

f = Z.random_zernike_wavefront({'defocus':(1,2), (3,-3):5})
print(f.amplitudes_requested, f.amplitudes_ansi)

# PSF

In [None]:
N = 64*1
dx = .1

psf = P.PsfGenerator3D(psf_shape=(N,)*3, units=(dx,)*3, na_detection=1.1, lam_detection=.5, n=1.33)

In [None]:
wf = Z.ZernikeWavefront(np.random.uniform(-0.2,0.2,5), order='ansi')
wf.zernikes

plt.figure(figsize=(15,4))
phase = wf.phase(psf.krho, psf.kphi, normed=True, outside=None)
phase = np.fft.fftshift(phase)
plt.subplot(131); plt.imshow(phase); plt.colorbar()

h1 = np.fft.fftshift(psf.incoherent_psf(wf, normed=True))
plt.subplot(132); plt.imshow(h1[N//2]);   plt.title('XY section'); plt.colorbar()
plt.subplot(133); plt.imshow(h1[:,N//2]); plt.title('ZX section'); plt.colorbar()
plt.tight_layout()
None;

## Data Generator

In [None]:
data = M.Data({'vertical coma':.2}, batch_size=3, noise_params={'mean':(100,100),'sigma':(3,4),'snr':(1,5)})
psfs, amps = next(data.generator())

print(psfs.shape)
print(amps.shape)
h1 = psfs[0,...,0]

plt.figure(figsize=(10,4))
plt.subplot(121); plt.imshow(h1[N//2]);   plt.title('XY section'); plt.colorbar()
plt.subplot(122); plt.imshow(h1[:,N//2]); plt.title('ZX section'); plt.colorbar()
plt.tight_layout();

# Model

In [None]:
c = M.Config(psf_n=8)
vars(c)

In [None]:
#model = M.PhaseNet(M.Config(), basedir=None)
model = M.PhaseNet(M.Config(), name='test', basedir='models')
vars(model.config)

In [None]:
model.train(epochs=5)

## Load model

In [None]:
model = M.PhaseNet(None, name='test', basedir='models')
vars(model.config)

## Create test PSFs

In [None]:
data = M.Data(
    batch_size           = 128,
    amplitude_ranges     = {'vertical coma': [-0.5, 0.5]},
    #amplitude_ranges     = model.config.zernike_amplitude_ranges,
    order                = model.config.zernike_order,
    normed               = model.config.zernike_normed,
    psf_shape            = model.config.psf_shape,
    units                = model.config.psf_units,
    na_detection         = model.config.psf_na_detection,
    lam_detection        = model.config.psf_lam_detection,
    n                    = model.config.psf_n,
)
psfs, amps = next(data.generator())
psfs.shape
amps.shape

## Predict and plot

In [None]:
# using keras model directly
_amps_pred = model.keras_model.predict(psfs, verbose=1)

In [None]:
amps_pred = np.stack([model.predict(psf) for psf in tqdm(psfs)])

In [None]:
print(_amps_pred.shape, amps_pred.shape)
np.allclose(_amps_pred.ravel(), amps_pred)

In [None]:
plt.figure(figsize=(10,8))
ind = np.argsort(amps.ravel())
plt.plot(amps[ind], label='gt')
plt.plot(amps_pred[ind], '--', label='pred')
plt.hlines(-0.2, *plt.axis()[:2])
plt.hlines(+0.2, *plt.axis()[:2])
plt.xlabel('test psf')
plt.ylabel(f'amplitude {tuple(model.config.zernike_amplitude_ranges.keys())[0]}')
plt.legend()
None;