In [5]:
import pyprob
import numpy as np
import ot
import torch
import cProfile

from pyprob.dis import ModelDIS
from showerSim import invMass_ginkgo, simulator
from torch.utils.data import DataLoader
from pyprob.nn.dataset import OnlineDataset
from pyprob.util import InferenceEngine, TraceMode
from pyprob.util import to_tensor
from pyprob import Model, state
from pyprob.model import Parallel_Generator
import math
import time
import torch.multiprocessing as mp
from pyprob.distributions import Normal
from pyprob.distributions.delta import Delta


import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as mpl_cm
plt.ion()

import sklearn as skl
from sklearn.linear_model import LinearRegression

from geomloss import SamplesLoss
sinkhorn = SamplesLoss(loss="sinkhorn", p=1, blur=.05)
def sinkhorn_t(x,y):
    x = to_tensor(x)
    y = torch.stack(y)
    return sinkhorn(x,y)

def ot_dist(x,y):
    # x = to_tensor(x)
    # y = torch.stack(y)
    x = np.array(x)
    y = np.array(torch.stack(y))
    a = ot.unif(len(x))
    b = ot.unif(len(y))
    Mat = ot.dist(x, y, metric='euclidean')
    #Mat1 /= Mat1.max()
    distance = to_tensor(ot.emd2(a,b,Mat))
    return distance


device = "cpu"

from pyprob.util import set_device
set_device(device)

obs_leaves = to_tensor([[44.57652381, 26.16169856, 25.3945314 , 25.64598258],
                           [18.2146321 , 10.70465096, 10.43553391, 10.40449709],
                           [ 6.47106713,  4.0435395,  3.65545951,  3.48697568],
                           [ 8.43764314,  5.51040615,  4.60990593,  4.42270416],
                           [26.61664145, 16.55894826, 14.3357362 , 15.12215264],
                           [ 8.62925002,  3.37121204,  5.19699   ,  6.00480461],
                           [ 1.64291837,  0.74506775,  1.01003622,  1.05626017],
                           [ 0.75525072,  0.3051808 ,  0.45721085,  0.51760643],
                           [39.5749915 , 18.39638928, 24.24717939, 25.29349408],
                           [ 4.18355659,  2.11145474,  2.82071304,  2.25221316],
                           [ 0.82932922,  0.29842766,  0.5799056 ,  0.509021  ],
                           [ 3.00825023,  1.36339397,  1.99203677,  1.79428211],
                           [ 7.20024308,  4.03280868,  3.82379277,  4.57441754],
                           [ 2.09953618,  1.28473579,  1.03554351,  1.29769683],
                           [12.21401828,  6.76059035,  6.94920042,  7.42823701],
                           [ 6.91438054,  3.68417135,  3.83782514,  4.41656731],
                           [ 1.97218904,  1.01632927,  1.08008339,  1.27454585],
                           [ 8.58164301,  5.06157833,  4.79691164,  4.99553141],
                           [ 5.97809522,  3.26557958,  3.4253764 ,  3.64894791],
                           [ 5.22842301,  2.94437891,  3.10292633,  3.00551074],
                           [15.40023764,  9.10884407,  8.93836964,  8.61970667],
                           [ 1.96101346,  1.24996337,  1.06923988,  1.06743143],
                           [19.81054106, 11.90268453, 11.60989346, 10.76953856],
                           [18.79470876, 11.429855  , 10.8377334 , 10.25112761],
                           [25.74331932, 15.63430056, 14.83860792, 14.07189108],
                           [ 9.98357576,  6.10090721,  5.68664128,  5.48748692],
                           [12.34604239,  7.78770185,  6.76075998,  6.78498685],
                           [21.24998531, 12.95180254, 11.9511704 , 11.87319933],
                           [ 7.80693733,  4.83117128,  4.27443559,  4.39602348],
                           [16.28983576,  9.66683929,  9.24891886,  9.28970032],
                           [ 2.50706736,  1.53153206,  1.36060018,  1.43002765],
                           [ 3.73938645,  2.06006639,  2.31013974,  2.09378969],
                           [20.2174725 , 11.88622367, 12.05106468, 11.05325362],
                           [ 9.48660008,  5.53665456,  5.54171966,  5.34966654],
                           [ 2.65812987,  1.64102742,  1.67392209,  1.25083707]])


QCD_mass = to_tensor(30.)
#rate=to_tensor([QCD_rate,QCD_rate]) #Entries: [root node, every other node] decaying rates. Choose same values for a QCD jet
jetdir = to_tensor([1.,1.,1.])
jetP = to_tensor(400.)
jetvec = jetP * jetdir / torch.linalg.norm(jetdir) ## Jetvec is 3-momentum. JetP is relativistic p.


# Actual parameters
pt_min = to_tensor(0.3**2)
M2start = to_tensor(QCD_mass**2)
jetM = torch.sqrt(M2start) ## Mass of initial jet
jet4vec = torch.cat((torch.sqrt(jetP**2 + jetM**2).reshape(-1), jetvec))
minLeaves = 1
maxLeaves = 10000 # unachievable, to prevent rejections
maxNTry = 100



class SimulatorModelDIS(invMass_ginkgo.SimulatorModel, ModelDIS):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def dummy_bernoulli(self, jet):
        return True

    def forward(self, inputs=None):
        assert inputs is None # Modify code if this ever not met?
        # Sample parameter of interest from Unif(0,10) prior
        root_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                  name="decay_rate_parameter")
        decay_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                   name="decay_rate_parameter")
        # Simulator code needs two decay rates for (1) root note (2) all others
        # For now both are set to the same value
        inputs = [root_rate, decay_rate]
        jet = super().forward(inputs)
        delta_val = self.dummy_bernoulli(jet)
        bool_func_dist = pyprob.distributions.Bernoulli(delta_val)
        pyprob.observe(bool_func_dist, name = "dummy")
        return jet

# Make instance of the simulator
simulatorginkgo = SimulatorModelDIS(jet_p=jet4vec,  # parent particle 4-vector
                                    pt_cut=float(pt_min),  # minimum pT for resulting jet
                                    Delta_0= M2start,  # parent particle mass squared -> needs tensor
                                    M_hard=jetM,  # parent particle mass
                                    minLeaves=1,  # minimum number of jet constituents
                                    maxLeaves=10000,  # maximum number of jet constituents (a large value to stop expensive simulator runs)
                                    suppress_output=True,
                                    obs_leaves=obs_leaves,
                                    dist_fun=sinkhorn_t)

In [6]:
next(simulatorginkgo._trace_generator())

Trace(variables:89, observable:89, observed:18, tagged:0, controlled:2, uncontrolled:69, log_prob:tensor(-4.6032), log_importance_weight:0.0)

In [7]:
class Parallel_Generator(Dataset):
    """
    Generates datasets for parallelisation by PyTorch dataloader methods.
    """
    def __init__(self, model, importance_sample_size = None, observe = None):
        self._model = model
        self._generator = model._trace_generator
        self._inference_network = model._inference_network
        self._length = importance_sample_size
        self._observe = observe
    
    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        if self._model._inference_network:
            return next(self._generator(trace_mode=TraceMode.POSTERIOR,inference_engine = InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK, inference_network = self._inference_network, observe = self._observe))
        else:
            return next(self._generator(trace_mode=TraceMode.PRIOR))

In [8]:
dataset = Parallel_Generator(simulatorginkgo, importance_sample_size=1000)
dataset[0]

Trace(variables:84, observable:84, observed:17, tagged:0, controlled:2, uncontrolled:65, log_prob:tensor(-4.6032), log_importance_weight:0.0)

In [9]:
dataloader = DataLoader(dataset, num_workers=8, batch_size = None)

for i in dataloader:
    pass

RuntimeError: received 0 items of ancdata

In [13]:
gen_test = simulatorginkgo._trace_generator(trace_mode=TraceMode.PRIOR_FOR_INFERENCE_NETWORK)

In [22]:
for i in range(1000):
    next(gen_test)
    pass

In [28]:
dataset = Parallel_Generator(simulatorginkgo._trace_generator, importance_sample_size=100, inference_engine=True, observe={'dummy':1})

In [33]:
dataset[0]

Posterior Sample


Trace(variables:349, observable:349, observed:71, tagged:0, controlled:278, uncontrolled:0, log_prob:tensor(-128.0414), log_importance_weight:-1.1920928244535389e-07)

In [45]:
gen_test = simulatorginkgo._trace_generator(trace_mode=TraceMode.POSTERIOR, inference_engine=InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK,observe= {'dummy':1}, inference_network=simulatorginkgo._inference_network)

In [46]:
next(gen_test)

Trace(variables:425, observable:425, observed:86, tagged:0, controlled:339, uncontrolled:0, log_prob:tensor(-299.3549), log_importance_weight:-140.41181634692475)

In [58]:
dataset = Parallel_Generator(simulatorginkgo, importance_sample_size=1000, observe = {'dummy':1})
dataset[0]

Posterior Sample


Trace(variables:314, observable:314, observed:64, tagged:0, controlled:250, uncontrolled:0, log_prob:tensor(-217.5419), log_importance_weight:-98.98258972167969)

In [60]:
for i in range(50):
    next(gen_test)
    pass



In [61]:
for i in range(50):
    dataset[0]
    pass

Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




In [62]:
dataset = Parallel_Generator(simulatorginkgo, importance_sample_size=50, observe = {'dummy':1})
dataloader = DataLoader(dataset, num_workers=8, batch_size = None)

for i in dataloader:
    pass

Posterior SamplePosterior Sample
Posterior SamplePosterior Sample


Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample


  P = np.sqrt(tp)/2 * np.sqrt( 1 - 2 * (t_child+t_sib)/tp + (t_child - t_sib)**2 / tp**2 )


Posterior Sample




Posterior Sample




Posterior Sample
Posterior Sample




Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample








Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample




Posterior Sample








Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample
Posterior Sample




Posterior Sample




Posterior Sample




Posterior Sample




Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample








Posterior Sample
Posterior Sample
Posterior Sample
Posterior Sample








Posterior Sample




Posterior Sample
Posterior Sample
Posterior Sample




Posterior Sample
Posterior Sample




In [63]:
import time
import multiprocessing as mp

In [73]:
mp.cpu_count()

8

In [72]:
pool = mp.pool(mp.cpu_count())
#next(simulatorginkgo._trace_generator(trace_mode=TraceMode.POSTERIOR,inference_engine = InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK, inference_network = simulatorginkgo._inference_network, observe = {'dummy':1}))




Trace(variables:394, observable:394, observed:80, tagged:0, controlled:314, uncontrolled:0, log_prob:tensor(-244.1229), log_importance_weight:-107.05960771720856)

In [5]:
test_traces = simulatorginkgo._dis_traces(trace_mode = TraceMode.PRIOR)

NameError: name 'simulatorginkgo' is not defined

In [2]:
simulatorginkgo._dis_traces(trace_mode=TraceMode.PRIOR, num_workers=1)

Time spent  | Time remain.| Progress             | Trace     | ESS    | Traces/sec
0d:00:01:01 | 0d:00:45:46 | -------------------- |  110/5000 | 110.00 | 1.78       

KeyboardInterrupt: 

In [3]:
dataset = Parallel_Generator(simulatorginkgo, importance_sample_size=5000, observe = {'dummy':1})

In [5]:
dataloader = DataLoader(dataset, num_workers = 1, batch_size = None)

In [6]:
for (i,j) in enumerate(dataloader):
    if i < 10:
        print(j.log_prob)
    else: break

tensor(-5.5401)
tensor(11.4845)
tensor(-72.1304)
tensor(-1.7900)
tensor(-73.5113)
tensor(-107.7310)
tensor(-9.1543)
tensor(-110.6435)
tensor(-0.0078)
tensor(-5.7799)
