## Comparing jaxtronomy and lenstronomy performance for lens model ray shooting

In [3]:
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LensModel.lens_model import LensModel
from lenstronomy.LensModel.lens_model import LensModel as LensModel_ref
from jaxtronomy.LensModel.profile_list_base import (
    _JAXXED_MODELS as JAXXED_DEFLECTOR_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = True

if supersampling:
    num_pix += supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for deflector_profile in JAXXED_DEFLECTOR_PROFILES:
    lensModel = LensModel([deflector_profile])
    if deflector_profile == "EPL":
        deflector_profile = "EPL_NUMBA"
    lensModel_ref = LensModel_ref([deflector_profile])
    kwargs_lens = lensModel.lens_model.func_list[0].upper_limit_default

    # Compile code/warmup
    lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])
    lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(f"jaxtronomy execution time for {deflector_profile}: {jax_execution_time} seconds")
    print(f"lenstronomy execution time for {deflector_profile}: {lenstronomy_execution_time} seconds")
    print(f"jaxtronomy takes {'{0:.1f}'.format(jax_execution_time/lenstronomy_execution_time * 100)}% as long\n")

jaxtronomy execution time for CONVERGENCE: 0.12459120000130497 seconds
lenstronomy execution time for CONVERGENCE: 0.21981309999682708 seconds
jaxtronomy takes 56.7% as long

jaxtronomy execution time for CSE: 0.14737099999911152 seconds
lenstronomy execution time for CSE: 0.8491871000005631 seconds
jaxtronomy takes 17.4% as long

jaxtronomy execution time for EPL_NUMBA: 7.412143699999433 seconds
lenstronomy execution time for EPL_NUMBA: 10.091441799995664 seconds
jaxtronomy takes 73.4% as long



  R_omega = Z * hyp2f1(1, t / 2, 2 - t / 2, -(1 - q) / (1 + q) * (Z / Z.conj()))


jaxtronomy execution time for EPL_Q_PHI: 7.982565800004522 seconds
lenstronomy execution time for EPL_Q_PHI: 2.33471819999977 seconds
jaxtronomy takes 341.9% as long

jaxtronomy execution time for GAUSSIAN: 0.32333610000205226 seconds
lenstronomy execution time for GAUSSIAN: 0.6915124999941327 seconds
jaxtronomy takes 46.8% as long

jaxtronomy execution time for GAUSSIAN_POTENTIAL: 0.32431639999413164 seconds
lenstronomy execution time for GAUSSIAN_POTENTIAL: 0.6254164000056335 seconds
jaxtronomy takes 51.9% as long

jaxtronomy execution time for HERNQUIST: 0.8544383000044036 seconds
lenstronomy execution time for HERNQUIST: 1.2600654999987455 seconds
jaxtronomy takes 67.8% as long

jaxtronomy execution time for HERNQUIST_ELLIPSE_CSE: 4.508605499999248 seconds
lenstronomy execution time for HERNQUIST_ELLIPSE_CSE: 21.022818199999165 seconds
jaxtronomy takes 21.4% as long

jaxtronomy execution time for LOS: 0.20142560000385856 seconds
lenstronomy execution time for LOS: 0.560108300000138

  alpha = theta_E * (r2 / theta_E**2) ** (1 - gamma / 2.0)
  f_x = fac * xt1
  f_y = fac * xt2


jaxtronomy execution time for SPP: 0.6270390999998199 seconds
lenstronomy execution time for SPP: 0.8886139999958687 seconds
jaxtronomy takes 70.6% as long



## Comparing jaxtronomy and lenstronomy performance for light model surface brightness

In [7]:
import copy
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LightModel.light_model import LightModel
from lenstronomy.LightModel.light_model import LightModel as LightModel_ref
from jaxtronomy.LightModel.light_model_base import (
    _JAXXED_MODELS as JAXXED_SOURCE_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = True

if supersampling:
    num_pix += supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for source_profile in JAXXED_SOURCE_PROFILES:
    lightModel = LightModel([source_profile])
    lightModel_ref = LightModel_ref([source_profile])
    kwargs_source = copy.deepcopy(lightModel.func_list[0].upper_limit_default)
    for key, val in kwargs_source.items():
        if source_profile in [
            "MULTI_GAUSSIAN",
            "MULTI_GAUSSIAN_ELLIPSE",
        ] and key in ["amp", "sigma"]:
            kwargs_source[key] = np.linspace(val/10, val, 5)
    if source_profile == "SHAPELETS":
        kwargs_source['amp'] = np.linspace(20., 30., 66)
        kwargs_source['n_max'] = 10

    # Compile code/warmup
    lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])
    lightModel_ref.surface_brightness(x, y, [kwargs_source])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lightModel_ref.surface_brightness(x, y, [kwargs_source])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(f"jaxtronomy execution time for {source_profile}: {jax_execution_time} seconds")
    print(f"lenstronomy execution time for {source_profile}: {lenstronomy_execution_time} seconds")
    print(f"jaxtronomy takes {'{0:.1f}'.format(jax_execution_time/lenstronomy_execution_time * 100)}% as long\n")

jaxtronomy execution time for CORE_SERSIC: 0.6009732000020449 seconds
lenstronomy execution time for CORE_SERSIC: 3.027182999998331 seconds
jaxtronomy takes 19.9% as long

jaxtronomy execution time for GAUSSIAN: 0.12626150000141934 seconds
lenstronomy execution time for GAUSSIAN: 0.400462300000072 seconds
jaxtronomy takes 31.5% as long

jaxtronomy execution time for GAUSSIAN_ELLIPSE: 0.23892630000045756 seconds
lenstronomy execution time for GAUSSIAN_ELLIPSE: 0.586217600000964 seconds
jaxtronomy takes 40.8% as long

jaxtronomy execution time for MULTI_GAUSSIAN: 0.2912997999956133 seconds
lenstronomy execution time for MULTI_GAUSSIAN: 1.7678541999994195 seconds
jaxtronomy takes 16.5% as long

jaxtronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 0.34107599999697413 seconds
lenstronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 1.9826394999981858 seconds
jaxtronomy takes 17.2% as long

jaxtronomy execution time for SERSIC: 0.3457901000001584 seconds
lenstronomy execution time for SERSIC:



jaxtronomy execution time for SHAPELETS: 0.693164500000421 seconds
lenstronomy execution time for SHAPELETS: 14.717462499997055 seconds
jaxtronomy takes 4.7% as long

