In [None]:
%matplotlib inline
import bz2
import pickle
import pandas
import numpy as np
import matplotlib.pyplot as plt

import plotly.graph_objects as go
import plotly
import glob

In [None]:
def visualize(sample, data, iev, trk_opacity=0.8):
    Xelem = pandas.DataFrame(data[iev]["Xelem"])
    ycand = pandas.DataFrame(data[iev]["ycand"])
    ygen = pandas.DataFrame(data[iev]["ygen"])

    eta_range = 1000
    radius_mult = 2000

    trk_x = []
    trk_y = []
    trk_z = []
    for irow, row in Xelem[Xelem["typ"] == 1].iterrows():
        trk_x += [0, 1 * radius_mult * np.cos(row["phi"])]
        trk_z += [0, 1 * radius_mult * np.sin(row["phi"])]
        trk_y += [0, eta_range * row["eta"]]
        if row["phi_ecal"] != 0:
            trk_x += [1.5 * radius_mult * np.cos(row["phi_ecal"])]
            trk_z += [1.5 * radius_mult * np.sin(row["phi_ecal"])]
            trk_y += [eta_range * row["eta_ecal"]]
        if row["phi_hcal"] != 0:
            trk_x += [2 * radius_mult * np.cos(row["phi_hcal"])]
            trk_z += [2 * radius_mult * np.sin(row["phi_hcal"])]
            trk_y += [eta_range * row["eta_hcal"]]
        trk_x += [None]
        trk_z += [None]
        trk_y += [None]

    points_trk = go.Scatter3d(
        x=trk_x,
        z=trk_z,
        y=trk_y,
        mode="lines",
        line=dict(color="rgba(10, 10, 10, {})".format(trk_opacity)),
        name="Tracks",
        hoverinfo="skip",
    )

    trk_x = []
    trk_y = []
    trk_z = []
    for irow, row in Xelem[Xelem["typ"] == 6].iterrows():
        trk_x += [0, 1 * radius_mult * np.cos(row["phi"])]
        trk_z += [0, 1 * radius_mult * np.sin(row["phi"])]
        trk_y += [0, eta_range * row["eta"]]
        if row["phi_ecal"] != 0:
            trk_x += [1.5 * radius_mult * np.cos(row["phi_ecal"])]
            trk_z += [1.5 * radius_mult * np.sin(row["phi_ecal"])]
            trk_y += [eta_range * row["eta_ecal"]]
        if row["phi_hcal"] != 0:
            trk_x += [2 * radius_mult * np.cos(row["phi_hcal"])]
            trk_z += [2 * radius_mult * np.sin(row["phi_hcal"])]
            trk_y += [eta_range * row["eta_hcal"]]
        trk_x += [None]
        trk_z += [None]
        trk_y += [None]

    points_gsf = go.Scatter3d(
        x=trk_x,
        z=trk_z,
        y=trk_y,
        mode="lines",
        line=dict(color="rgba(10, 10, 10, {})".format(trk_opacity)),
        name="GSF",
        # hoverinfo="skip"
    )

    msk = Xelem["typ"] == 2
    points_ps1 = go.Scatter3d(
        x=1 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=1 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 1000 * Xelem[msk]["e"],
            # "size": 2.0
        },
        name="PS1",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 3
    points_ps2 = go.Scatter3d(
        x=1 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=1 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 1000 * Xelem[msk]["e"],
        },
        name="PS2",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 4
    points_ecal = go.Scatter3d(
        x=1.5 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=1.5 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 10 * np.log10(Xelem[msk]["e"] + 1.0),
            # "size": 2.0
        },
        name="ECAL",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 5
    points_hcal = go.Scatter3d(
        x=2 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=2 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 5 * np.log10(Xelem[msk]["e"] + 1.0),
            # "size": 2.0
        },
        name="HCAL",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 8
    points_hfem = go.Scatter3d(
        x=2 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=2 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 5 * np.log10(Xelem[msk]["e"] + 1.0),
            # "size": 2.0
        },
        name="HFEM",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 9
    points_hfhad = go.Scatter3d(
        x=2 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=2 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 5 * np.log10(Xelem[msk]["e"] + 1.0),
            # "size": 2.0
        },
        name="HFHAD",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 10
    points_sc = go.Scatter3d(
        x=1.5 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=1.5 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 5 * np.log10(Xelem[msk]["e"] + 1.0),
            # "size": 2.0
        },
        name="SC",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = Xelem["typ"] == 11
    points_ho = go.Scatter3d(
        x=2.1 * radius_mult * np.cos(Xelem[msk]["phi"].values),
        z=2.1 * radius_mult * np.sin(Xelem[msk]["phi"].values),
        y=eta_range * Xelem[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "square",
            "opacity": 0.8,
            "size": 5 * np.log10(Xelem[msk]["e"] + 1.0),
            # "size": 2.0
        },
        name="HO",
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=["E={:.2f}".format(x["e"]) for _, x in Xelem[msk].iterrows()],
    )

    msk = ycand["typ"] != 0
    
    points_cand = go.Scatter3d(
        x=2.2 * radius_mult * ycand[msk]["cos_phi"].values,
        z=2.2 * radius_mult * ycand[msk]["sin_phi"].values,
        y=eta_range * ycand[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "x",
            "opacity": 0.8,
            "color": "rgba(0, 0, 0, 0.8)",
            # "size": 2.0
            "size": np.clip(5 * np.log10(ycand[msk]["e"].values + 5.0), 1, 10),
        },
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=[
            "{}<br>E={:.2f}<br>eta={:.2f}<br>phi={:.2f}".format(
                int(x["typ"]), x["e"], x["eta"], np.arctan2(x["sin_phi"], x["cos_phi"])
            )
            for _, x in ycand[msk].iterrows()
        ],
        name="PFCand",
    )

    msk = ygen["typ"] != 0
    points_gen = go.Scatter3d(
        x=2.5 * radius_mult * ygen[msk]["cos_phi"].values,
        z=2.5 * radius_mult * ygen[msk]["sin_phi"].values,
        y=eta_range * ygen[msk]["eta"].values,
        mode="markers",
        marker={
            "symbol": "circle",
            "opacity": 0.8,
            "color": "rgba(50, 0, 0, 0.4)",
            # "size": 2.0
            "size": np.clip(5 * np.log10(ygen[msk]["e"].values + 5), 1, 10),
        },
        hovertemplate="<b>%{hovertext}</b>",
        hovertext=[
            "{}<br>E={:.2f}<br>eta={:.2f}<br>phi={:.2f}".format(
                int(x["typ"]), x["e"], x["eta"], np.arctan2(x["sin_phi"], x["cos_phi"])
            )
            for _, x in ygen[msk].iterrows()
        ],
        name="MLPF truth",
    )

    fig = go.Figure(
        data=[
            points_trk,
            points_gsf,
            points_ps1,
            points_ps2,
            points_ecal,
            points_hcal,
            points_hfem,
            points_hfhad,
            points_sc,
            points_ho,
            points_cand,
            points_gen,
        ]
    )

    fig.update_layout(
        autosize=True,
        # width=800,
        # height=1200,
        #     margin=go.layout.Margin(
        #         l=50,
        #         r=0,
        #         b=0,
        #         t=50,
        #     ),
        scene_camera={
            "eye": dict(x=0.8, y=0.8, z=0.8),
        },
        scene={
            "xaxis": dict(
                nticks=1, range=[-5000, 5000], showaxeslabels=False, showticklabels=False, showgrid=False, visible=True
            ),
            "yaxis": dict(
                nticks=1, range=[-5000, 5000], showaxeslabels=False, showticklabels=False, showgrid=False, visible=True
            ),
            "zaxis": dict(
                nticks=1, range=[-5000, 5000], showaxeslabels=False, showticklabels=False, showgrid=False, visible=True
            ),
        },
    )
    fig.update_layout(legend={"itemsizing": "constant"})
    #    fig.show()
    s = fig.to_html(default_width="1200px", default_height="800px")
    with open("plot_{}_{}.html".format(sample, iev), "w") as fi:
        fi.write(s)

    with open("plot_{}_{}_data.html".format(sample, iev), "w") as fi:
        fi.write("X")
        fi.write(Xelem.to_html())

        fi.write("ycand")
        fi.write(ycand[ycand["typ"] != 0].to_html())

        fi.write("ygen")
        fi.write(ygen[ygen["typ"] != 0].to_html())

In [None]:
# for sample in [
#     "SinglePiMinusFlatPt0p7To1000_cfi",
#     "SinglePi0Pt1To1000_pythia8_cfi",
#     "SingleTauFlatPt1To1000_cfi",
#     "SingleElectronFlatPt1To1000_pythia8_cfi",
#     "SingleGammaFlatPt1To1000_pythia8_cfi",
#     "SingleNeutronFlatPt0p7To1000_cfi",
#     "SingleMuFlatLogPt_100MeVto2TeV_cfi",
# ]:
#     filelist = sorted(glob.glob("/hdfs/local/joosep/mlpf/gen/v2/{}/raw/*.pkl.bz2".format(sample)))
#     data = pickle.load(bz2.BZ2File(filelist[0], "r"))
#     for iev in range(0, 10):
#         visualize(sample, data, iev)

In [None]:
for sample in [
    "TTbar_14TeV_TuneCUETP8M1_cfi",
]:
    filelist = sorted(glob.glob("/local/joosep/mlpf/cms/v3/nopu/{}/raw/*.pkl.bz2".format(sample)))
    data = pickle.load(bz2.BZ2File(filelist[0], "r"))
    for iev in range(0, 1):
        visualize(sample, data, iev, trk_opacity=0.1)