In [9]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import astropy.units as u
from astropy.io import fits
from pathlib import Path
from IPython.display import clear_output, display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
from importlib import reload

import poppy

import logging, sys
poppy_log = logging.getLogger('poppy')
poppy_log.setLevel('DEBUG')
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
poppy_log.disabled = True

import ray

import scoobpsf
from scoobpsf.math_module import xp, _scipy
from scoobpsf.imshows import *
import scoobpsf.compact_scoob as cscoob

import lina

pupil_diam = 6.75*u.mm 
lyot_ratio = 0.9
lyot_diam = 3.6*u.mm

# The flattest wavefront obtained for the system in the lab by
# performing phase diversity

wavelength_c = 632.8e-9*u.m # central wavelength

In [30]:
import numpy as np
arr = xp.ones((2048,2048))

np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(arr)))

array([[0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       ...,
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j]])

In [31]:
%%timeit
ft = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(arr)))

393 ms ± 555 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
import cupy as cp
arr = cp.ones((2048,2048))

cp.fft.ifftshift(cp.fft.fft2(cp.fft.fftshift(arr)))

array([[0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       ...,
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j]])

In [36]:
%%timeit
ft = cp.fft.ifftshift(cp.fft.fft2(cp.fft.fftshift(arr)))

465 µs ± 68 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [2]:
import jax
import jax.numpy as xp
jax.config.update("jax_enable_x64", True)

In [32]:
dtype = xp.float64
arr = xp.ones((2048,2048), dtype=dtype)
type(arr)

jaxlib.xla_extension.Array

In [34]:
%%timeit
ft = xp.fft.ifftshift(xp.fft.fft2(xp.fft.fftshift(arr)))

1.21 ms ± 6.53 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
jax.devices()[0].platform
jax.devices()[0].device_kind

'NVIDIA A100 80GB PCIe'

In [37]:
jax.devices()[0].__dir__()

['__init__',
 '__doc__',
 '__module__',
 'id',
 'process_index',
 'host_id',
 'task_id',
 'platform',
 'device_kind',
 'client',
 '__str__',
 '__repr__',
 'transfer_to_infeed',
 'transfer_from_outfeed',
 'live_buffers',
 '__getattr__',
 '__new__',
 '__hash__',
 '__getattribute__',
 '__setattr__',
 '__delattr__',
 '__lt__',
 '__le__',
 '__eq__',
 '__ne__',
 '__gt__',
 '__ge__',
 '__reduce_ex__',
 '__reduce__',
 '__getstate__',
 '__subclasshook__',
 '__init_subclass__',
 '__format__',
 '__sizeof__',
 '__dir__',
 '__class__']

In [17]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu
