## Comparing jaxtronomy and lenstronomy performance for lens model ray shooting

In [None]:
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LensModel.lens_model import LensModel
from lenstronomy.LensModel.lens_model import LensModel as LensModel_ref
from jaxtronomy.LensModel.profile_list_base import (
    _JAXXED_MODELS as JAXXED_DEFLECTOR_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = True

if supersampling:
    num_pix *= supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for deflector_profile in JAXXED_DEFLECTOR_PROFILES:
    lensModel = LensModel([deflector_profile])
    if deflector_profile == "EPL":
        deflector_profile = "EPL_NUMBA"
    lensModel_ref = LensModel_ref([deflector_profile])
    kwargs_lens = lensModel.lens_model.func_list[0].upper_limit_default

    # Compile code/warmup
    lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])
    lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lensModel.ray_shooting(x_jax, y_jax, [kwargs_lens])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lensModel_ref.ray_shooting(x, y, [kwargs_lens])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {deflector_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {deflector_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy takes {'{0:.1f}'.format(jax_execution_time/lenstronomy_execution_time * 100)}% as long\n"
    )

jaxtronomy execution time for CONVERGENCE: 0.21539789999951608 seconds
lenstronomy execution time for CONVERGENCE: 1.4262556999965454 seconds
jaxtronomy takes 15.1% as long

jaxtronomy execution time for CSE: 0.8345672999930684 seconds
lenstronomy execution time for CSE: 4.805201100003615 seconds
jaxtronomy takes 17.4% as long

jaxtronomy execution time for EPL_NUMBA: 49.048782100006065 seconds
lenstronomy execution time for EPL_NUMBA: 81.85929719999694 seconds
jaxtronomy takes 59.9% as long

jaxtronomy execution time for EPL_Q_PHI: 44.40125360000093 seconds
lenstronomy execution time for EPL_Q_PHI: 17.591442200005986 seconds
jaxtronomy takes 252.4% as long

jaxtronomy execution time for GAUSSIAN: 1.3684808000034536 seconds
lenstronomy execution time for GAUSSIAN: 5.109279099997366 seconds
jaxtronomy takes 26.8% as long

jaxtronomy execution time for GAUSSIAN_POTENTIAL: 1.3590571000022464 seconds
lenstronomy execution time for GAUSSIAN_POTENTIAL: 4.187931799999205 seconds
jaxtronomy ta

## Comparing jaxtronomy and lenstronomy performance for light model surface brightness

In [1]:
import copy
from jax import numpy as jnp
import numpy as np
import time

from jaxtronomy.LightModel.light_model import LightModel
from lenstronomy.LightModel.light_model import LightModel as LightModel_ref
from jaxtronomy.LightModel.light_model_base import (
    _JAXXED_MODELS as JAXXED_SOURCE_PROFILES,
)

# 60x60 grid
num_pix = 60
supersampling_factor = 3
supersampling = False

if supersampling:
    num_pix *= supersampling_factor

x_jax = jnp.tile(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y_jax = jnp.repeat(jnp.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

x = np.tile(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)
y = np.repeat(np.linspace(-5.0, 5.0, num_pix) + 100, num_pix)

for source_profile in JAXXED_SOURCE_PROFILES:
    lightModel = LightModel([source_profile])
    lightModel_ref = LightModel_ref([source_profile])
    kwargs_source = copy.deepcopy(lightModel.func_list[0].upper_limit_default)
    for key, val in kwargs_source.items():
        if source_profile in [
            "MULTI_GAUSSIAN",
            "MULTI_GAUSSIAN_ELLIPSE",
        ] and key in ["amp", "sigma"]:
            kwargs_source[key] = np.linspace(val / 10, val, 5)
    if source_profile == "SHAPELETS":
        kwargs_source["amp"] = np.linspace(20.0, 30.0, 66)
        kwargs_source["n_max"] = 10

    # Compile code/warmup
    lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])
    lightModel_ref.surface_brightness(x, y, [kwargs_source])

    # Now time runtime after compilation/warmup
    start_time = time.perf_counter()
    for _ in range(10000):
        lightModel.surface_brightness(x_jax, y_jax, [kwargs_source])

    middle_time = time.perf_counter()

    for _ in range(10000):
        lightModel_ref.surface_brightness(x, y, [kwargs_source])

    end_time = time.perf_counter()

    jax_execution_time = middle_time - start_time
    lenstronomy_execution_time = end_time - middle_time
    print(
        f"jaxtronomy execution time for {source_profile}: {jax_execution_time} seconds"
    )
    print(
        f"lenstronomy execution time for {source_profile}: {lenstronomy_execution_time} seconds"
    )
    print(
        f"jaxtronomy takes {'{0:.1f}'.format(jax_execution_time/lenstronomy_execution_time * 100)}% as long\n"
    )

jaxtronomy execution time for CORE_SERSIC: 0.4202519999962533 seconds
lenstronomy execution time for CORE_SERSIC: 2.389370400000189 seconds
jaxtronomy takes 17.6% as long

jaxtronomy execution time for GAUSSIAN: 0.12592279999807943 seconds
lenstronomy execution time for GAUSSIAN: 0.3873317000034149 seconds
jaxtronomy takes 32.5% as long

jaxtronomy execution time for GAUSSIAN_ELLIPSE: 0.18123300000297604 seconds
lenstronomy execution time for GAUSSIAN_ELLIPSE: 0.5806130999990273 seconds
jaxtronomy takes 31.2% as long

jaxtronomy execution time for MULTI_GAUSSIAN: 0.20797440000023926 seconds
lenstronomy execution time for MULTI_GAUSSIAN: 1.8185205000045244 seconds
jaxtronomy takes 11.4% as long

jaxtronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 0.26332089999777963 seconds
lenstronomy execution time for MULTI_GAUSSIAN_ELLIPSE: 1.9643479000005755 seconds
jaxtronomy takes 13.4% as long

jaxtronomy execution time for SERSIC: 0.3364309999960824 seconds
lenstronomy execution time for SERS



jaxtronomy execution time for SHAPELETS: 1.4494135999993887 seconds
lenstronomy execution time for SHAPELETS: 14.34100330000365 seconds
jaxtronomy takes 10.1% as long

