## Scratch Notebook for TS calculations

### Initialize

In [None]:
import time
import matplotlib.pyplot as plt
from inverse_thomson_scattering.v0 import form_factor as np_ff
from inverse_thomson_scattering.jax import form_factor as jnp_ff
from inverse_thomson_scattering.utils import plotting
import numpy as np

In [None]:
x = np.array(np.arange(-8, 8, 0.1))
distf = 1 / (2 * np.pi) ** (1 / 2) * np.exp(-(x**2) / 2)
sa = np.linspace(55, 65, 10)
# backend = "jax"

In [None]:
# if backend == "numpy":
t0 = time.time()
formf, lams = np_ff.nonMaxwThomson(1.0, 1.0, 1.0, 1.0, 1.0, 0.3e20, 0.0, 0.0, [400, 700], 526.5, sa, distf, x)
t1 = time.time()
print(f"numpy/scipy form factor calculation {np.round(t1 - t0, 4)} s")

In [None]:
# elif backend == "jax":
# get the functions
ff_fn, vg_ff_fn = jnp_ff.get_form_factor_fn([400, 700], 526.5)

# run them once so they're compiled
_ = ff_fn(1.0, 1.0, 1.0, 1.0, 1.0, 0.3e20, 0.0, 0.0, sa, (distf, x))
_ = vg_ff_fn(1.0, 1.0, 1.0, 1.0, 1.0, 0.3e20, 0.0, 0.0, sa, (distf, x))

# then run them again to benchmark them
# TODO: find a better way to measure this
t0 = time.time()
formf_jax, lams_jax = ff_fn(1.0, 1.0, 1.0, 1.0, 1.0, 0.3e20, 0.0, 0.0, sa, (distf, x))
t1 = time.time()
print(f"jax form factor calculation took {np.round(t1 - t0, 4)} s")

In [None]:
t0 = time.time()
val, grad = vg_ff_fn(1.0, 1.0, 1.0, 1.0, 1.0, 0.3e20, 0.0, 0.0, sa, (distf, x))
t1 = time.time()
print(f"value and gradient took {np.round(t1 - t0, 4)} s")
print(f"gradient was {grad}")

In [None]:
plotting(formf, formf_jax)