In [None]:
import jax
import jax.numpy as jnp
import ptyrodactyl.electrons as pte
from jaxtyping import Array, Float, Shaped, Int, Complex

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
beam_test_s = jnp.ones((256, 256), dtype=jnp.complex64)
slice_test_s = jnp.ones((256, 256), dtype=jnp.complex64)

beam_test_m = jnp.ones((256, 256, 32), dtype=jnp.complex64)
slice_test_m = jnp.ones((256, 256, 50), dtype=jnp.complex64)

2024-08-12 14:49:08.144746: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [None]:
test_func_ss = jax.jit(pte.cbed_single_slice_single_beam)
test_func_sm = jax.jit(pte.cbed_single_slice_multi_beam)
test_func_ms = jax.jit(pte.cbed_multi_slice_single_beam)
test_func_mm = jax.jit(pte.cbed_multi_slice_multi_beam)

In [None]:
%timeit test_func_ss(slice_test_s, beam_test_s)

44.5 μs ± 800 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit test_func_sm(slice_test_s, beam_test_m)

332 μs ± 81.5 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit test_func_ms(slice_test_m, beam_test_s, jnp.float64(1.003), jnp.float64(60.001), jnp.float64(0.01))

TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
  * the input carry component carry[0] has type complex64[256,256] but the corresponding output carry component has type complex128[256,256], so the dtypes do not match

Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

In [None]:
help(pte.propagation_func)

In [None]:
def propagation_func(
    imsize_y: int,
    imsize_x: int,
    thickness_ang: Float[Array, "*"],
    voltage_kV: Float[Array, "*"],
    calib_ang: Float[Array, "*"],
) -> Complex[Array, "H W"]:
    """
    Calculates the complex propagation function that results
    in the phase shift of the exit wave when it travels from
    one slice to the next in the multislice algorithm

    Args:
    - `imsize`, (Int[Array, "2"]):
        Size of the image of the propagator
    -  `thickness_ang`, (Float[Array, "*"])
        Distance between the slices in angstroms
    - `voltage_kV`, (Float[Array, "*"])
        Accelerating voltage in kilovolts
    - `calib_ang`, (Float[Array, "*"])
        Calibration or pixel size in angstroms

    Returns:
    - `prop` Complex[Array, "H W"]:
        The propagation function of the same size given by imsize
    """
    # Generate frequency arrays directly using fftfreq
    qy: Float[Array, "H"] = jnp.fft.fftfreq(imsize_y, d=calib_ang)
    qx: Float[Array, "W"] = jnp.fft.fftfreq(imsize_x, d=calib_ang)

    # Create 2D meshgrid of frequencies
    Lya, Lxa = jnp.meshgrid(qy, qx, indexing="ij")

    # Calculate squared sum of frequencies
    L_sq: Float[Array, "H W"] = jnp.square(Lxa) + jnp.square(Lya)

    # Calculate wavelength
    lambda_angstrom: float = pte.wavelength_ang(voltage_kV)

    # Compute the propagation function
    prop: Complex[Array, "H W"] = jnp.exp(
        (-1j) * jnp.pi * lambda_angstrom * thickness_ang * L_sq
    )
    return prop

In [None]:
pf = jax.jit(propagation_func, static_argnums=(0, 1))

In [None]:
pf(256, 256, jnp.float64(1), jnp.float64(60), jnp.float64(0.01))