In [None]:
from neos.fit import get_solvers
from neos import models

import jax
from jax.experimental import stax

import numpy as np

In [None]:
init_random_params, predict = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(2),
    stax.Softmax,
)

_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))



In [None]:
def get_hists(network, s, b1, b2):
    NMC = len(s)
    LUMI = 10
    sh, bh1, bh2 = (
        predict(network, s).sum(axis=0) * 2 / NMC * LUMI,
        predict(network, b1).sum(axis=0) * 10 / NMC * LUMI,
        predict(network, b2).sum(axis=0) * 10 / NMC * LUMI,
    )
    b_mean = jax.numpy.mean(jax.numpy.asarray([bh1, bh2]), axis=0)
    b_unc = jax.numpy.std(jax.numpy.asarray([bh1, bh2]), axis=0)
    results = sh, b_mean, b_unc
    return results


def hist_maker():
    NMC = 500
    bkg1 = np.random.multivariate_normal([2, 2], [[1, 0], [0, 1]], size=(NMC,))
    bkg2 = np.random.multivariate_normal([-1, -1], [[1, 0], [0, 1]], size=(NMC,))
    sig = np.random.multivariate_normal([-1, 1], [[1, 0], [0, 1]], size=(NMC,))

    def make(network):
        return get_hists(network, sig, bkg1, bkg2)

    make.bkg1 = bkg1
    make.bkg2 = bkg2
    make.sig = sig
    return make

In [None]:
network[0][0][0][0:2]+10

DeviceArray([ 9.97184 , 10.008986], dtype=float32)

In [None]:
import jax
from jax.experimental import optimizers
from fax.implicit import twophase


# doesn't matter what we return!
def log_likelihood(pars):
    return jax.numpy.ones(1,)[0]


def get_fit(
    default_rtol=1e-10,
    default_atol=1e-10,
    default_max_iter=int(1e7),
    learning_rate = 0.01
):

    adam_init, adam_update, adam_get_params  = optimizers.adam(1e-6)

    def global_bestfit_minimized(ignored_param):
        
        def bestfit_via_grad_descent(i, param):  # gradient descent
            g = jax.grad(log_likelihood)(param)
            param = adam_get_params(adam_update(i,g,adam_init(param)))
            return param

        return bestfit_via_grad_descent

    global_solve = twophase.two_phase_solver(
        param_func=global_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter
    )

    def global_fit(init, ignored_param):
        solve = global_solve(init, ignored_param)
        return solve.value

    return global_fit

def do_fit(ignored_param):
    fit = get_fit()
    
    # Commenting this line gives the error
    #fit = jax.jit(fit) 
    
    return fit(1.,ignored_param)

ignored_param = 1
jax.jit(do_fit)(ignored_param)

AssertionError: If you see this error, please let us know by opening an issue at
https://github.com/google/jax/issues 
since we thought this was unreachable!

In [None]:
def implicit_fit(value):
    fit = get_solvers()
    
    # Commenting this line gives the error
    #fit = jax.jit(fit) 
    
    test = jax.numpy.array([0.5,0.5])
    return fit(test,value)

jax.jit(implicit_fit)(1) # produces error

AssertionError: If you see this error, please let us know by opening an issue at
https://github.com/google/jax/issues 
since we thought this was unreachable!

In [None]:
def whatwhat(network):
    c_fitter = get_solver()
    #c_fitter = jax.jit(c_fitter)
    test = jax.numpy.array([0.5,0.5])
    return c_fitter(1.0,network[0][0][0][0]+1)

jax.jit(whatwhat)(network)

DeviceArray(1., dtype=float64)

In [None]:
import scipy.optimize
from fax.implicit import twophase

def get_solver(default_rtol=1e-10,
    default_atol=1e-10,
    default_max_iter=int(1e7),
    learning_rate = 0.01):
    
    def func_to_minimize(hparam):
        def func(param):
            return (param-hparam**2)**2
        return func

    def minimize_from(hparam):
        func = func_to_minimize(hparam)
        def fixed_point_func(i,param):
            return param-jax.grad(func)(param)
        return fixed_point_func

    global_solve = twophase.two_phase_solver(
            param_func=minimize_from,
            default_rtol=default_rtol,
            default_atol=default_atol,
            default_max_iter=default_max_iter
        )


    def g_fitter(init, hyper_pars):
        solve = global_solve(init, hyper_pars)
        return solve.value

    return g_fitter

In [None]:
x = get_solver()

DeviceArray(1., dtype=float64)