In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt

# from optics_util.backend import select_backend
import optics_utils

lib = 'torch'

optics_utils.backend.select_backend(lib)

device = torch.device('cuda') if lib == 'torch' else None
fn = optics_utils.backend.functions
dtype = fn.float32


nx, ny = 512, 512
dx, dy = 0.3, 0.3

dz = 1.0
nz = 512

# Lens
f = 300.
R = 50.
lamb0 = 0.532

if lib == 'torch':
    xa = (fn.arange(0,nx,1,dtype=dtype).to(device)-(nx-1)/2)*dx
    ya = (fn.arange(0,ny,1,dtype=dtype).to(device)-(ny-1)/2)*dy
    za = (fn.arange(0,nz,1,dtype=dtype).to(device))*dz
else:
    xa = (fn.arange(0,nx,1,dtype=dtype)-(nx-1)/2)*dx
    ya = (fn.arange(0,ny,1,dtype=dtype)-(ny-1)/2)*dy
    za = (fn.arange(0,nz,1,dtype=dtype))*dz

x, y = fn.meshgrid(xa,ya,indexing='ij')

asm = optics_utils.Fourier_optics.ASM([nx,ny],[dx,dy],[f],[lamb0],sign=1,dtype=dtype,device=device)
rs_asm = optics_utils.Fourier_optics.RotationalSymmetricASM(nx,nx*dx/2,za,[lamb0],sign=1,dtype=dtype,device=device)

phase = fn.exp(1.j* (-2*fn.pi/lamb0 * (fn.sqrt(x**2+y**2+f**2) - f)))
phase[x**2+y**2>R**2] = 0

rs_phase = fn.exp(1.j* (-2*fn.pi/lamb0 * (fn.sqrt(rs_asm.r**2+f**2) - f)))
rs_phase[rs_asm.r>R] = 0

In [None]:
plt.imshow(np.angle(phase.cpu()))
plt.colorbar()

In [None]:
plane = asm.propagate(phase)
plt.imshow(np.abs(plane[0,0,0,:,:].cpu())**2)
plt.colorbar()

In [None]:
plt.plot(np.angle(rs_phase.cpu()))

In [None]:
plane = rs_asm.propagate(rs_phase)

# There is distortion in r-axis due to non-uniform axis
plt.imshow(np.abs(plane.cpu().squeeze())**2)
plt.colorbar()

In [None]:
plt.plot(rs_asm.r.cpu(),np.abs(plane.cpu().squeeze()[300,:])**2)