In [None]:
from __future__ import absolute_import, division, print_function

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20, 18)

from matplotlib import patches
from matplotlib import animation, rc
import matplotlib.lines as mlines

import ROOT
import fastjet as fj
import fjext
import fjcontrib
import fjtools

import pythia8
import pythiafjext
import pythiaext
from heppy.pythiautils import configuration as pyconf

from tqdm.notebook import tqdm
import argparse
import os
import sys

# standard numerical library imports
import numpy as np
rng = np.random.RandomState(0)

# matplotlib is required for this example
from IPython.core.display import display, HTML
display(HTML("<style>div.output_scroll { height: 500em; }</style>"))


In [None]:
def get_args_from_settings(ssettings):
    sys.argv=[' '] + ssettings.split()
    parser = argparse.ArgumentParser(description='pythia8 fastjet on the fly')
    pyconf.add_standard_pythia_args(parser)
    parser.add_argument('--output', default="test_ang_ue.root", type=str)
    parser.add_argument('--user-seed', help='pythia seed', default=1111, type=int)
    args = parser.parse_args()
    return args

In [None]:
mycfg = []
ssettings = "--py-ecm 5000 --user-seed=100000 --nev 1000"
args = get_args_from_settings(ssettings)
pythia_hard = pyconf.create_and_init_pythia_from_args(args, mycfg)

In [None]:
max_eta_hadron=2
parts_selector_h = fj.SelectorAbsEtaMax(max_eta_hadron)
jet_R0 = 0.4
jet_selector = fj.SelectorPtMin(100.0) & fj.SelectorPtMax(105.0) & fj.SelectorAbsEtaMax(max_eta_hadron - 1.05 * jet_R0)

In [None]:
# print the banner first
fj.ClusterSequence.print_banner()
print()
# set up our jet definition and a jet selector
jet_R0 = 0.4
jet_def = fj.JetDefinition(fj.antikt_algorithm, jet_R0)
print(jet_def)

sj_rs = [0.1, 0.2, 0.3]
sj_defs = []
for r in sj_rs:
    _sj_def = fj.JetDefinition(fj.antikt_algorithm, 0.1)
    print(_sj_def)
    sj_defs.append(_sj_def)


In [None]:
def print_jet_constits(j):
    print('[jet]', j.perp(), j.phi(), j.eta(), ':', [_c.user_index() for _c in j.constituents()])

In [None]:
# for n in tqdm(range(args.nev)):
def next_event():
    subjets = []
    while (1):
        if not pythia_hard.next():
            continue
        parts_pythia_h = pythiafjext.vectorize_select(pythia_hard, [pythiafjext.kFinal], 0, False)
        parts_pythia_h_selected = parts_selector_h(parts_pythia_h)
        mult_hard = len(parts_pythia_h_selected)
        jets_h = fj.sorted_by_pt(jet_selector(jet_def(parts_pythia_h_selected)))
        if len(jets_h) < 1:
            continue
        j = jets_h[0]

        #make the subjets
        subjets.clear()
        _accept_flag = True
        for sj_def in sj_defs:
            _sjets = fj.sorted_by_pt(sj_def(j.constituents()))
            subjets.append(_sjets)
            if len(_sjets) < 3:
                _accept_flag = False
        if _accept_flag is False:
            continue

        return j, subjets

In [None]:
#j, jbg, jbgsub = next_event()
#print_jet_constits(j)
#print_jet_constits(jbg)
#print_jet_constits(jbgsub)

In [None]:
def draw_jet(j):
    # for the jet but not subjets
    pts = [p.perp() for p in fj.sorted_by_pt(j.constituents())]
    ys = [j.rapidity() - p.rapidity() for p in fj.sorted_by_pt(j.constituents())]
    phis = [j.delta_phi_to(p) for p in fj.sorted_by_pt(j.constituents())]
    
    phis.append(jet_R0)
    phis.append(-jet_R0)
    ys.append(jet_R0)
    ys.append(-jet_R0)
    pts.append(0)
    pts.append(0)
    zs = [pt/j.perp() for pt in pts]
    zs_sized = [z*1000. for z in zs]
    cs = [int(z*100.) for z in zs]

    plt.figure()
    # plt.scatter(phis, ys, c=colors, s=zs_sized, alpha=0.4, cmap="PuOr") #cmap='viridis')
    plt.scatter(phis, ys, c=cs, s=zs_sized, alpha=0.4, cmap="magma") #cmap='viridis')
    
    plt.xlabel('phi')
    plt.ylabel('y')
    plt.colorbar();  # show color scale
    plt.show()
    #plt.rcdefaults()

In [None]:
import importlib
import subjets_geometry as sjgeom
importlib.reload(sjgeom)


In [None]:
def draw_subjets(j, scale_pt=0, sj_r = 0.1):
    sj_def = fj.JetDefinition(fj.antikt_algorithm, sj_r)
    sjs = fj.sorted_by_pt(sj_def(j.constituents()))
    # for subjets
    pts = []
    ys = []
    phis = []
    cs = []
    lines = []
    circles = []
    for i,sj in enumerate(sjs):
        sc = fj.sorted_by_pt(sj.constituents())
        pts.extend([p.perp() for p in sc])
        ys.extend([j.rapidity() - p.rapidity() for p in sc])
        phis.extend([j.delta_phi_to(p) for p in sc])
        if i == 0:
            _col = [1, 0, 0, 0.3]
            for p in sc:
                cs.append(_col)
                part_dphi = j.delta_phi_to(p)
                part_deta = j.rapidity() - p.rapidity()
                sj_dphi = j.delta_phi_to(sj)
                sj_deta = j.rapidity() - sj.rapidity()
                lines.append([sj_dphi, part_dphi,
                             sj_deta, part_deta,
                             cs[-1]])
            circles.append([sj_dphi, sj_deta, sj_r, _col])
            continue
        if i == 1:
            _col = [0.1, .75, 0.1, 0.3]
            for p in sc:
                cs.append(_col)
                part_dphi = j.delta_phi_to(p)
                part_deta = j.rapidity() - p.rapidity()
                sj_dphi = j.delta_phi_to(sj)
                sj_deta = j.rapidity() - sj.rapidity()
                lines.append([sj_dphi, part_dphi,
                             sj_deta, part_deta,
                             cs[-1]])
            circles.append([sj_dphi, sj_deta, sj_r, _col])
            continue
        if i == 2:
            _col = [0, 0, 1, 0.3]
            for p in sc:
                cs.append(_col)
                part_dphi = j.delta_phi_to(p)
                part_deta = j.rapidity() - p.rapidity()
                sj_dphi = j.delta_phi_to(sj)
                sj_deta = j.rapidity() - sj.rapidity()
                lines.append([sj_dphi, part_dphi,
                             sj_deta, part_deta,
                             cs[-1]])
            circles.append([sj_dphi, sj_deta, sj_r, _col])
            continue
        if i > 2:
            gr_scale = (len(sjs) - i - 2)/len(sjs)
            # gr_col = 1.0 - 0.3 * gr_scale
            gr_col = 1.0 - (0.1 * (i - 2))
            if gr_col <= 0.1:
                gr_col = 0.1
            gr_col = 1.0 - gr_col
            print(gr_col)
            _col = [gr_col, gr_col, gr_col, 0.3]
            for p in sc:
                cs.append(_col)
                part_dphi = j.delta_phi_to(p)
                part_deta = j.rapidity() - p.rapidity()
                sj_dphi = j.delta_phi_to(sj)
                sj_deta = j.rapidity() - sj.rapidity()
                lines.append([sj_dphi, part_dphi,
                             sj_deta, part_deta,
                             cs[-1]])
            circles.append([sj_dphi, sj_deta, sj_r, _col])

    phis.append(jet_R0)
    phis.append(-jet_R0)
    ys.append(jet_R0)
    ys.append(-jet_R0)
    pts.append(0)
    pts.append(0)
    if scale_pt == 0:
        scale_pt = j.perp()
    zs = [pt/scale_pt for pt in pts]
    zs_sized = [z*1000. for z in zs]
    cs.append([0,0,0,0])
    cs.append([0,0,0,0])

    fig = plt.figure()
    # plt.scatter(phis, ys, c=colors, s=zs_sized, alpha=0.4, cmap="PuOr") #cmap='viridis')
    # plt.scatter(phis, ys, c=cs, s=zs_sized, alpha=0.4, cmap="magma") #cmap='viridis')
    plt.scatter(phis, ys, c=cs, s=zs_sized, alpha=0.4, cmap="magma") #cmap='viridis')
    
    plt.xlabel(r"$\Delta\varphi$")
    plt.ylabel(r"$\Delta y$")

    ax = fig.axes[0]
    # transform = ax.transAxes
    transform = ax.transData

    # plt.colorbar();  # show color scale    
    for l in lines:
        # _line = mlines.Line2D([l[0], l[1]], [l[2], l[3]], color='red')
        _line = mlines.Line2D([l[0], l[1]], [l[2], l[3]], color=l[4])
        # _line = mlines.Line2D([0, phis[i]], [0, ys[i]], color='red')
        # _line = mlines.Line2D([0, 1], [0.4, -0.4], color='red')
        _line.set_transform(transform)
        ax.add_line(_line)
    # fig.add_artist(line)
    
    c_patches = []
    for ic, c in enumerate(circles):
        _pc = sjgeom.SubjetPatch([c[0], c[1]], sj_r, 0.4)
        c_patches.append(_pc)

    for ic, pc in enumerate(c_patches):
        for pcx in c_patches[:ic]:
            #pc.resolve_overlap_with_subjet_patch(pcx)
            #pc.resolve_overlap_with_sj(pcx)
            pc.clip_to_sj(pcx)

    for ic, _pc in enumerate(c_patches):
        _pc.plot_patch_ax(ax, circles[ic][3], 0)

    ax.set_xlim(-0.4, 0.4)
    ax.set_ylim(-0.4, 0.4)

    plt.show()
    print(j.perp(), sjs[0].perp(), sjs[0].perp()/j.perp())
    # plt.rcdefaults()


In [None]:
save_jets = []
for n in range(5):
    j, subjets = next_event()
    save_jets.append(j)
    # draw_jet(j)
    for sj_r in sj_rs:
        draw_subjets(j, j.perp()/4., sj_r)


In [None]:
pythia_hard.stat()

In [None]:
with open('nice_jets', 'w') as f:
	for j in save_jets:
		print (j, file=f)
		for c in fj.sorted_by_pt(j.constituents()):
			print(' ', c, file=f)


In [None]:
import importlib
import plot_subjets_utils as sjplt
importlib.reload(sjplt)


In [None]:
save_jets = []
for n in range(1):
    j, subjets = next_event()
    save_jets.append(j)
    # draw_jet(j)
    p01 = sjplt.SubjetPlot(j, 0.4, 1000)
    p01.plot(scale_pt=0, sj_r=0.1)
    p02 = sjplt.SubjetPlot(j, 0.4, 1000)
    p02.plot(scale_pt=0, sj_r=0.2)
    p03 = sjplt.SubjetPlot(j, 0.4, 1000)
    p03.plot(scale_pt=0, sj_r=0.3)

    #for sj_r in sj_rs:
    #    s(j, j.perp()/4., sj_r)


In [None]:
with open('nice_jets', 'w') as f:
	for j in save_jets:
		print (j, file=f)
		for c in fj.sorted_by_pt(j.constituents()):
			print(' ', c, file=f)

In [None]:
# now this one is for making movies...
number_of_frames=10
save_jets = []
for n in range(1):
    j, subjets = next_event()
    counter = 0
    sj_r = 0.001
    while sj_r < 0.4:
        p = sjplt.SubjetPlot(j, 0.4, number_of_frames)
        p.plot(scale_pt=0, sj_r=sj_r)
        plt.savefig('{}_subjets.png'.format(counter))
        counter += 1
        sj_r += 0.001