## Comparing jaxtronomy and lenstronomy performance for lens model ray shooting

In [1]:
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LensModel.lens_model import LensModel
from lenstronomy.LensModel.lens_model import LensModel as LensModel_ref
from jaxtronomy.LensModel.profile_list_base import (
    _JAXXED_MODELS as JAXXED_DEFLECTOR_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = True

if supersampling:
    num_pix *= supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for deflector_profile in JAXXED_DEFLECTOR_PROFILES:
    lensModel = LensModel([deflector_profile])

    if deflector_profile == "EPL":
        lensModel_ref = LensModel_ref(["EPL_NUMBA"])
    else:
        lensModel_ref = LensModel_ref([deflector_profile])

    kwargs_lens = lensModel.lens_model.func_list[0].upper_limit_default

    # Compile code/warmup
    lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])
    lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {deflector_profile}: {jax_execution_time} seconds"
    )
    if deflector_profile == "EPL":
        deflector_profile = "EPL_NUMBA"
    print(
        f"lenstronomy execution time for {deflector_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

jaxtronomy execution time for CONVERGENCE: 0.20506839999870863 seconds
lenstronomy execution time for CONVERGENCE: 1.4116154000003007 seconds
jaxtronomy is 6.9x faster

jaxtronomy execution time for CSE: 0.8139273999986472 seconds
lenstronomy execution time for CSE: 8.218596400000024 seconds
jaxtronomy is 10.1x faster

jaxtronomy execution time for EPL: 37.05310129999998 seconds
lenstronomy execution time for EPL_NUMBA: 80.56953299999986 seconds
jaxtronomy is 2.2x faster

jaxtronomy execution time for EPL_Q_PHI: 37.68870070000048 seconds
lenstronomy execution time for EPL_Q_PHI: 16.48123719999967 seconds
jaxtronomy is 0.4x faster

jaxtronomy execution time for GAUSSIAN: 1.3453974000003655 seconds
lenstronomy execution time for GAUSSIAN: 4.445075600000564 seconds
jaxtronomy is 3.3x faster

jaxtronomy execution time for GAUSSIAN_POTENTIAL: 1.3102717000001576 seconds
lenstronomy execution time for GAUSSIAN_POTENTIAL: 4.051776199999949 seconds
jaxtronomy is 3.1x faster

jaxtronomy executio

## Comparing jaxtronomy and lenstronomy performance for light model surface brightness
Note: Restarting the kernel is recommended before running the below tests on a low-memory machine, otherwise you may introduce a memory bottleneck from caching all of the compiled functions from the lens model testing in the previous cell

In [1]:
import copy
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LightModel.light_model import LightModel
from lenstronomy.LightModel.light_model import LightModel as LightModel_ref
from jaxtronomy.LightModel.light_model_base import (
    _JAXXED_MODELS as JAXXED_SOURCE_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = True

if supersampling:
    num_pix *= supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for source_profile in JAXXED_SOURCE_PROFILES:
    lightModel = LightModel([source_profile])
    lightModel_ref = LightModel_ref([source_profile])
    kwargs_source = copy.deepcopy(lightModel.func_list[0].upper_limit_default)
    for key, val in kwargs_source.items():
        if source_profile in [
            "MULTI_GAUSSIAN",
            "MULTI_GAUSSIAN_ELLIPSE",
        ] and key in ["amp", "sigma"]:
            kwargs_source[key] = np.linspace(val / 10, val, 5)

    # Do this comparison after the refactor since it is not really usable at the moment
    if source_profile == "SHAPELETS":
        continue
        #kwargs_source["amp"] = np.linspace(20.0, 30.0, 66)
        #kwargs_source["n_max"] = 10

    # Compile code/warmup
    lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])
    lightModel_ref.surface_brightness(x, y, [kwargs_source])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lightModel_ref.surface_brightness(x, y, [kwargs_source])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {source_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {source_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy is {'{0:.1f}'.format(lenstronomy_execution_time/jax_execution_time)}x faster\n"
    )

jaxtronomy execution time for CORE_SERSIC: 1.2848345000002155 seconds
lenstronomy execution time for CORE_SERSIC: 18.967165199999727 seconds
jaxtronomy is 14.8x faster

jaxtronomy execution time for GAUSSIAN: 0.2517599999991944 seconds
lenstronomy execution time for GAUSSIAN: 2.4892575000012584 seconds
jaxtronomy is 9.9x faster

jaxtronomy execution time for GAUSSIAN_ELLIPSE: 0.46514789999855566 seconds
lenstronomy execution time for GAUSSIAN_ELLIPSE: 3.6453911000007793 seconds
jaxtronomy is 7.8x faster

jaxtronomy execution time for MULTI_GAUSSIAN: 0.6295931999993627 seconds
lenstronomy execution time for MULTI_GAUSSIAN: 11.456636000000799 seconds
jaxtronomy is 18.2x faster

jaxtronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 0.682676000000356 seconds
lenstronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 12.356475099999443 seconds
jaxtronomy is 18.1x faster

jaxtronomy execution time for SERSIC: 0.8067269000002852 seconds
lenstronomy execution time for SERSIC: 8.104093699999794 sec