In [1]:
# This is PyProb_Testing.ipynb converted to a script, fixing some errors.
# I've also reduced num_traces so that this script runs quickly.
# (See comments next to num_traces below for original values.)
import pyprob
import numpy as np
import torch

from pyprob.dis import ModelDIS
from showerSim import invMass_ginkgo

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as mpl_cm
plt.ion()

import sklearn as skl
from sklearn.linear_model import LinearRegression



In [2]:
obs_leaves = torch.tensor([[44.57652381, 26.16169856, 25.3945314 , 25.64598258],
       [18.2146321 , 10.70465096, 10.43553391, 10.40449709],
       [ 6.47106713,  4.0435395,  3.65545951,  3.48697568],
       [ 8.43764314,  5.51040615,  4.60990593,  4.42270416],
       [26.61664145, 16.55894826, 14.3357362 , 15.12215264],
       [ 8.62925002,  3.37121204,  5.19699   ,  6.00480461],
       [ 1.64291837,  0.74506775,  1.01003622,  1.05626017],
       [ 0.75525072,  0.3051808 ,  0.45721085,  0.51760643],
       [39.5749915 , 18.39638928, 24.24717939, 25.29349408],
       [ 4.18355659,  2.11145474,  2.82071304,  2.25221316],
       [ 0.82932922,  0.29842766,  0.5799056 ,  0.509021  ],
       [ 3.00825023,  1.36339397,  1.99203677,  1.79428211],
       [ 7.20024308,  4.03280868,  3.82379277,  4.57441754],
       [ 2.09953618,  1.28473579,  1.03554351,  1.29769683],
       [12.21401828,  6.76059035,  6.94920042,  7.42823701],
       [ 6.91438054,  3.68417135,  3.83782514,  4.41656731],
       [ 1.97218904,  1.01632927,  1.08008339,  1.27454585],
       [ 8.58164301,  5.06157833,  4.79691164,  4.99553141],
       [ 5.97809522,  3.26557958,  3.4253764 ,  3.64894791],
       [ 5.22842301,  2.94437891,  3.10292633,  3.00551074],
       [15.40023764,  9.10884407,  8.93836964,  8.61970667],
       [ 1.96101346,  1.24996337,  1.06923988,  1.06743143],
       [19.81054106, 11.90268453, 11.60989346, 10.76953856],
       [18.79470876, 11.429855  , 10.8377334 , 10.25112761],
       [25.74331932, 15.63430056, 14.83860792, 14.07189108],
       [ 9.98357576,  6.10090721,  5.68664128,  5.48748692],
       [12.34604239,  7.78770185,  6.76075998,  6.78498685],
       [21.24998531, 12.95180254, 11.9511704 , 11.87319933],
       [ 7.80693733,  4.83117128,  4.27443559,  4.39602348],
       [16.28983576,  9.66683929,  9.24891886,  9.28970032],
       [ 2.50706736,  1.53153206,  1.36060018,  1.43002765],
       [ 3.73938645,  2.06006639,  2.31013974,  2.09378969],
       [20.2174725 , 11.88622367, 12.05106468, 11.05325362],
       [ 9.48660008,  5.53665456,  5.54171966,  5.34966654],
       [ 2.65812987,  1.64102742,  1.67392209,  1.25083707]], dtype=torch.float64)


QCD_mass = 30.
#rate=torch.tensor([QCD_rate,QCD_rate]) #Entries: [root node, every other node] decaying rates. Choose same values for a QCD jet
jetdir = np.array([1,1,1])
jetP = 400.
jetvec = jetP * jetdir / np.linalg.norm(jetdir) ## Jetvec is 3-momentum. JetP is relativistic p.


# Actual parameters
pt_min = torch.tensor(0.3**2)
M2start = torch.tensor(QCD_mass**2)
jetM = np.sqrt(M2start.numpy()) ## Mass of initial jet
jet4vec = np.concatenate(([np.sqrt(jetP**2 + jetM**2)], jetvec))
minLeaves = 1
maxLeaves = 10000 # unachievable, to prevent rejections
maxNTry = 100

def dummy_bernoulli(self, jet):
    return True

class SimulatorModelDIS(invMass_ginkgo.SimulatorModel, ModelDIS):
    def forward(self, inputs=None):
        assert inputs is None # Modify code if this ever not met?
        # Sample parameter of interest from Unif(0,10) prior
        root_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                            name="decay_rate_parameter")
        decay_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                   name="decay_rate_parameter")
        # Simulator code needs two decay rates for (1) root note (2) all others
        # For now both are set to the same value
        inputs = [root_rate, decay_rate]
        return super().forward(inputs)

# Make instance of the simulator
simulator = SimulatorModelDIS(jet_p=jet4vec,  # parent particle 4-vector
                              pt_cut=float(pt_min),  # minimum pT for resulting jet
                              Delta_0= M2start,  # parent particle mass squared -> needs tensor
                              M_hard=jetM,  # parent particle mass
                              minLeaves=1,  # minimum number of jet constituents
                              maxLeaves=10000,  # maximum number of jet constituents (a large value to stop expensive simulator runs)
                              bool_func=dummy_bernoulli,
                              suppress_output=True,
                              obs_leaves=obs_leaves)


In [10]:
# Test run
simulator.train(iterations = 2,
    importance_sample_size=5000, 
    proposal_mixture_components=5,
    observe_embeddings={'bool_func': {'dim': 1, 'depth': 1}} # Dummy value as we currently have to observe something
)

Time spent  | Time remain.| Progress             | Trace     | ESS    | Traces/sec
------------------------------------------------ |  103/5000 |   2.23 | 0.65       
------------------------------------------------ |  330/5000 |   6.45 | 0.61       
------------------------------#----------------- |  629/5000 |  11.25 | 0.62       
------------------------------#----------------- |  655/5000 |   2.98 | 0.62       
------------------------------#----------------- |  688/5000 |   3.03 | 0.62       
0d:00:19:01 | 0d:01:55:31 | ###----------------- |  707/5000 |   3.05 | 0.62       



------------------------------#----------------- |  819/5000 |   4.29 | 0.62       
------------------------------##---------------- |  966/5000 |   5.33 | 0.63       
0d:00:33:06 | 0d:01:39:32 | #####--------------- | 1248/5000 |   4.19 | 0.63       



------------------------------####-------------- | 1543/5000 |   4.40 | 0.63       
------------------------------#####------------- | 1729/5000 |   4.50 | 0.33       
------------------------------######------------ | 1876/5000 |   4.98 | 0.33       
------------------------------######------------ | 1914/5000 |   5.65 | 0.33       
------------------------------######------------ | 2079/5000 |   6.54 | 0.33       
------------------------------########---------- | 2560/5000 |   7.10 | 0.34       
------------------------------########---------- | 2591/5000 |   7.19 | 0.34       
------------------------------########---------- | 2593/5000 |   7.19 | 0.34       
------------------------------#########--------- | 2818/5000 |   7.22 | 0.34       
------------------------------##########-------- | 3039/5000 |   8.59 | 0.34       
------------------------------##########-------- | 3089/5000 |   8.59 | 0.34       
------------------------------###########------- | 3136/5000 |   8.71 | 0.34



------------------------------############------ | 3497/5000 |   9.42 | 0.12       
------------------------------############------ | 3526/5000 |   9.43 | 0.13       
------------------------------#############----- | 3662/5000 |  10.09 | 0.13       
------------------------------##############---- | 4014/5000 |  10.89 | 0.13       
0d:08:43:37 | 0d:01:46:18 | #################--- | 4157/5000 |  11.12 | 0.13       



------------------------------###############--- | 4224/5000 |  11.20 | 0.13       
------------------------------################-- | 4530/5000 |  12.04 | 0.14       
------------------------------################-- | 4586/5000 |  12.04 | 0.14       
------------------------------#################- | 4783/5000 |  12.50 | 0.14       
0d:10:03:19 | 0d:00:00:07 | #################### | 5000/5000 |  13.08 | 0.14       
OfflineDataset at: .
Num. traces      : 500
Sorted on disk   : False
No pre-computed hashes found, generating: ./pyprob_hashes
Hashing offline dataset for sorting
Time spent  | Time remain.| Progress             | Traces  | Traces/sec
0d:00:00:14 | 0d:00:00:00 | #################### | 500/500 | 34.39       
Sorting offline dataset
Sorting done
Num. trace types : 85
Trace hash	Count
82.55686951	3
94.10295868	5
103.18231964	210
139.11848450	16
146.72689819	1
146.82281494	5
154.12643433	2
158.16914368	15
158.19790649	15
166.14773560	1
170.16592407	1
171.16374207	9
174.10971069



------------------------------------------------ |  362/5000 |   3.18 | 0.51       
------------------------------------------------ |  408/5000 |   4.76 | 0.50       
------------------------------------------------ |  472/5000 |   4.96 | 0.49       
------------------------------------------------ |  559/5000 |   5.11 | 0.48       
------------------------------#----------------- |  730/5000 |   5.38 | 0.47       
------------------------------#----------------- |  873/5000 |   5.77 | 0.44       
------------------------------####-------------- | 1465/5000 |   6.60 | 0.43       
------------------------------######------------ | 2001/5000 |   7.63 | 0.41       
------------------------------#######----------- | 2291/5000 |  10.07 | 0.32       
0d:02:13:18 | 0d:01:48:03 | ###########--------- | 2762/5000 |  11.15 | 0.35       

In [7]:
simulator.save_inference_network('ginkgo_network2')