## Comparing jaxtronomy and lenstronomy performance for lens model ray shooting

In [2]:
from jax import numpy as jnp, config
import numpy as np
import time

from jaxtronomy.LensModel.lens_model import LensModel
from lenstronomy.LensModel.lens_model import LensModel as LensModel_ref
from jaxtronomy.LensModel.profile_list_base import (
    _JAXXED_MODELS as JAXXED_DEFLECTOR_PROFILES,
)

config.update("jax_enable_x64", True)


# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = False

if supersampling:
    num_pix *= supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for deflector_profile in JAXXED_DEFLECTOR_PROFILES:
    lensModel_ref = LensModel_ref([deflector_profile])
    lensModel = LensModel([deflector_profile])

    # Get parameter names
    kwargs_lens = lensModel.lens_model.func_list[0].upper_limit_default

    # Compile code/warmup
    result = lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])
    result_ref = lensModel_ref.ray_shooting(x, y, [kwargs_lens])
    np.testing.assert_allclose(result, result_ref, rtol=1e-3, atol=1e-3)

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {deflector_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {deflector_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

    # Additional performance comparison with EPL_NUMBA
    if deflector_profile == "EPL":
        lensModel_epl_numba = LensModel_ref(["EPL_NUMBA"])
        lensModel_epl_numba.ray_shooting(x, y, [kwargs_lens])
        start_time = time.perf_counter()
        for _ in range(10000):
            lensModel_epl_numba.ray_shooting(x, y, [kwargs_lens])
        end_time = time.perf_counter()
        numba_execution_time = end_time - start_time

        print(f"jaxtronomy execution time for EPL: {jax_execution_time} seconds")
        print(
            f"lenstronomy execution time for EPL_NUMBA: {numba_execution_time} seconds"
        )
        print(
            f"jaxtronomy is {'{0:.1f}'.format(numba_execution_time/jax_execution_time)}x faster\n"
        )

jaxtronomy execution time for CONVERGENCE: 0.12926510000033886 seconds
lenstronomy execution time for CONVERGENCE: 0.24261440000009316 seconds
jaxtronomy is 1.9x faster

jaxtronomy execution time for CSE: 0.1693755999986024 seconds
lenstronomy execution time for CSE: 0.8399465000002238 seconds
jaxtronomy is 5.0x faster

jaxtronomy execution time for EPL: 7.1486386000015045 seconds
lenstronomy execution time for EPL: 83.05072649999966 seconds
jaxtronomy is 11.6x faster

jaxtronomy execution time for EPL: 7.1486386000015045 seconds
lenstronomy execution time for EPL_NUMBA: 8.942944800000987 seconds
jaxtronomy is 1.3x faster

jaxtronomy execution time for EPL_Q_PHI: 6.708962799999426 seconds
lenstronomy execution time for EPL_Q_PHI: 2.124331400000301 seconds
jaxtronomy is 0.3x faster

jaxtronomy execution time for GAUSSIAN: 0.29490329999862297 seconds
lenstronomy execution time for GAUSSIAN: 0.6275099000013142 seconds
jaxtronomy is 2.1x faster

jaxtronomy execution time for GAUSSIAN_POTEN

## Comparing jaxtronomy and lenstronomy performance for light model surface brightness
Note: Restarting the kernel is recommended before running the below tests on a low-memory machine, otherwise you may introduce a memory bottleneck from caching all of the compiled functions from the lens model testing in the previous cell

In [1]:
import copy
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LightModel.light_model import LightModel
from lenstronomy.LightModel.light_model import LightModel as LightModel_ref
from jaxtronomy.LightModel.light_model_base import (
    _JAXXED_MODELS as JAXXED_SOURCE_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = False

if supersampling:
    num_pix *= supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for source_profile in JAXXED_SOURCE_PROFILES:
    lightModel = LightModel([source_profile])
    lightModel_ref = LightModel_ref([source_profile])
    kwargs_source = copy.deepcopy(lightModel.func_list[0].upper_limit_default)
    if source_profile in ["MULTI_GAUSSIAN", "MULTI_GAUSSIAN_ELLIPSE"]:
        kwargs_source["amp"] = np.linspace(10, 20, 5)
        kwargs_source["sigma"] = np.linspace(0.3, 1.0, 5)
    elif source_profile == "SHAPELETS":
        n_max = 10
        num_param = int((n_max + 1) * (n_max + 2) / 2)
        kwargs_source["n_max"] = n_max
        kwargs_source["amp"] = np.linspace(20.0, 30.0, num_param)

    # Compile code/warmup
    lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])
    lightModel_ref.surface_brightness(x, y, [kwargs_source])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lightModel_ref.surface_brightness(x, y, [kwargs_source])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {source_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {source_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

jaxtronomy execution time for CORE_SERSIC: 0.45475860000078683 seconds
lenstronomy execution time for CORE_SERSIC: 2.4216302999993786 seconds
jaxtronomy is 5.3x faster

jaxtronomy execution time for GAUSSIAN: 0.1258683000014571 seconds
lenstronomy execution time for GAUSSIAN: 0.3730510999994294 seconds
jaxtronomy is 3.0x faster

jaxtronomy execution time for GAUSSIAN_ELLIPSE: 0.20272420000037528 seconds
lenstronomy execution time for GAUSSIAN_ELLIPSE: 0.5507977000015671 seconds
jaxtronomy is 2.7x faster

jaxtronomy execution time for MULTI_GAUSSIAN: 0.21117310000045109 seconds
lenstronomy execution time for MULTI_GAUSSIAN: 1.6327565999999933 seconds
jaxtronomy is 7.7x faster

jaxtronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 0.25613220000013825 seconds
lenstronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 1.936380300001474 seconds
jaxtronomy is 7.6x faster

jaxtronomy execution time for SERSIC: 0.3276308999993489 seconds
lenstronomy execution time for SERSIC: 1.2243885999996564 se

## Comparing jaxtronomy and lenstronomy performance for image convolution
We start by comparing Gaussian convolution. Due to the fact that there is no jax.scipy.ndimage.gaussian_filter function, in jaxtronomy we implement gaussian convolution by constructing a gaussian kernel and performing a pixel kernel fft convolution. This is done in such a way to match scipy.ndimage.guassian_filter with mode="nearest" as closely as possible. The jaxtronomy implementation is slower, but the benefit of this JAX implementation is that we can use autodifferentiation for model fitting.

In [1]:
from lenstronomy.ImSim.Numerics.convolution import MultiGaussianConvolution
from jaxtronomy.ImSim.Numerics.convolution import GaussianConvolution

from jax import numpy as jnp
import numpy as np
import time

num_pix = 60
sigma = 0.5
pixel_scale = 0.11
supersampling_factor = 3
supersampling_convolution = False

for truncation in range(1, 10):

    jax_conv = GaussianConvolution(
        sigma, pixel_scale, supersampling_factor, supersampling_convolution, truncation
    )
    lenstronomy_conv = MultiGaussianConvolution(
        [sigma],
        [1],
        pixel_scale,
        supersampling_factor,
        supersampling_convolution,
        truncation,
    )

    kernel_radius = round(sigma * truncation / pixel_scale)
    if supersampling_convolution:
        kernel_radius *= supersampling_factor

    image_jax = jnp.tile(jnp.linspace(1.0, 100.0, num_pix), num_pix).reshape(
        (num_pix, num_pix)
    )
    image = np.tile(np.linspace(1.0, 100.0, num_pix), num_pix).reshape(
        (num_pix, num_pix)
    )

    # Compile code/warmup
    result_jax = jax_conv.convolution2d(image_jax)
    result = lenstronomy_conv.convolution2d(image)
    np.testing.assert_allclose(result_jax, result, rtol=1e-5, atol=1e-5)

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()

    for _ in range(10000):
        jax_conv.convolution2d(image_jax)

    middle_time = time.perf_counter()

    for _ in range(10000):
        lenstronomy_conv.convolution2d(image)

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for gaussian convolution with kernel radius {kernel_radius}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for gaussian convolution with kernel radius {kernel_radius}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"lenstronomy is {'{0:.1f}'.format(jax_execution_time/lenstronomy_execution_time)}x faster\n"
    )

jaxtronomy execution time for gaussian convolution with kernel radius 5: 0.6235142000004998 seconds
lenstronomy execution time for gaussian convolution with kernel radius 5: 0.4927905999975337 seconds
lenstronomy is 1.3x faster

jaxtronomy execution time for gaussian convolution with kernel radius 9: 0.7403319999975793 seconds
lenstronomy execution time for gaussian convolution with kernel radius 9: 0.5890407000006235 seconds
lenstronomy is 1.3x faster

jaxtronomy execution time for gaussian convolution with kernel radius 14: 2.004389199999423 seconds
lenstronomy execution time for gaussian convolution with kernel radius 14: 0.7015000000028522 seconds
lenstronomy is 2.9x faster

jaxtronomy execution time for gaussian convolution with kernel radius 18: 1.489607800001977 seconds
lenstronomy execution time for gaussian convolution with kernel radius 18: 0.7453275999978359 seconds
lenstronomy is 2.0x faster

jaxtronomy execution time for gaussian convolution with kernel radius 23: 2.400533

Now we compare fft convolution with a given kernel. Generally, jax.scipy.signal.fftconvolve is about twice as fast as the normal scipy version, but there are some specific kernel sizes where the jax version is slower than usual for some unknown reason.

In [None]:
from lenstronomy.ImSim.Numerics.convolution import (
    PixelKernelConvolution as PixelKernelConvolution_ref,
)
from jaxtronomy.ImSim.Numerics.convolution import PixelKernelConvolution

from jax import numpy as jnp
import numpy as np
import time

num_pix = 60
kernel_size_list = np.linspace(3, 45, 22, dtype=int)

supersampling_factor = 3
supersampling_convolution = False

if supersampling_convolution:
    num_pix *= supersampling_factor
    kernel_size_list *= supersampling_factor

for kernel_size in kernel_size_list:
    kernel_jax = jnp.tile(jnp.linspace(0.1, 1.0, kernel_size), kernel_size).reshape(
        (kernel_size, kernel_size)
    )
    kernel_jax = kernel_jax / jnp.sum(kernel_jax)
    kernel = np.tile(np.linspace(0.1, 1.0, kernel_size), kernel_size).reshape(
        (kernel_size, kernel_size)
    )
    kernel = kernel / np.sum(kernel)

    jax_conv = PixelKernelConvolution(kernel=kernel_jax, convolution_type="fft")
    lenstronomy_conv = PixelKernelConvolution_ref(kernel=kernel, convolution_type="fft")

    image_jax = jnp.tile(jnp.linspace(1.0, 20.0, num_pix), num_pix).reshape(
        (num_pix, num_pix)
    )
    image = np.tile(np.linspace(1.0, 20.0, num_pix), num_pix).reshape(
        (num_pix, num_pix)
    )

    # Compile code/warmup
    result_jax = jax_conv.convolution2d(image_jax)
    result = lenstronomy_conv.convolution2d(image)
    np.testing.assert_allclose(result_jax, result, rtol=1e-5, atol=1e-5)

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()

    for _ in range(10000):
        jax_conv.convolution2d(image_jax)

    middle_time = time.perf_counter()

    for _ in range(10000):
        lenstronomy_conv.convolution2d(image)

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for fft convolution with kernel size {kernel_size}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for fft convolution with kernel size {kernel_size}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

jaxtronomy execution time for fft convolution with kernel size 3: 0.7521321000021999 seconds
lenstronomy execution time for fft convolution with kernel size 3: 0.932906499998353 seconds
jaxtronomy is 1.2x faster

jaxtronomy execution time for fft convolution with kernel size 5: 0.44844689999808907 seconds
lenstronomy execution time for fft convolution with kernel size 5: 0.9823906000019633 seconds
jaxtronomy is 2.2x faster

jaxtronomy execution time for fft convolution with kernel size 7: 0.5849750000015774 seconds
lenstronomy execution time for fft convolution with kernel size 7: 1.0667126000007556 seconds
jaxtronomy is 1.8x faster

jaxtronomy execution time for fft convolution with kernel size 9: 0.6611798000012641 seconds
lenstronomy execution time for fft convolution with kernel size 9: 0.9508036999977776 seconds
jaxtronomy is 1.4x faster

jaxtronomy execution time for fft convolution with kernel size 11: 0.6411372000002302 seconds
lenstronomy execution time for fft convolution wit