In [1]:
import jax
import fullstream.models as models
from fullstream.cls import cls_maker
import numpy as np
import jax.experimental.stax as stax
import jax.experimental.optimizers as optimizers
import jax.random

In [2]:
init_random_params, predict = stax.serial(
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1), stax.Sigmoid
)

In [68]:
def get_hists(network,s,b1,b2):
    NMC = len(s)
    LUMI = 10
    s,b1,b2 = (
        predict(network,s).ravel().sum()*2/NMC * LUMI,
        predict(network,b1).ravel().sum()*10/NMC * LUMI,
        predict(network,b2).ravel().sum()*10/NMC* LUMI
    )
    b_mean = (b1+b2)/2
    b_unc  = jax.numpy.abs((b1-b2)/2)
    results = s,b_mean,b_unc
    return results

def hist_maker():
    NMC = 500
    bkg1 = np.random.multivariate_normal([2,2],[[1,0],[0,1]], size = (NMC,))
    bkg2 = np.random.multivariate_normal([-1,-1],[[1,0],[0,1]], size = (NMC,))
    sig  = np.random.multivariate_normal([-1,1],[[1,0],[0,1]], size = (NMC,))
    def make(network):
        return get_hists(network,sig,bkg1,bkg2)
    make.bkg1 = bkg1
    make.bkg2 = bkg2
    make.sig  = sig
    return make

def makeNN():
    hm = hist_maker()
    def nn_model_maker(network):
        s,b,db = hm(network)
        m = models.hepdata_like([s], [b], [db])
        nompars = m.config.suggested_init()
        bonlypars = jax.numpy.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars
    nn_model_maker.hm = hm
    return nn_model_maker


In [69]:
nnm = makeNN()
loss = cls_maker(nnm,solver_kwargs=dict(pdf_transform=True))

In [70]:
_,network = init_random_params(jax.random.PRNGKey(1),(-1,2))
jax.value_and_grad(loss)(network,1.0)

(DeviceArray(0.19328345, dtype=float64),
 [(DeviceArray([[ 0.00033304, -0.00055408, -0.00579085, ..., -0.00997428,
                 -0.00318625, -0.00307893],
                [-0.00109069, -0.00033859, -0.00254815, ...,  0.00499082,
                 -0.00070228, -0.00477251]], dtype=float32),
   DeviceArray([-0.00084073,  0.0003814 , -0.00123341, ...,  0.00637819,
                 0.00213871, -0.00226459], dtype=float32)),
  (),
  (DeviceArray([[ 1.75309804e-04, -1.71597028e-04,  0.00000000e+00, ...,
                  1.38020594e-04, -2.00585487e-06, -1.72544314e-04],
                [ 1.85450772e-05,  9.54204588e-05, -3.34541146e-05, ...,
                  4.78524225e-06, -9.30776991e-07, -2.60588422e-05],
                [-1.81693031e-05,  3.61806015e-04, -4.43428253e-05, ...,
                 -8.42693771e-06,  1.90722567e-08, -3.35249424e-05],
                ...,
                [ 6.50360074e-04, -3.68370063e-04, -1.29543996e-05, ...,
                  5.19682537e-04, -4.81607685e-

In [71]:
def train_network(N):
    cls_vals = []
    opt_init, opt_update, opt_params = optimizers.adam(1e-3)
    _,network = init_random_params(jax.random.PRNGKey(1),(-1,2))
    state = opt_init(network)
    losses = []
    for i in range(N):
        network = opt_params(state)
        value,grad = jax.value_and_grad(loss)(network,1.0)
        losses.append(value)
        state = opt_update(i,grad,state)
        metrics = {'loss': losses}
        yield network,metrics

In [72]:
def plot(axarr,network,metrics,hm,maxN):
    ax = axarr[0]
    g= np.mgrid[-5:5:101j,-5:5:101j]
    levels = np.linspace(0,1,21)
    ax.contourf(g[0],g[1],predict(network,np.moveaxis(g,0,-1)).reshape(101,101), levels = levels)
    ax.contour(g[0],g[1],predict(network,np.moveaxis(g,0,-1)).reshape(101,101), colors = 'w', levels = levels)
    ax.scatter(hm.sig[:,0],hm.sig[:,1],alpha = 0.2, c = 'white')
    ax.scatter(hm.bkg1[:,0],hm.bkg1[:,1],alpha = 0.1, c = 'maroon')
    ax.scatter(hm.bkg2[:,0],hm.bkg2[:,1],alpha = 0.1, c = 'maroon')
    ax.set_xlim(-5,5)
    ax.set_ylim(-5,5)
    
    ax = axarr[1]
    ax.plot(metrics['loss'],c = 'k')
    ax.set_ylim(0,1)
    ax.set_xlim(0,maxN)
    
    ax = axarr[2]
    s,b,db = hm(network)
    ax.bar(range(1),b,color = 'maroon')
    ax.bar(range(1),s,bottom = b, color = 'grey')
    ax.bar(range(1),db,bottom = b-db/2., alpha = 0.5, color = 'black')
    ax.set_ylim(0,100)

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from celluloid import Camera


fig, axarr = plt.subplots(1,3)
fig.set_size_inches(15,5)
camera = Camera(fig)
maxN = 300
for i,(network,metrics) in enumerate(train_network(maxN)):
    print(i,metrics['loss'][-1])
    plot(axarr,network,metrics,nnm.hm,maxN = maxN)
    camera.snap()
    if i % 10 == 0:
        camera.animate().save('animation.gif', writer='imagemagick', fps=10)
camera.animate().save('animation.gif', writer='imagemagick', fps=10)

0 0.19328344891005278
1 0.43586859388311305
2 0.13280589387745256
3 0.1774936434222234
4 0.18064326604745684
5 0.10231259560316097
6 0.05174061180302836
7 0.06331696055046421
8 0.08456229248998426
9 0.08122562347770046
10 0.061838297290535804
11 0.04518645495356455
12 0.040267545452321585
13 0.04443964723810745
14 0.05039207557019698
15 0.05234674095907077
16 0.049256121056254765
17 0.04347241755515907
18 0.03813404778733376
19 0.03523540220844268
20 0.035048165610867876
21 0.03650737075697341
22 0.03796241969072711
23 0.03814877634284519
24 0.036793589779108515
25 0.034515357496008425
26 0.03228164092665842
27 0.0307947077533266
28 0.030259053336349373
29 0.030411892599020884
30 0.030723418855382656
31 0.030677454752946254
32 0.030023556713737243
33 0.0288645180895013
34 0.027542859487623916
35 0.026417434118708893
36 0.025683251252889505
37 0.025305299325004027
38 0.0250756882841483
39 0.024741598862428438
40 0.024145573087477112
41 0.02329744344589657
42 0.022343420082929777
43 0.02