In [1]:
import jax
import jax.numpy as jnp
import time
import ptyrodactyl.electrons as pte
import ptyrodactyl.tools as ptt
import matplotlib.pyplot as plt

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
jax.config.update("jax_enable_x64", True)

In [5]:
key = jax.random.PRNGKey(0)

In [7]:
calibration = 0.1
num_modes = 200
beam_test_m = jax.random.normal(key, shape=(256, 256, num_modes), dtype=jnp.complex64)
beam_weights = jnp.arange(num_modes)
beam_weights /= jnp.sum(beam_weights)
slice_test_m = jax.random.normal(key, shape=(256, 256, 500), dtype=jnp.float64)

In [11]:
slices = pte.make_potential_slices(slice_test_m, 1, calibration)
beams = pte.make_probe_modes(beam_test_m, beam_weights, calibration)

In [12]:
test_cbed = jax.jit(pte.cbed)

In [13]:
%timeit test_cbed(slices, beams, jnp.asarray(60.0))

6.81 s ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
num_modes = 3
beam_weights = jnp.flip(1 + jnp.arange(num_modes))
beam_weights /= jnp.sum(beam_weights)
beam_test_small = jax.random.normal(
    key, shape=(128, 128, num_modes), dtype=jnp.complex64
)
slice_test_small = jax.random.normal(key, shape=(128, 128, 200), dtype=jnp.float64)
slices_small = pte.make_potential_slices(slice_test_small, 1, calibration)
beams_small = pte.make_probe_modes(beam_test_small, beam_weights, calibration)

In [15]:
calib_ang = 0.2
x = jnp.arange(-5, 5, calib_ang)
y = jnp.arange(-5, 5, calib_ang)
xx, yy = jnp.meshgrid(x, y)
positions = jnp.asarray((xx.ravel(), yy.ravel())).T

In [16]:
test_stem4d = jax.jit(pte.stem_4D)

In [17]:
%timeit test_stem4d(slices_small, beams_small, positions, jnp.asarray(60), calib_ang)

23.8 s ± 4.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
