In [1]:
from scipy.optimize import leastsq, minimize
import jax
import jax.numpy as jnp
from jax import grad, jit
from numba import njit
import numpy as np

In [None]:
import jax
from jax import Array, jit, numpy as jnp
from typing import Callable


@jit
def func(a: Array, arg2: int) -> Array:
    return a + arg2


@jit
def myjittedfun(f: Callable, a) -> Array:
    return f(a)


closure = jax.tree_util.Partial(func, arg2=0)
a = jnp.array([3, 4])
print(myjittedfun(closure, a))

## Linear fit. A*x+b

In [2]:
def func_to_fit(x, a):
    return a[0] * x + a[1]+a[2]*x**2+a[3]*x**3


x = np.linspace(0, 10, 1_000_000)
x_right = [50, 30, 60, -100]
x0 = [1.0, 1.0, 1.0, 1.0]

data = func_to_fit(x, x_right) + np.random.randn(len(x))

def to_minimize(args):
    return func_to_fit(x, args) - data



res, _ = leastsq(to_minimize, x0,)
if np.abs(np.sum(res - x_right))>0.5:
    print("Failed to fit. Right parameters:", x_right, "Fitted parameters:", res)
else:
    print("Success to fit. Right parameters:", x_right, "Fitted parameters:", res)
    
%timeit leastsq(to_minimize, x0)


Success to fit. Right parameters: [50, 30, 60, -100] Fitted parameters: [ 50.00270706  29.99870055  59.99915987 -99.99993824]
235 ms ± 5.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Fit Sin

In [15]:
def func_to_fit(x, a):
    return a[0]*np.sin(a[1] * x) + a[2]


x = np.linspace(0, 10, 1_000_000)
x_right = [5, 2, 6]
x0 = [1,2,2]

data = func_to_fit(x, x_right) + np.random.randn(len(x))

def to_minimize(args):
    return func_to_fit(x, args) - data



res, _ = leastsq(to_minimize, x0,)
if np.abs(np.sum(res - x_right))>0.5:
    print("Failed to fit. Right parameters:", x_right, "Fitted parameters:", res)
else:
    print("Success to fit. Right parameters:", x_right, "Fitted parameters:", res)
    
%timeit leastsq(to_minimize, x0)


Success to fit. Right parameters: [5, 2, 6] Fitted parameters: [5.00248733 1.9999943  5.99942826]
156 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
from jax import jit as jax_jit
import jax.numpy as jnp

@jax_jit
def func_to_fit(x, a):
    return a[0]*jnp.sin(a[1] * x) + a[2]

x = np.linspace(0, 10, 1_000_000)
x_right = [5, 2, 6]
x0 = [1,2,2]

data = func_to_fit(x, x_right) + np.random.randn(len(x))

def to_minimize(args):
    return func_to_fit(x, args) - data



res, _ = leastsq(to_minimize, x0,)
if np.abs(np.sum(res - x_right))>0.5:
    print("Failed to fit. Right parameters:", x_right, "Fitted parameters:", res)
else:
    print("Success to fit. Right parameters:", x_right, "Fitted parameters:", res)
    
%timeit leastsq(to_minimize, x0)


Success to fit. Right parameters: [5, 2, 6] Fitted parameters: [4.999143   1.99997798 5.99919216]
117 ms ± 1.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [72]:
leastsq(to_minimize, x0, full_output=True)

(array([  49.99922858,   30.00203631,   60.00004841, -100.00000127]),
 array([[ 1.19999225e-05, -1.19999045e-05, -2.69998300e-06,
          1.67998915e-07],
        [-1.19999045e-05,  1.59998805e-05,  2.39997969e-06,
         -1.39998746e-07],
        [-2.69998300e-06,  2.39997969e-06,  6.47996168e-07,
         -4.19997494e-08],
        [ 1.67998915e-07, -1.39998746e-07, -4.19997494e-08,
          2.79998330e-09]]),
 {'fvec': array([-0.57746262,  0.706797  , -0.96950354, ..., -0.26087287,
         -0.42506463, -0.01513807]),
  'nfev': 11,
  'fjac': array([[-3.77964944e+05,  0.00000000e+00,  0.00000000e+00, ...,
           2.64573211e-03,  2.64574005e-03,  2.64574800e-03],
         [-4.40958883e+04, -7.45356553e+03,  5.33119314e-13, ...,
          -2.23601695e-03, -2.23603767e-03, -2.23605854e-03],
         [-5.29150395e+03, -2.23606854e+03, -5.77350411e+02, ...,
           1.73198846e-03,  1.73202985e-03,  1.73203776e-03],
         [-6.61437669e+02, -5.59016865e+02, -4.33012597e+02, ..

In [None]:
import jax.numpy as jnp
from numba import njit
import numpy as np
from jax import jit as jax_jit

def jax_function_nocc(x):
    res = jnp.copy(x)
    for _ in range(100):
        res += jnp.sin(x) @ jnp.cos(x)
    return res

@jax_jit
def jax_function_compile(x):
    res = jnp.copy(x)
    for _ in range(100):
        res += jnp.sin(x) @ jnp.cos(x)
    return res

# Compile functions
jax_function_nocc(x_jax)
jax_function_compile(x_jax)

# time the functions
%timeit jax_function_nocc(x_jax)
%timeit jax_function_compile(x_jax)

2.03 s ± 372 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.4 ms ± 3.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
