In [None]:
import matplotlib.pyplot as plt
import numpy as np
import awkward as ak

In [None]:
import sys
sys.path += ["/home/joosep/particlemind"]

In [None]:
from src.datasets.CLDHits import CLDHits

In [None]:
dataset_train = CLDHits("../data/p8_ee_tt_ecm365/parquet/", "train")

In [None]:
elems = []
for elem in dataset_train:
    unique_labels, contiguous_labels = np.unique(elem["hit_labels"], return_inverse=True)
    elem["hit_labels_contiguous"] = contiguous_labels
    elems.append(elem)
    if len(elems)>=100:
        break

elems = [[ak.from_iter(elem)] for elem in elems]
elems = ak.concatenate(elems, axis=0)

In [None]:
plt.hist(ak.max(elems["hit_labels_contiguous"], axis=1), bins=np.linspace(0,400,41));
plt.xlabel("Clusters per event")
plt.ylabel("Event count")

In [None]:
hit_labels_c_f = ak.flatten(elems["hit_labels_contiguous"])
calo_hit_features_f = ak.flatten(elems["calo_hit_features"])

In [None]:
plt.hist(calo_hit_features_f[:, 0], np.linspace(-5000,5000,100), histtype="step", lw=2, label="x")
plt.hist(calo_hit_features_f[:, 1], np.linspace(-5000,5000,100), histtype="step", lw=2, label="y")
plt.hist(calo_hit_features_f[:, 2], np.linspace(-5000,5000,100), histtype="step", lw=2, label="z");
plt.xlabel("Hit position (mm)")
plt.ylabel("Hit count")
plt.legend()

In [None]:
plt.hist(10*calo_hit_features_f[:, 3], np.logspace(-3,1,100))
plt.xscale("log")
plt.xlabel("Hit energy (GeV)")
plt.ylabel("Hit count")

In [None]:
all_cluster_std_x = []
all_cluster_std_y = []
all_cluster_std_z = []
all_cluster_sum_e = []
all_cluster_hit_count = []
all_cluster_id = []

for ielem in range(len(elems)):
    elem = elems[ielem]
    cluster_ids = np.unique(elem["hit_labels_contiguous"])
    cluster_std_x = []
    cluster_std_y = []
    cluster_std_z = []
    cluster_sum_e = []
    cluster_hit_count = []
    cluster_id = []
    for clid in cluster_ids:
        cl_mask = elem["hit_labels_contiguous"]==clid
        std_x = np.std(elem["calo_hit_features"][:, 0][cl_mask])
        std_y = np.std(elem["calo_hit_features"][:, 1][cl_mask])
        std_z = np.std(elem["calo_hit_features"][:, 2][cl_mask])
        sum_e = np.sum(elem["calo_hit_features"][:, 3][cl_mask])
        hit_count = np.sum(cl_mask)
        
        cluster_std_x.append(std_x)
        cluster_std_y.append(std_y)
        cluster_std_z.append(std_z)
        cluster_sum_e.append(sum_e)
        cluster_hit_count.append(hit_count)
        cluster_id.append(clid)
        
    all_cluster_std_x.append(cluster_std_x)
    all_cluster_std_y.append(cluster_std_y)
    all_cluster_std_z.append(cluster_std_z)
    all_cluster_sum_e.append(cluster_sum_e)
    all_cluster_hit_count.append(cluster_hit_count)
    all_cluster_id.append(cluster_id)


all_cluster_std_x = ak.Array(all_cluster_std_x)
all_cluster_std_y = ak.Array(all_cluster_std_y)
all_cluster_std_z = ak.Array(all_cluster_std_z)
all_cluster_sum_e = ak.Array(all_cluster_sum_e)
all_cluster_hit_count = ak.Array(all_cluster_hit_count)
all_cluster_id = ak.Array(all_cluster_id)

In [None]:
plt.hist2d(
    ak.to_numpy(ak.flatten(all_cluster_hit_count[all_cluster_hit_count>5])),
    ak.to_numpy(ak.flatten(all_cluster_std_x[all_cluster_hit_count>5])),
    bins=(np.logspace(0,3,100), np.logspace(-2,4,100))
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Hits per cluster")
plt.ylabel("Hit pos x stddev")

In [None]:
plt.hist2d(
    ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count))),
    ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_y))),
    bins=(np.logspace(0,3,100), np.logspace(-2,4,100))
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Hits per cluster")
plt.ylabel("Hit pos y stddev")

In [None]:
plt.hist2d(
    ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count[all_cluster_hit_count>5]))),
    ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_z[all_cluster_hit_count>5]))),
    bins=(np.logspace(0,3,100), np.logspace(-2,4,100))
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Hits per cluster")
plt.ylabel("Hit pos z stddev")

In [None]:
plt.figure(figsize=(5,5))
plt.hist2d(
    ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count))),
    ak.to_numpy(ak.flatten(ak.Array(all_cluster_sum_e))),
    bins=(np.logspace(0,3,100), np.logspace(-2,3,100))
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Hits per cluster")
plt.ylabel("Sum energy per cluster")

In [None]:
plt.hist(ak.flatten(all_cluster_hit_count), bins=np.linspace(0,1500,100));
plt.yscale("log")
plt.xlabel("Number of hits per cluster")
plt.ylabel("Cluster count")

In [None]:
fig, axs = plt.subplots(3,3, figsize=(10,10))
axs = axs.flatten()
for ielem in range(9):
    plt.sca(axs[ielem])
    elem = elems[ielem]
    
    unique_labels, contiguous_labels = np.unique(elem["hit_labels"], return_inverse=True)
    cmap = plt.get_cmap('viridis')
    distinct_colors = cmap(np.linspace(0, 1, len(unique_labels)))
    
    plt.scatter(
        elem["calo_hit_features"][:, 0],
        elem["calo_hit_features"][:, 1],
        s=np.clip(100*elem["calo_hit_features"][:, 3], 0.1, 10),
        c=distinct_colors[contiguous_labels])
    plt.xlim(-6000, 6000)
    plt.ylim(-6000, 6000)
    plt.title("$N_{{hit}}$={}, $N_{{cl}}$={}".format(len(elem["calo_hit_features"]), len(np.unique(elem["hit_labels"]))))
    plt.xticks([])
    plt.yticks([])