In [1]:
from ptyrodactyl import simul, tools
import jax
import jax.numpy as jnp



In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
jax.config.update("jax_enable_x64", True)

In [4]:
key = jax.random.PRNGKey(0)

In [5]:
calibration = 0.1
num_modes = 200
beam_test_m = jax.random.normal(key, shape=(256, 256, num_modes), dtype=jnp.complex64)
beam_weights = jnp.arange(num_modes)
beam_weights /= jnp.sum(beam_weights)
slice_test_m = jax.random.normal(key, shape=(256, 256, 500), dtype=jnp.float64)

In [6]:
slices = tools.make_potential_slices(slice_test_m, 1, calibration)
beams = tools.make_probe_modes(beam_test_m, beam_weights, calibration)

In [7]:
test_cbed = jax.jit(simul.cbed)

In [8]:
%timeit test_cbed(slices, beams, jnp.asarray(60.0))

124 μs ± 78.3 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
num_modes = 3
beam_weights = jnp.flip(1 + jnp.arange(num_modes))
beam_weights /= jnp.sum(beam_weights)
beam_test_small = jax.random.normal(key, shape=(128, 128, num_modes), dtype=jnp.complex64)
slice_test_small = jax.random.normal(key, shape=(128, 128, 200), dtype=jnp.float64)
slices_small = tools.make_potential_slices(slice_test_small, 1, calibration)
beams_small = tools.make_probe_modes(beam_test_small, beam_weights, calibration)

In [10]:
calib_ang = 0.2
x = jnp.arange(-4, 4, calib_ang)
y = jnp.arange(-4, 4, calib_ang)
xx, yy = jnp.meshgrid(x, y)
positions = jnp.asarray((xx.ravel(), yy.ravel())).T

In [11]:
test_stem4d = jax.jit(simul.stem_4d)

In [12]:
help(simul.stem_4d)

Help on function stem_4d in module ptyrodactyl.simul.simulations:

stem_4d(pot_slice: ptyrodactyl.tools.electron_types.PotentialSlices, beam: ptyrodactyl.tools.electron_types.ProbeModes, positions: jaxtyping.Num[Array, '#P 2'], voltage_kv: Union[int, float, jaxtyping.Num[Array, '']], calib_ang: Union[float, jaxtyping.Float[Array, '']]) -> ptyrodactyl.tools.electron_types.STEM4D
    Simulate CBED patterns for multiple beam positions by shifting the beam and
    running CBED simulations.

    Parameters
    ----------
    pot_slice : PotentialSlices
        The potential slice(s).
    beam : ProbeModes
        The electron beam mode(s).
    positions : Num[Array, "#P 2"]
        The (y, x) positions to shift the beam to.
        With P being the number of positions.
    voltage_kv : ScalarNumeric
        The accelerating voltage in kilovolts.
    calib_ang : ScalarFloat
        The calibration in angstroms.

    Returns
    -------
    STEM4D
        Complete 4D-STEM dataset containing:


In [13]:
%timeit test_stem4d(slices_small, beams_small, positions, jnp.asarray(60), calib_ang)

131 μs ± 45 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
