In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['KMP_WARNINGS'] = 'off'
import sys
import git

import uproot as ut
import awkward as ak
import numpy as np
import math
import vector
import sympy as sp

import re
from tqdm import tqdm
import timeit
import re

sys.path.append( git.Repo('.', search_parent_directories=True).working_tree_dir )
from utils import *

In [2]:
%load_ext line_profiler
%load_ext memory_profiler

In [3]:
from utils.FeynNet.Feynman import Feynman

In [4]:
diagram = Feynman('x').decays(
    Feynman('t').decays(
        Feynman('b').decays('j'),
        Feynman('w').decays('j','j')
    ),
    Feynman('t').decays(
        Feynman('b').decays('j'),
        Feynman('w').decays('j','j')
    )
).build_diagram()

In [5]:
def factorial(n):
    if n == 0:
        return 1
    return n * factorial(n - 1)

In [6]:
from utils.FeynNet.FeynNet import FeynNet

In [66]:
import itertools

In [16]:
f(j=10)

6


KeyboardInterrupt: 

In [11]:
def f(**nfinalstates):
    diagram._permutation_cache_.clear()
    return diagram.get_finalstate_permutations(**nfinalstates)
%timeit f(j=6)
%timeit f(j=7)
%timeit f(j=8)


74.8 ms ± 2.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
176 ms ± 4.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
838 ms ± 26.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit f(j=9)
%timeit f(j=10)

7.76 s ± 38.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


KeyboardInterrupt: 

In [79]:
single_permutations = f()


In [85]:
def f():
    diagram._permutation_cache_.clear()
    return diagram.get_finalstate_permutations(**nfinalstates, n_jobs=-1)
%timeit f()
# parallel_permutations = f()

3 s ± 206 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
assert np.allclose(single_permutations['j'], parallel_permutations['j'])
assert np.allclose(single_permutations['b'], parallel_permutations['b'])


In [14]:
np.allclose(trueMap, paraMap)

True

In [17]:
permiter = itertools.permutations(np.arange(10))

class SliceIterator:
    def __init__(self, it, njobs, ijob):
        self.it = it
        self.njobs = njobs
        self.ijob = ijob
    def __iter__(self):
        for i, p in enumerate(self.it):
            if i%self.njobs == self.ijob:
                yield i, p

def slice_iter(it, njobs=2):
    its = itertools.tee(it, njobs)
    return [ SliceIterator(it, njobs, i) for i, it in enumerate(its) ]

def get_permutations(diagram, permit):
    permMap = {}
    for iperm ,perm in permit:
        perm = perm[:6]
        reco_id = diagram.get_reco_id(j=perm)
        if reco_id not in permMap:
            permMap[reco_id] = (iperm, perm)
    return permMap

njobs = 5
iter_slices = slice_iter(permiter, njobs=njobs)

from multiprocessing import Pool

class Worker:
    @staticmethod 
    def start(worker):
        return worker()

    def __init__(self, diagram, permit):
        self.diagram = diagram
        self.permit = permit

    def __call__(self):
        self.result = get_permutations(self.diagram, self.permit)
        return self

with Pool(njobs) as pool:
    workers = list( tqdm( pool.starmap(get_permutations, ((diagram, permit) for permit in iter_slices)), total=njobs) )

100%|██████████| 5/5 [00:00<00:00, 17175.69it/s]


In [18]:
workers

[{-785540713582192026: (0, (0, 1, 2, 3, 4, 5)),
  5459826315071848245: (25, (0, 1, 2, 3, 4, 6)),
  7639047042285513462: (50, (0, 1, 2, 3, 4, 7)),
  -203112063285740678: (75, (0, 1, 2, 3, 4, 8)),
  -6364507362123037761: (100, (0, 1, 2, 3, 4, 9)),
  4070019417317376963: (145, (0, 1, 2, 3, 5, 6)),
  8832964218323503357: (170, (0, 1, 2, 3, 5, 7)),
  9181182885108433504: (195, (0, 1, 2, 3, 5, 8)),
  -8219171639012249673: (220, (0, 1, 2, 3, 5, 9)),
  2436597756508557713: (290, (0, 1, 2, 3, 6, 7)),
  2623716134459799649: (315, (0, 1, 2, 3, 6, 8)),
  5522312379146233137: (340, (0, 1, 2, 3, 6, 9)),
  -2498076612794249943: (435, (0, 1, 2, 3, 7, 8)),
  7820759771392396245: (460, (0, 1, 2, 3, 7, 9)),
  4880638220450592878: (580, (0, 1, 2, 3, 8, 9)),
  5478145358627943706: (720, (0, 1, 2, 4, 3, 5)),
  8925968446663913157: (745, (0, 1, 2, 4, 3, 6)),
  5670844194188701195: (770, (0, 1, 2, 4, 3, 7)),
  6391498434043397251: (795, (0, 1, 2, 4, 3, 8)),
  5539268978315546835: (820, (0, 1, 2, 4, 3, 9)),
  

In [26]:
from collections import defaultdict 

def merge_permMaps(permMaps):
    master_permMap = defaultdict(list)
    for permMap in permMaps:
        for reco_id, perm in permMap.items():
            master_permMap[reco_id].append(perm)

    perms = [ min(perms,key=lambda kv:kv[0])  for perms in master_permMap.values() ]
    perms = sorted(perms, key=lambda kv:kv[0])
    return np.array([ perm[1] for perm in perms ])

In [27]:
master_permMap = merge_permMaps([ worker.result for worker in workers ])

In [28]:
master_permMap.shape

(18900, 6)

In [29]:
np.all(master_permMap == trueMap)

True

In [17]:
feynnet = FeynNet(diagram=diagram,nfinalstates=dict(j=8))

In [78]:
jet_x = torch.rand(1024, 32, 10)

In [79]:
%timeit feynnet(j=jet_x)

9.67 s ± 430 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
