# Ratio of modified Bessel functions of the first kind

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import ive


from spin_model_transformers.bessel import bessel_iv_ratio


def asymptotic_ratio(nu, x, n=1):
    """Reference"""
    return (x / (1 + (1 + x**2) ** 0.5)) ** n * (1.0)


jit_bessel_iv_ratio = jax.jit(
    jax.vmap(
        bessel_iv_ratio,
        in_axes=(0, None, None),
    ),
    static_argnums=(1, 2),
)

steps = np.logspace(0, 9, num=10, endpoint=True, base=2.0)

r_scipy, r_jax, r_asym = [], [], []
for nu in steps:
    r_scipy.append(
        ive(nu + 1, [nu**0.5, nu, nu**2]) / ive(nu, [nu**0.5, nu, nu**2])
    )
    r_jax.append(jit_bessel_iv_ratio(jnp.array([nu**0.5, nu, nu**2]), nu, 2))
    r_asym.append(
        [
            asymptotic_ratio(nu, nu**-0.5),
            asymptotic_ratio(nu, 1.0),
            asymptotic_ratio(nu, nu),
        ]
    )

with plt.style.context("ggplot"):
    plt.plot(
        steps,
        np.stack(r_scipy),
        label=["SciPy", "_", "_"],
        color="tab:red",
        marker="o",
        markerfacecolor="none",
        linewidth=1.5,
        linestyle="none",
    )
    plt.plot(
        steps,
        np.stack(r_jax),
        label=["JAX", "_", "_"],
        color="tab:blue",
        marker="x",
        linewidth=1.5,
        linestyle="dashed",
    )
    plt.plot(
        steps,
        np.stack(r_asym),
        label=["asym", "_", "_"],
        color="tab:green",
        marker="x",
        linewidth=1.5,
        linestyle="dotted",
    )
    plt.legend()
    plt.xlabel(r"$\nu$")
    plt.ylabel(r"$I_{\nu+1}(x)/I_{\nu}(x)$")
    plt.text(
        200,
        0.96,
        r"$x=\nu^2$",
        horizontalalignment="center",
        verticalalignment="center",
    )
    plt.text(
        200, 0.45, r"$x=\nu$", horizontalalignment="center", verticalalignment="center"
    )
    plt.text(
        200,
        0.08,
        r"$x=\sqrt{\nu}$",
        horizontalalignment="center",
        verticalalignment="center",
    )
    plt.show()