In [None]:
# Imports
import uproot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch.utils.data import Subset
from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm, GINEConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score, precision_recall_curve, roc_curve,
    classification_report, auc
)
from sklearn.utils import resample
import math
import random
import seaborn as sns
from itertools import product
import os
import copy

# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Random Seed for Reproducibility (keep as 42)
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

pd.set_option('display.max_colwidth', None)  # 'None' removes the limit

### Load data


In [None]:
# Look at available trees/classes and branches. 
# load both segs and mc_tracks for all four datasets (signal, internal conversion, michel and beam).
# ---segs--- gives the MC triplet information which can be used to find individual hits in a frame, later used to make graphs. 
# ---mc_tracks--- gives information of tracks found by the standard reconstruction algorithm (used only for final comparison).
import uproot
import numpy as np
import pandas as pd

# File paths
root_file_path   = 'mu3e_5E4_signal_reco.root'
root_file_path2  = 'mu3e_reco_2E4_alltrk_IC.root'
root_file_path3  = 'mu3e_reco_5E4_beam.root'
root_file_path4  = 'mu3e_reco_5E4_alltrk_michel.root'

# Open files
file_signal = uproot.open(root_file_path)
file_IC     = uproot.open(root_file_path2)
file_beam   = uproot.open(root_file_path3)
file_michel = uproot.open(root_file_path4)

# Print available trees/classes
print(f"Signal file classes:   {file_signal.classnames()}")
print(f"IC     file classes:   {file_IC.classnames()}")
print(f"Beam   file classes:   {file_beam.classnames()}")
print(f"Michel file classes:   {file_michel.classnames()}\n")

# Select trees from each dataset
segs10     = file_signal['segs;10'] # signal
frames1    = file_signal['frames;1']
mc_tracks1 = file_signal['mc_tracks;1']

segsIC     = file_IC['segs;6'] # internal conversion
mc_tracksIC = file_IC['mc_tracks;1']

segsbeam   = file_beam['segs;24'] # beam
mc_tracksbeam = file_beam['mc_tracks;2']

segsmichel = file_michel['segs;4'] # Michel
mc_tracksM = file_michel['mc_tracks;1']

# Print available branches
print("Branches in 'segs;10' (signal):      ", segs10.keys())
print("Branches in 'frames;1' (signal):     ", frames1.keys())
print("Branches in 'mc_tracks;1' (signal):  ", mc_tracks1.keys(), "\n")

# Define which segs branches to load
segs_branches = [
    'x00','x10','x20', # triplet hit positions (x,y,z)
    'y00','y10','y20',
    'z00','z10','z20',
    'frameId','mc_tid','mc_pid','mc_type','mc_p','mc_pt',
    'mc_phi','mc_lam','mc_theta'
    # frame id – usually extends to a round number eg 50,000 or 20,000, 
    # track id – an identifier, where all hits of an MC track are assigned the same value,
    # particle id – 11 for electron, -11 for positron, 
    # type – the decay the particles originate from (92 is mu to eee, i'm unsure on the rest), 
    # momentum p, 
    # transverse momentum pT, 
    # azimuthal angle phi – about the z axis, 
    # pitch angle lambda – from the x,y plane projected onto +z direction, 
    # theta and phi histograms look identical but flipped on x-axis.
    # note: when hits have tid, type, p, pT etc = 0, it is noise.
    #       noise hits can sometimes be shared with real MC hits, so any code 
    #       that finds all hits in a frame must save hit info if available.
]
frames_branches = ['frameId','mc_tid','x0','y0','z0','t0']
    # x0, y0, z0 is the position of the very first hit, needed for first_hit flagging when finding 8hit graphs.
tracks_branches = [
    'runId','frameId','n','type','pid','tid','nhit','nfb','nfbc',
    'ntl','n3','n4','n6','n8','p','pt','phi','lam','theta'
    # n6 tells you whether a MC track (id) has at least 6 hits.
    # n8 tells you whether a MC track (id) has at least 8 hits.
    #   n4, n6, n8 are usually 1 or 0, but can be >1 if the track recurls.
    # ntl tells you if a particle reaches the scintillation tile (ie. it is a 6hit track in the forward recurl layers).
    # --each row in mc_tracks is of a unique frame id.
    # --used for final efficiency comparison of standard reco against GNN.
]

# Load the segs DataFrames
segs10_data      = segs10.arrays(segs_branches,   library='pd')
frames1_data     = frames1.arrays(frames_branches, library='pd')

segsIC_data      = segsIC.arrays(segs_branches,   library='pd')
segsmichel_data  = segsmichel.arrays(segs_branches, library='pd')

segsbeam_data    = segsbeam.arrays(segs_branches, library='pd')
# limit to first 4000 unique frames for beam
first4k = np.sort(segsbeam_data['frameId'].unique())[:4000]
segsbeam_data    = segsbeam_data[segsbeam_data['frameId'].isin(first4k)].reset_index(drop=True)

# Now load mc_tracks DataFrames (all branches)
mc_tracks1_data    = mc_tracks1.arrays(tracks_branches, library='pd')
mc_tracksIC_data   = mc_tracksIC.arrays(tracks_branches, library='pd')
mc_tracksbeam_data = mc_tracksbeam.arrays(tracks_branches, library='pd')
# apply the *same* first4k list from segsbeam_data to mc_tracksbeam_data
mc_tracksbeam_data = mc_tracksbeam_data[
    mc_tracksbeam_data['frameId'].isin(first4k)].reset_index(drop=True)
mc_tracksM_data    = mc_tracksM.arrays(tracks_branches, library='pd')

# Peek at each DataFrame (signal)
print("segs10_data.head():\n", segs10_data.head(), "\n")
print("frames1_data.head():\n", frames1_data.head(), "\n")
print("mc_tracks1_data.head():\n", mc_tracks1_data.head(15), "\n")

print("Lowest frameId in segs10_data:", segs10_data['frameId'].min(), "to highest", segs10_data['frameId'].max())
print("Lowest frameId in segsIC_data:", segsIC_data['frameId'].min(), "to highest", segsIC_data['frameId'].max())
print("Lowest frameId in segsbeam_data:", segsbeam_data['frameId'].min(), "to highest", segsbeam_data['frameId'].max())
print("Lowest frameId in segsmichel_data:", segsmichel_data['frameId'].min(), "to highest", segsmichel_data['frameId'].max())

In [None]:
# save to csv the tracks found by standard reconstruction
# this is simply a list of mc_tids and frameIds of all tracks in mc_tracks
import pandas as pd

_list_cols = ['tid', 'n6', 'n8', 'ntl']

def explode_and_save(df: pd.DataFrame, out_csv: str):
    # 1) Copy and turn any awkward/list-like column into a real Python list
    df2 = df.copy()
    for col in _list_cols:
        if col in df2.columns:
            df2[col] = df2[col].apply(
                lambda x: list(x) if hasattr(x, "tolist") or isinstance(x, (tuple, pd.Series)) else x
            )

    # 2) Build a zipped column of tuples, one tuple per “candidate”
    #    (zip will automatically pad shorter lists by dropping extra elements)
    df2['_zipped'] = df2[_list_cols].apply(lambda row: list(zip(*row)), axis=1)

    # 3) Explode that one column
    df3 = df2.explode('_zipped')

    # 4) Unpack the tuples back into the original columns
    for i, col in enumerate(_list_cols):
        df3[col] = df3['_zipped'].str[i]

    df3 = df3.drop(columns=['_zipped'])

    # 5) Apply selection
    sel = (
        df3.loc[
            (df3['n6'] == 1) &
            (df3['n8'] == 0) &
            (df3['ntl'] > 0),
            ['tid', 'frameId']
        ]
        .drop_duplicates()
        .rename(columns={'tid': 'mc_tid'})
    )
    
    # cast mc_tid (and frameId) to integers
    sel['mc_tid']   = sel['mc_tid'].astype(int)
    sel['frameId'] = sel['frameId'].astype(int)

    sel.to_csv(out_csv, index=False)
    print(f"→ wrote {len(sel)} rows to {out_csv}")

# Call for each mc_tracks DataFrame:
explode_and_save(mc_tracks1_data,    "stdreco_6hit_tracks_signal1.csv")
explode_and_save(mc_tracksIC_data,   "stdreco_6hit_tracks_IC1.csv")
explode_and_save(mc_tracksbeam_data, "stdreco_6hit_tracks_beam1.csv")
explode_and_save(mc_tracksM_data,    "stdreco_6hit_tracks_michel1.csv")

hits_df_unique contains the unique hits that are used to generate graphs. 

tracks_df contains the truth tracks determined before any track generation (built from triplets so have length of at least three). 

sixhit_tracks_df finds the subset of these truth tracks that follow the layer sequences 1 > 2 > 3 > 4 > 4+- > 3+- where +- is for the forward recurl layers far into +-z direction. 
these are exactly the tracks that the 6-hit graph generating algorithm attempts to find.

real_tracks_df and fake_tracks_df are from the 'validation' (constraints) algorithm which either do or dont match all of their track ids within the built track that follows the same layer sequence as above. 
'real' and 'fake' tracks are found from validated_tracks_df, which contains all tracks that pass constraints. 

### Find truth tracks for later comparison/plots/plotly 3D view

In [None]:
# 'tracks_df' is a dataframe containing all tracks that can be built from segs.
# tracks include anything with more than three hits (from triplets)
# there are cases where a built truth sixhit track has all mc_tids=0 (these are discarded entirely in the later method before turning into graphs)

# 'sixhit_tracks_df' filters only for tracks that follow the sequences 1,2,3,4,4+-,3+-
#  where +- is the forward recurl layer far into +-z direction

import numpy as np
import pandas as pd
from truetrackbuilding import explode_triplets_to_tracks, assign_layer1

# run for segs10_data, segsIC_data, segsbeam_data, segsmichel_data
tracks_df = explode_triplets_to_tracks(segs10_data) # here
tracks_df = tracks_df[tracks_df['mc_tid'] != 0].copy()
print(f"Built {len(tracks_df):,} total tracks (length ≥3 hits).")
tracks_df["layer_sequence"] = tracks_df["hits"].apply(lambda hits_list: [assign_layer1(hit) for hit in hits_list])

# Filters only + and - recurls from the very first recorded hit.
desired_sequences = [
    ['1', '2', '3', '4', '4+', '3+'],
    ['1', '2', '3', '4', '4-', '3-'],
]

sixhit_candidates = tracks_df[ tracks_df["hits"].apply(len) == 6 ].copy()
# there are rare instances of a subsequence that matches the desired sequence (due to MS) but these 
# do not count as 6hits, so applying len==6 is fine.
sixhit_candidates = sixhit_candidates[
    sixhit_candidates["layer_sequence"].apply(lambda seq: seq in desired_sequences)
].reset_index(drop=True)

# Select relevant columns
sixhit_tracks_df = sixhit_candidates[["frameId", "mc_tid", "mc_pid", "mc_type", "mc_p", "mc_pt", "layer_sequence", "hits"]].copy()

# Display a summary of the final filtered and truncated tracks
print("\nFinal Filtered 6hit Tracks:")
print(sixhit_tracks_df.head())
print(f"Built {len(sixhit_tracks_df):,} six-hit recurl tracks.")

torch.save(tracks_df, "true_tracks_eee_May31.pt") # also change name here
torch.save(sixhit_tracks_df, "sixhit_tracks_eee_May31.pt") # 

In [None]:
# save a list of the true unique tids and frameids for final efficiency comparisons for 6hit tracks/graphs
import pandas as pd

filtered_df = sixhit_tracks_df[sixhit_tracks_df['mc_tid'] != 0]
unique_df = filtered_df[['mc_tid', 'frameId']].drop_duplicates()
# unique_df.to_csv('true6hitmctidsframeidsMichel.csv', index=False)

In [None]:
# Find the loosest constraints that will 'validate' all real tracks
# here you will see how the z12/z34 constraint is very bad and needs to be looser
# or replace the constraint with something else entirely.
sixhit_tracks_df = torch.load("sixhit_tracks_eee_May31.pt") #

import numpy as np, matplotlib.pyplot as plt
from build6hittracks import collect_sixhit_distributions

# constraints contains the histogram information for each constraint used in graph generation (track 'validation')
constraints = collect_sixhit_distributions(sixhit_tracks_df)

import numpy as np
import matplotlib.pyplot as plt

# Pull each list out of the dictionary
d12_list      = constraints["d12"]
d34_list      = constraints["d34"]
d56_list      = constraints["d56"]
zratio56_list   = constraints["zratio56"]
zratio12_list   = constraints["zratio12"]
center_diff_list  = constraints["center_diff"]
radius_diff_list  = constraints["radius_diff"]
pitch_diff_list   = constraints["pitch_diff"]

# Re‐assemble into a mapping from label → data, exactly as before:
constraints_for_plot = {
    'N$_1$–N$_2$ Transverse Distance [mm]':      d12_list,
    'N$_3$–N$_4$ Transverse Distance [mm]':      d34_list,
    'N$_5$–N$_6$ Transverse Distance [mm]':      d56_list,
    'Min. Centre Transverse Difference [mm]':    center_diff_list,
    'Radius Tolerance [mm]':                     radius_diff_list,
    'Min. Pitch Difference [mm/rad]':            pitch_diff_list,
    'z$_{5-6}$/z$_{3-4}$ [%]':                   zratio56_list,
    'z$_{1-2}$/z$_{3-4}$ [%]':                   zratio12_list,
}

x_ranges = {
    'N$_1$–N$_2$ Transverse Distance [mm]':   (0, 30),
    'N$_3$–N$_4$ Transverse Distance [mm]':   (0, 60),
    'N$_5$–N$_6$ Transverse Distance [mm]':   (0, 60),
    'Min. Centre Transverse Difference [mm]': (0, 100),
    'Radius Tolerance [mm]':                  (0, 50),
    'Min. Pitch Difference [mm/rad]':         (0, 40),
    'z$_{5-6}$/z$_{3-4}$ [%]':                (0, 200),
    'z$_{1-2}$/z$_{3-4}$ [%]':                (0, 150)
}

for label, data in constraints_for_plot.items():
    arr = np.array(data)
    arr = arr[np.isfinite(arr)]   # filter out any NaNs

    plt.figure(figsize=(4.5, 3))
    if label in x_ranges:
        lo, hi = x_ranges[label]
        plt.hist(arr, bins=50, range=(lo, hi),
                 histtype='step', color='k', linewidth=1.5)
        plt.xlim(lo, hi)
    else:
        plt.hist(arr, bins=50, histtype='step', color='k', linewidth=1.5)

    # draw any red dashed lines:
    if label == 'Min. Centre Transverse Difference [mm]':
        plt.axvline(50, color='red', linestyle='--', linewidth=1.5,
                    label='set to 50 mm')
        plt.legend(fontsize=11)

    if label == 'z$_{5-6}$/z$_{3-4}$ [%]':
        plt.axvline(35, color='red', linestyle='--', linewidth=1.5,
                    label='set to ±65%')
        plt.axvline(165, color='red', linestyle='--', linewidth=1.5)
        plt.legend(fontsize=11, framealpha=0.3)

    #  Here is is for z_{1-2}/z_{3-4} =
    #  “0.4 ± 0.65” ⇒ [14%, 66%]
    if label == 'z$_{1-2}$/z$_{3-4}$ [%]':
        plt.axvline(14, color='red', linestyle='--', linewidth=1.5,
                    label='set to 40%×(1−0.65) = 14%')
        plt.axvline(66, color='red', linestyle='--', linewidth=1.5,
                    label='set to 40%×(1+0.65) = 66%')
        plt.legend(fontsize=11, framealpha=0.3)

    plt.xlabel(label, fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.yscale('log')
    plt.tick_params(labelsize=12)
    plt.grid(linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
# view all truth track info for a frame
frame_id = 4202
tracks_in_frame = tracks_df[tracks_df['frameId'] == frame_id]
print(tracks_in_frame[['mc_tid', 'layer_sequence', 'hits']])
from IPython.display import HTML
HTML('<div style="height:300px; overflow:auto">' + tracks_in_frame.to_html() + '</div>')

In [None]:
# find pT the MS effect is minimised (ie where the recurl hits occur near phi= pi ±0.5 rad) 
# Fig 2.4 in the TDR
import numpy as np
from circlefitanalysis import recurl_angle

tol = 0.5
dphis = []
mask = []

for hits in tracks_df['hits']:
    dphi = recurl_angle(hits)
    dphis.append(dphi)
    mask.append(abs(abs(dphi) - np.pi) < tol)

tracks_df['dphi_recurl'] = dphis
recurling = tracks_df[mask]

print(f"Found {len(recurling)} tracks with Δφ≈π ± {tol} rad")

pt_recurl = tracks_df.loc[mask, 'mc_pt']

counts, bin_edges = np.histogram(pt_recurl, bins=25)
bin_centers = 0.5*(bin_edges[:-1] + bin_edges[1:])
errors = np.sqrt(counts)

plt.figure(figsize=(6,4.5))
plt.hist(pt_recurl,
         bins=bin_edges,
         histtype='step',
         linewidth=1.4,
         color='navy',
         label='Recurling Tracks')
plt.errorbar(bin_centers, counts,
             yerr=errors,
             fmt='none',
             ecolor='black',
             capsize=3)

plt.xlabel(r"$p_{\mathrm{T,true}}$ [MeV]", fontsize=14)
plt.ylabel(r"Number of Tracks ($\Delta\phi \approx \pi \pm 0.5\,$rad)", fontsize=14)
plt.tick_params(axis="both", labelsize=12)
plt.grid(linestyle="--", alpha=0.4)
plt.legend()
plt.show()

# the high pT tail does not make sense to me. i viewed these frames/trackids in the frame 
# plotly and they seemed to have been incorrectly assigned high momenta? they have the
# same bending radii as the low pT tracks. or it is after energy loss and only some final hits pass some filter
#  use the frame visualisation in the cell below to look at frames with anomalous tracks

for _, row in tracks_df[mask & (tracks_df['mc_pt'] > 50)].head(5).iterrows():
    print(f"Frame {row['frameId']}, mc_tid = {row['mc_tid']}, pT = {row['mc_pt']:.1f} MeV")

In [None]:
# plotly visualisation to look at frames
# eg. here, frame 473 track id 67924 has large pT = 51MeV but appears to have a small bending radius (just large enough to reach layer 4)
from plotlyhelper import visualise_truth_tracks, wireframe_traces
target_frame_id = 473  # Replace with desired frame id
visualise_truth_tracks(target_frame_id, tracks_df, wireframe_traces)

In [None]:
# find full and amb info for double hit probability
# the code looks at just 6hit track mc_tids
# 'full' info refers to tracks that complete a 6hit 
# whereas 'ambiguous' refers to incomplete tracks that follow the same trajectory with the same mc_tid
# but contain a different hit in a speific layer due to overlapping ladders in the detector, leading to double hits.
# the plot finds the probability of having double hits for any given true 6hit track. so it is not exactly the same as the probability 
# of having a duplicate 6hit graph for a given mc_tid which depends on the number of hits given per layer
import numpy as np
from collections import defaultdict

# two recurl sequences:
desired_sequences = [
    ['1', '2', '3', '4', '4+', '3+'],
    ['1', '2', '3', '4', '4-', '3-']
]

# Containers:
phi_by_layer = defaultdict(lambda: {'full': [], 'amb': []})
full_info = []
amb_info  = []

# Single pass over all (frameId, mc_tid) groups
for (frame_id, mc_tid), group in tracks_df.groupby(['frameId', 'mc_tid']):
    # pick the first “full” recurl track if any. the ambiguous (double hit) track does not form a 6hit.
    full = group.loc[
        group['hits'].apply(len).eq(6) &
        group['layer_sequence'].isin(desired_sequences)
    ]
    amb  = group.loc[group['hits'].apply(len).ne(6)]
    
    if full.empty or amb.empty:
        continue
    
    track_full = full.iloc[0]
    track_amb  = amb .iloc[0]
    
    # map layer→hit once each
    full_hits = {layer:hit for hit, layer in zip(track_full['hits'], track_full['layer_sequence'])}
    amb_hits  = {layer:hit for hit, layer in zip(track_amb ['hits'], track_amb ['layer_sequence'])}
    
    # look at layers in common
    for layer in full_hits.keys() & amb_hits.keys():
        hf = np.asarray(full_hits[layer])
        ha = np.asarray( amb_hits[layer])
        
        # skip identical within tolerance
        if np.allclose(hf, ha, atol=1e-6):
            continue
        
        # compute phis
        phi_f = np.arctan2(hf[1], hf[0])
        phi_a = np.arctan2(ha[1], ha[0])
        
        # accumulate per‐layer lists
        d = phi_by_layer[layer]
        d['full'].append(phi_f)
        d['amb' ].append(phi_a)
        
        # shared metadata
        meta = {
            'frameId': frame_id,
            'mc_tid':  mc_tid,
            'mc_pid': track_full.get('mc_pid', None),
            'mc_type':track_full.get('mc_type',None),
            'mc_p':   track_full.get('mc_p',   None),
            'mc_pt':  track_full.get('mc_pt',  None),
            'layer':  layer
        }
        full_info.append({**meta, 'hit': hf, 'phi': phi_f})
        amb_info .append({**meta, 'hit': ha, 'phi': phi_a})

# duplicated hit probability as fn of pT
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- Overall (Full) Tracks: restrict to 6-hit recurls ---
df_all = tracks_df[tracks_df['hits'].apply(len) == 6]
df_all = df_all[df_all['layer_sequence'].apply(lambda seq: seq in desired_sequences)]
df_all = df_all.dropna(subset=['mc_pt'])
df_all = df_all[df_all['mc_pt'] != 0]
df_all = df_all[df_all['mc_tid'].apply(lambda x: isinstance(x, int) and x != 0)]
df_all = df_all[~df_all['mc_type'].isin(['Multiple', 'Unknown'])]

# --- Duplicate Hits from amb_info (collected from duplicate-checking) ---
df_dup = pd.DataFrame(amb_info)
df_dup = df_dup.dropna(subset=['mc_pt'])
df_dup = df_dup[df_dup['mc_pt'] != 0]
df_dup = df_dup[df_dup['mc_tid'].apply(lambda x: isinstance(x, int) and x != 0)]
df_dup = df_dup[~df_dup['mc_type'].isin(['Multiple', 'Unknown'])]

# --- Split into Electrons and Positrons ---
df_all_e = df_all[df_all['mc_pid'] == 11]
df_all_p = df_all[df_all['mc_pid'] == -11]
print(len(df_all_e), 'e-', len(df_all_p), 'e+')
df_dup_e = df_dup[df_dup['mc_pid'] == 11]
df_dup_p = df_dup[df_dup['mc_pid'] == -11]

# --- Define Common Bins in pT ---
# We use a fixed lower bound and the max of all four subsets. 
pt_max = max(df_all_e['mc_pt'].max(), df_all_p['mc_pt'].max())
pt_min = min(df_all_e['mc_pt'].min(), df_all_p['mc_pt'].min())
bins = np.linspace(pt_min, pt_max, 31)  # 50 bins
bin_centers = 0.5 * (bins[:-1] + bins[1:])

# --- Compute Histogram Counts and Duplication Probability for Electrons ---
counts_all_e, _ = np.histogram(df_all_e['mc_pt'], bins=bins)
counts_dup_e, _ = np.histogram(df_dup_e['mc_pt'], bins=bins)
probability_e = np.divide(counts_dup_e, counts_all_e, out=np.zeros_like(counts_dup_e, dtype=float), where=counts_all_e != 0)
err_e = np.zeros_like(probability_e, dtype=float)
nonzero_e = counts_all_e > 0
err_e[nonzero_e] = np.sqrt(probability_e[nonzero_e]*(1 - probability_e[nonzero_e]) / counts_all_e[nonzero_e])

# --- Compute Histogram Counts and Duplication Probability for Positrons ---
counts_all_p, _ = np.histogram(df_all_p['mc_pt'], bins=bins)
counts_dup_p, _ = np.histogram(df_dup_p['mc_pt'], bins=bins)
probability_p = np.divide(counts_dup_p, counts_all_p, out=np.zeros_like(counts_dup_p, dtype=float), where=counts_all_p != 0)
err_p = np.zeros_like(probability_p, dtype=float)
nonzero_p = counts_all_p > 0
err_p[nonzero_p] = np.sqrt(probability_p[nonzero_p]*(1 - probability_p[nonzero_p]) / counts_all_p[nonzero_p])

# --- Plot the Overlapping Probability Distributions ---
plt.figure(figsize=(6, 4))
plt.errorbar(bin_centers, probability_e, yerr=err_e, fmt='o', color='blue', capsize=3, label='Electrons')
plt.errorbar(bin_centers, probability_p, yerr=err_p, fmt='o', color='red', capsize=3, label='Positrons')

plt.xlabel(r"$p_{T}^{\rm true}$ [MeV]", fontsize=14)
plt.ylabel("Double Hit Probability", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12, loc='upper center')
plt.show()

# # finds duplicates with pT above 50MeV?
# df_dup_highpt = df_dup[df_dup['mc_pt'] > 50]
# print(df_dup_highpt.head(5))

In [None]:
# plotly visualisation
target_frame_id = 878  # Replace with desired frame id
visualise_truth_tracks(target_frame_id, tracks_df, wireframe_traces)

In [None]:
# Average Number of Hits per Layer
# Ensure necessary libraries are imported
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# load the correct data (hits unique is found later)
# hits_df_unique   = torch.load('unique_hits_eee.pt') # unique_hits_eee, unique_hits_eeevv, unique_hits_evv, unique_hits_beam

# Step 1: Group by 'frameId' and 'layer' to count hits per layer per frame
hits_per_frame_layer = hits_df_unique.groupby(['frameId', 'layer']).size().reset_index(name='hit_count')
fmin = hits_per_frame_layer['frameId'].min()
fmax = hits_per_frame_layer['frameId'].max()

# Step 2: Pivot the DataFrame to have frames as rows and layers as columns
all_frames = pd.Index(range(fmin, fmax+1), name='frameId')
hits_pivot = (hits_per_frame_layer.pivot(index='frameId', columns='layer', values='hit_count').reindex(all_frames, fill_value=0))
# Step 3: Calculate mean and standard deviation for each layer
mean_hits = hits_pivot.mean()
std_hits = hits_pivot.std()
N = len(hits_pivot)  # Number of frames (samples)

# Step 4: Compute standard error of the mean (SEM)
sem_hits = std_hits / np.sqrt(N)  # Error as sqrt(s²/N)

# Step 5: Plotting
plt.figure(figsize=(4.5, 4))
bar_container = plt.bar(
    mean_hits.index,
    mean_hits.values,
    yerr=sem_hits.values,  # Corrected error bars
    capsize=5,
    color=sns.color_palette('magma', len(mean_hits)),
    # edgecolor='black'
)

# Add numerical labels on top of each bar with a slight right shift
for bar in bar_container:
    height = bar.get_height()
    width = bar.get_width()
    x = bar.get_x()
    
    # Calculate a small shift (e.g., 5% of the bar width)
    shift = width * 0  # Adjust the multiplier (e.g., 0.05 for 5%)
    
    # Ensure labels do not go beyond the plot's x-axis limits
    label_x = x + width / 2 + shift
    # Optionally, clamp the label_x to the axis limits
    ax_limits = plt.gca().get_xlim()
    label_x = min(label_x, ax_limits[1] - width * 0.1)  # Adjust 0.1 as needed
    
    plt.text(
        label_x,
        height,
        f'{height:.2f}',
        ha='center',
        va='bottom',
        fontsize=12,
        color='black'
    )

# Customize the plot
plt.xlabel('Detector Layer', fontsize=14)
plt.ylabel('Average Number of Hits per Frame', fontsize=13)
plt.ylim(0, max(mean_hits.values + sem_hits.values) * 1.1)  # Adjust for SEM
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tick_params(axis='both', labelsize=13)
plt.grid(axis='y', linestyle='-', linewidth=0.5, alpha=0.3)

# Improve layout and display the plot
plt.show()

In [None]:
# number of truth 6 hit tracks and long tracks
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Find full frame‐ID range from segs10_data, segsIC_data, segsbeam_data, segsmichel_data:
min_id = int(segs10_data['frameId'].min())
max_id = int(segs10_data['frameId'].max())
full_frame_index = np.arange(min_id, max_id + 1)  # all integers from min to max

# Count six‐hit recurl tracks (mc_tid != 0) per frame:
tracks_per_frame = (
    sixhit_tracks_df[sixhit_tracks_df['mc_tid'] != 0]
    .groupby('frameId')['mc_tid']
    .nunique()
)
# Reindex onto the full range, filling in 0 for any frame that never appears:
tracks_per_frame_full = tracks_per_frame.reindex(full_frame_index, fill_value=0)


# Count “long” tracks (mc_tid != 0) per frame in exactly the same way:
unf_tracks_per_frame = (
    tracks_df[tracks_df['mc_tid'] != 0]
    .groupby('frameId')['mc_tid']
    .nunique()
)
unf_tracks_per_frame_full = unf_tracks_per_frame.reindex(full_frame_index, fill_value=0)


# Now build discrete histograms over {0,1,2,3,4} tracks/frame:
#    Any frame with ≥5 tracks we lump into the “4” bin (so that our x‐axis is 0 through 4).

# Count how many frames have exactly k six‐hit recurl‐tracks:
tracks_counts = tracks_per_frame_full.value_counts().sort_index()
# Similarly for “long” tracks:
unf_tracks_counts = unf_tracks_per_frame_full.value_counts().sort_index()

# We only want bins 0–4 (anything≥5 goes into index 4):
desired_N = range(0, 5)
tracks_counts = tracks_counts.reindex(desired_N, fill_value=0)
unf_tracks_counts = unf_tracks_counts.reindex(desired_N, fill_value=0)


# Define edges and centers for a step‐histogram over 0,1,2,3,4:
bin_edges   = np.array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0

# Poisson errors = sqrt(count):
errors_tracks = np.sqrt(tracks_counts.values)
errors_unf   = np.sqrt(unf_tracks_counts.values)

# --- Plot for Six-Hit Recurls ---

plt.figure(figsize=(3.6, 4))
# Use plt.stairs to draw the step histogram.
plt.stairs(tracks_counts.values, bin_edges, fill=False, color='green', linewidth=1.5, label='Six-Hit Recurls')
# Overlay error bars.
plt.errorbar(bin_centers, tracks_counts.values, yerr=errors_tracks, fmt='none', color='green', capsize=3)

plt.xlabel('Number of Six-Hit \n Tracks per Frame', fontsize=14)
plt.ylabel('Number of Frames', fontsize=14)
plt.tick_params(axis='both', labelsize=13)
plt.grid(axis='both', linestyle='-', linewidth=0.5, alpha=0.3)
plt.ylim(0, 31200)
plt.tight_layout()
plt.xticks(bin_centers, [str(i) for i in range(0, 5)])
plt.show()

# --- Plot for Long Tracks ---

plt.figure(figsize=(3.6, 4))
plt.stairs(unf_tracks_counts.values, bin_edges, fill=False, color='orange', linewidth=1.5, label='Long Tracks')
plt.errorbar(bin_centers, unf_tracks_counts.values, yerr=errors_unf, fmt='none', color='orange', capsize=3)

plt.xlabel('Number of Long \n Tracks per Frame', fontsize=14)
plt.ylabel('Number of Frames', fontsize=14)
plt.tick_params(axis='both', labelsize=13)
plt.grid(axis='both', linestyle='-', linewidth=0.5, alpha=0.3)
plt.ylim(0, 16100)
plt.tight_layout()
plt.xticks(bin_centers, [str(i) for i in range(0, 5)])
plt.show()

## Graph Generation from Hits (triplets)

In [None]:
# find 'hits_df_unique'. this cell finds all hits from triplets, including or excluding noise (where mc_tid, mc_p, etc. = 0).
# hits_df_unique has noise and is used for graph generation and plots.
# hits_df_unique2 is clean from noise and is only used for plots. 

from flatten_hits import flatten_hits, dedupe_hits, apply_layer_assignment

# Replace with the mu3e file you want eg. segs10_data, segsIC_data, segsbeam_data, segsmichel_data
flat_hits_df     = flatten_hits(segsbeam_data) # here
hits_df_unique   = dedupe_hits(flat_hits_df)
hits_df_unique   = apply_layer_assignment(hits_df_unique)
hits_df_unique2  = hits_df_unique[hits_df_unique.mc_tid!=0].reset_index(drop=True)

print(f"Total Hits After Flattening: {len(flat_hits_df)}")
print("\nDeduplicated Hits DataFrame (hits_df_unique):")
print(hits_df_unique.head())
print(f"Total Unique Hits in hits_df_unique: {len(hits_df_unique)}")
print(f"Total Unique Hits in hits_df_unique2: {len(hits_df_unique2)}")

# # i saved and loaded these as 'unique_hits_eee', 'unique_hits_eeevv', 'unique_hits_evv', 'unique_hits_beam'
torch.save(hits_df_unique, 'unique_hits_beam.pt')
# torch.save(hits_df_unique2, 'unique_hits2_eee.pt')

In [None]:
# overlapping number of electrons seen per frame
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# --- load pre‐saved DataFrames without noise ---
df_eee    = torch.load("unique_hits2_eee.pt")
df_eeevv  = torch.load("unique_hits2_eeevv.pt")
df_evv    = torch.load("unique_hits2_evv.pt")
df_beam   = torch.load("unique_hits2_beam.pt")

def get_multiplicity(df):
    # compute unique‐mc_tid per frame
    counts = df.groupby('frameId')['mc_tid'] \
               .nunique()
    # reindex over the full observed frame‐ID range
    fmin, fmax = counts.index.min(), counts.index.max()
    counts = counts.reindex(range(fmin, fmax+1), fill_value=0)
    # now build histogram of “how many frames saw k electrons”
    freq = counts.value_counts().sort_index()
    # reindex so we don't miss any k from 0…max
    freq = freq.reindex(range(freq.index.max()+1), fill_value=0)
    return freq

mult_eee    = get_multiplicity(df_eee)
mult_eeevv  = get_multiplicity(df_eeevv)
mult_evv    = get_multiplicity(df_evv)
mult_beam   = get_multiplicity(df_beam)

# synchronize to common k‐axis
max_k = max(m.index.max() for m in (mult_eee, mult_eeevv, mult_evv, mult_beam))
mult_eee   = mult_eee.reindex(range(max_k+1), fill_value=0)
mult_eeevv = mult_eeevv.reindex(range(max_k+1), fill_value=0)
mult_evv   = mult_evv.reindex(range(max_k+1), fill_value=0)
mult_beam  = mult_beam.reindex(range(max_k+1), fill_value=0)

def normalize_and_extend(counts):
    total = counts.sum()
    norm  = counts.values / total
    err   = np.sqrt(counts.values) / total
    ext   = np.append(norm, norm[-1])  # so step plot closes
    return norm, err, ext

n_eee,   e_eee,   x_eee   = normalize_and_extend(mult_eee)
n_eeevv, e_eeevv, x_eeevv = normalize_and_extend(mult_eeevv)
n_evv,   e_evv,   x_evv   = normalize_and_extend(mult_evv)
n_beam,  e_beam,  x_beam  = normalize_and_extend(mult_beam)

# x‐axis
x_edges   = np.arange(-0.5, max_k + 1.5, 1.0)
x_centers = np.arange(0, max_k+1)

plt.figure(figsize=(4,4))
plt.step(x_edges, x_eee,    where='post', label='Signal', color='blue',   lw=1.5)
plt.step(x_edges, x_eeevv,  where='post', label='I.C.',   color='orange', lw=1.5)
plt.step(x_edges, x_evv,    where='post', label='Michel', color='red',    lw=1.5)
plt.step(x_edges, x_beam,   where='post', label='Beam',   color='green',  lw=1.5)

# error bars
plt.errorbar(x_centers, n_eee,   yerr=e_eee,   fmt='none', ecolor='blue',   capsize=3)
plt.errorbar(x_centers, n_eeevv, yerr=e_eeevv, fmt='none', ecolor='orange', capsize=3)
plt.errorbar(x_centers, n_evv,   yerr=e_evv,   fmt='none', ecolor='red',    capsize=3)
plt.errorbar(x_centers, n_beam,  yerr=e_beam,  fmt='none', ecolor='green',  capsize=3)

plt.xlabel('Number of $e^-/e^+$ per Frame', fontsize=14)
plt.ylabel('Frequency Density',             fontsize=14)
# plt.yscale('log')
plt.xticks(x_centers, fontsize=12)
plt.yticks(fontsize=12)
plt.xlim(-0.5, 11 + 0.5)
plt.legend(fontsize=11)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

In [None]:
# number of hits per frame overlapping
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load pre-saved unique hits DataFrames.
df_eee   = torch.load("unique_hits2_eee.pt")
df_eeevv = torch.load("unique_hits2_eeevv.pt")
df_evv   = torch.load("unique_hits2_evv.pt")
df_beam  = torch.load("unique_hits2_beam.pt")

# --- dynamically infer total frames from each DataFrame's frameId range ---
def infer_total_frames(df):
    mi, ma = int(df['frameId'].min()), int(df['frameId'].max())
    return ma - mi + 1

total_frames = {
    "Signal":             infer_total_frames(df_eee),
    "Internal Conversion":infer_total_frames(df_eeevv),
    "Michel":             infer_total_frames(df_evv),
    "Beam":               infer_total_frames(df_beam)
}

print("Inferred total frames per sample:", total_frames)

def get_hits_per_frame_with_zeros(df, total):
    # start with an array of zeros for every frame
    hits = np.zeros(total, dtype=int)
    # count hits/frame in actual data
    counts = df.groupby("frameId").size()
    # only fill in those frameIds which are within [0, total)
    valid = counts.index[counts.index < total]
    hits[valid] = counts.loc[valid].values
    return hits

# Compute arrays
hits_eee   = get_hits_per_frame_with_zeros(df_eee,   total_frames["Signal"])
hits_eeevv = get_hits_per_frame_with_zeros(df_eeevv, total_frames["Internal Conversion"])
hits_evv   = get_hits_per_frame_with_zeros(df_evv,   total_frames["Michel"])
hits_beam  = get_hits_per_frame_with_zeros(df_beam,  total_frames["Beam"])

# choose binning (e.g. 0–32 in 16 bins → width=2)
bins     = np.linspace(-0.5, 31.5, 17)
centers  = 0.5 * (bins[:-1] + bins[1:])
width    = bins[1] - bins[0]

def compute_density_and_error(arr, total):
    counts, _ = np.histogram(arr, bins=bins)
    density   = counts / (total * width)
    err       = np.sqrt(counts) / (total * width)
    # extend for step-plot
    return np.concatenate([density, [density[-1]]]), \
           np.concatenate([err,     [err[-1]]])

den_eee,   err_eee   = compute_density_and_error(hits_eee,   total_frames["Signal"])
den_eeevv, err_eeevv = compute_density_and_error(hits_eeevv, total_frames["Internal Conversion"])
den_evv,   err_evv   = compute_density_and_error(hits_evv,   total_frames["Michel"])
den_beam,  err_beam  = compute_density_and_error(hits_beam,  total_frames["Beam"])

plt.figure(figsize=(5,4))

# step-plots
plt.step(bins, den_eee,    where="post", color="blue",   lw=1.5, label="Signal")
plt.step(bins, den_eeevv,  where="post", color="orange", lw=1.5, label="I.C.")
plt.step(bins, den_evv,    where="post", color="red",    lw=1.5, label="Michel")
plt.step(bins, den_beam,   where="post", color="green",  lw=1.5, label="Beam")

# errorbars at bin centers
plt.errorbar(centers, den_eee[:-1],   yerr=err_eee[:-1],   fmt="none", ecolor="blue",   capsize=3)
plt.errorbar(centers, den_eeevv[:-1], yerr=err_eeevv[:-1], fmt="none", ecolor="orange", capsize=3)
plt.errorbar(centers, den_evv[:-1],   yerr=err_evv[:-1],   fmt="none", ecolor="red",    capsize=3)
plt.errorbar(centers, den_beam[:-1],  yerr=err_beam[:-1],  fmt="none", ecolor="green",  capsize=3)

plt.xlabel("Number of Hits per Frame", fontsize=14)
plt.ylabel("Frequency Density",      fontsize=14)
plt.xlim(bins[0], bins[-1])
plt.tick_params(labelsize=12)
plt.grid(axis="y", linestyle="--", alpha=0.4)
plt.legend(fontsize=11)
plt.show()

In [None]:
# find the number of graphs possible across all frames.
# Assumes you already have hits_df_unique loaded
# this can be run for any of the four datasets.

hits_df_unique = torch.load('unique_hits_eee.pt') # 'unique_hits_eee', 'unique_hits_eeevv', 'unique_hits_evv', 'unique_hits_beam'
import numpy as np
import matplotlib.pyplot as plt

# 1) Loop over frames to count possible graphs per frame
frames = hits_df_unique.groupby('frameId')

total_graphs = 0
frame_totals = []
max_frame = None
max_count = -1

for frame_id, group in frames:
    # count hits in each layer
    c1  = len(group[group['layer'] == '1'])
    c2  = len(group[group['layer'] == '2'])
    c3  = len(group[group['layer'] == '3'])
    c4  = len(group[group['layer'] == '4'])
    c4p = len(group[group['layer'] == '4+'])
    c3p = len(group[group['layer'] == '3+'])
    c4m = len(group[group['layer'] == '4-'])
    c3m = len(group[group['layer'] == '3-'])
    
    # two trajectory possibilities per frame
    graphs_traj1 = c1 * c2 * c3 * c4  * c4p * c3p
    graphs_traj2 = c1 * c2 * c3 * c4  * c4m * c3m
    
    frame_total = graphs_traj1 + graphs_traj2
    frame_totals.append(frame_total)
    total_graphs += frame_total
    
    if frame_total > max_count:
        max_count = frame_total
        max_frame = frame_id

print(f"Total possible graphs across all frames: {total_graphs}")
print(f"Frame with most possible graphs: {max_frame} → {max_count} trajectories")

# 2) Compute number of completely empty frames dynamically
unique_frames = hits_df_unique['frameId'].unique()
min_f, max_f = unique_frames.min(), unique_frames.max()
total_frames = max_f - min_f + 1
n_seen = len(unique_frames)
num_empty = total_frames - n_seen

print(f"Frame IDs range from {min_f} to {max_f} → {total_frames} total frames")
print(f"{n_seen} frames with ≥1 graph, so {num_empty} empty frames")

# 3) Combine non-empty counts with zeros for empty frames
all_totals = frame_totals + [0] * num_empty

# 4) Plot histogram for full range (log y-axis)
counts, bin_edges = np.histogram(all_totals, bins=35)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
errors = np.sqrt(counts)

plt.figure(figsize=(4.5, 3))
plt.hist(all_totals, bins=35, histtype='step', linewidth=1.4, edgecolor='black')
plt.errorbar(bin_centers, counts, yerr=errors,
             fmt='none', ecolor='black', capsize=3)
plt.yscale('log')
plt.xlabel('Number of Possible Graphs per Frame', fontsize=14)
plt.ylabel('Number of Frames', fontsize=14)
plt.tick_params(axis='both', labelsize=12)
# plt.ylim(1.1e-10,3.2e-5)
ax = plt.gca()
offset = ax.xaxis.get_offset_text()
offset.set_position((1.1, 0))
offset.set_ha('center')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

# 5) Plot zoomed‐in histogram (0–50 graphs)
counts_z, bin_edges_z = np.histogram(all_totals, bins=35, range=(0,50))
bin_centers_z = 0.5 * (bin_edges_z[:-1] + bin_edges_z[1:])
errors_z = np.sqrt(counts_z)

plt.figure(figsize=(4.5, 3))
plt.hist(all_totals, bins=35, range=(0,50), histtype='step', linewidth=1.4, edgecolor='black')
plt.errorbar(bin_centers_z, counts_z, yerr=errors_z,
             fmt='none', ecolor='black', capsize=3)
plt.xlabel('Number of Possible Graphs per Frame', fontsize=14)
plt.ylabel('Number of Frames', fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

In [None]:
# 6hit only; All Frame Execution: Building and Validating Tracks
from build6hittracks import build_and_validate_tracks
import pandas as pd
import torch 
from tqdm import tqdm  # For progress tracking

# change the hits for the correct file: unique_hits_eee, unique_hits_eeevv, unique_hits_evv, unique_hits_beam
hits_df_unique   = torch.load('unique_hits_beam.pt') # here

all_tracks = []
for frame_id in tqdm(hits_df_unique['frameId'].unique()):
    validated = build_and_validate_tracks(
        hits_df_unique,
        frame_id,
        layer_sequences = [
            ['1','2','3','4','4+','3+'],
            ['1','2','3','4','4-','3-']
        ],
        center_tolerance=50, # mm
        radius_tolerance=50, # mm 
        pitch_tolerance=40, # mm/rad
        distance_constraints={'1-2':30,'3-4':60,'5-6':60}, # mm 
        z_distance_error_margin=0.65 
        # look at the plots showing the true value for these constraints for 6hit graphs if you must adjust
        # the bad constraint with arbitrary value 0.4 for zratio z12/z34 is found in build6hittracks.py; def build_and_validate_tracks
    )
    all_tracks.extend(validated)

print(f"\nTotal Number of Validated Tracks: {len(all_tracks)}")
validated_tracks_df = pd.DataFrame(all_tracks)
torch.save(validated_tracks_df, "validated_tracks_beam_May31.pt") # also change name here

In [None]:
# histogram of LOG number of post-validation graphs per frame
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

hits_df_unique = torch.load('unique_hits_eee.pt')
validated_tracks_df = torch.load('validated_tracks_eee_May31.pt')
# --- 1) Build per‐frame totals -------------------------------------
unique_frames = hits_df_unique['frameId'].unique()
min_f, max_f = unique_frames.min(), unique_frames.max()
total_frames = max_f - min_f + 1
frame_counts = (
    validated_tracks_df
      .groupby('frameId')
      .size()
      .reindex(np.arange(total_frames), fill_value=0)
      .values
)

# --- 2) Compute histogram + Poisson errors -------------------------------
counts, bin_edges = np.histogram(frame_counts, bins=35)
bin_centers      = 0.5 * (bin_edges[:-1] + bin_edges[1:])
errors           = np.sqrt(counts)  # Poisson

# --- 3) Full‐range (log‐y) counts plot ------------------------------------
plt.figure(figsize=(4.5,3))
plt.hist(frame_counts,
         bins=35,
         histtype='step',
         edgecolor='navy',
         linewidth=1.4)
plt.errorbar(bin_centers, counts,
             yerr=errors,
             fmt='none',
             ecolor='black',
             capsize=3,
             elinewidth=1)
plt.tick_params(axis='both', labelsize=12)
plt.xlabel('Number of Generated Graphs per Frame', fontsize=13)
plt.ylabel('Counts',                  fontsize=14)
plt.yscale('log')
plt.grid(axis='y', linestyle='--', alpha=0.7)
ax     = plt.gca()
offset = ax.xaxis.get_offset_text()
offset.set_position((1.1, 0))     # tweak these numbers
offset.set_ha('center')
plt.show()

### Track to graph

In [None]:
# Find Real and Fakes based on Track IDs
# looks up the mc/truth info from the hit ids. ensures all hits share the same non-zero track id in order to be labelled 'real'
# real tracks have meaningful helical/momentum information, so this information is also saved.
# i recently saw that i was accepting any track with matching mc_tids with any number of noise hits as real. this is fixed now;
# it only resulted in a shift (real to fake) of a few hundred graphs out of the 110,000 saved.
import torch; import pandas as pd
from determinerealorfake import all_hits_zero_tid, determine_track_tid, determine_track_type, determine_track_pt, angle_from_map

hits_df_unique        = torch.load("unique_hits_beam.pt")
validated_tracks_df   = torch.load("validated_tracks_beam_May31.pt")

# 2) Build lookup maps from hit_id → mc_tid, mc_type, and the three angles
hit_to_tid_map  = hits_df_unique.set_index("hit_id")["mc_tid"].to_dict()
hit_to_type_map = hits_df_unique.set_index("hit_id")["mc_type"].to_dict()
hit_angle_map   = hits_df_unique.set_index("hit_id")[["mc_phi","mc_theta","mc_lam"]].to_dict("index")

mask_all_zero = validated_tracks_df["hits"].apply(lambda hits: all_hits_zero_tid(hits, hit_to_tid_map))
validated_tracks_df = validated_tracks_df[~mask_all_zero].copy()

# Apply the function to assign mc_tid to each validated track
validated_tracks_df["mc_tid"]      = validated_tracks_df["hits"].apply(lambda hits: determine_track_tid(hits, hit_to_tid_map))
validated_tracks_df["mc_type"]     = validated_tracks_df["hits"].apply(lambda hits: determine_track_type(hits, hit_to_type_map))
validated_tracks_df["is_real_track"] = validated_tracks_df["mc_tid"].apply(lambda x: isinstance(x, int) and x != 0)

real_tracks_df = validated_tracks_df[validated_tracks_df['is_real_track']].copy().reset_index(drop=True)
fake_tracks_df = validated_tracks_df[~validated_tracks_df['is_real_track']].copy().reset_index(drop=True)

# Display summaries
print(f"Total Real Tracks: {len(real_tracks_df)}")
print(f"Total Fake Tracks: {len(fake_tracks_df)}")

real_tracks_df[["mc_pt","mc_p"]] = (real_tracks_df["hits"].apply(lambda hits: determine_track_pt(hits)).apply(pd.Series))
for angle in ("mc_phi","mc_theta","mc_lam"):
    real_tracks_df[angle] = real_tracks_df["hits"].apply(
        lambda hits: angle_from_map(hits, angle, hit_angle_map))

# Identify frames with the most fakes
most_fakes_per_frame = fake_tracks_df['frameId'].value_counts()
print("\nFrames with Most Fake Tracks:")
print(most_fakes_per_frame.head())

# Total Real Tracks. eee_May31: 31370. eeevv_May31: 13066. evv_May31: 11884. beam_May31: 5912
# Total Fake Tracks. eee_May31: 87082. eeevv_May31: 48974. evv_May31: 58   . beam_May31: 364874

In [None]:
print(real_tracks_df.head())
# print(fake_tracks_df.head())

In [None]:
# this code finds the graph (global) feature and edge feature means and stdevs for normalisation.
# node features (hit positions and onehot layer ids) are not normalised.
# this only needs to be run once with the signal real tracks loaded as 'real_tracks_df_temp' 
# as future normalisation is with respect to signal in the following cell.
# perhaps this could use sixhit_tracks_df instead of real_tracks_df.
from graphgeneration import extract_geom_features_from_df6, compute_global_edge_features_stats6
import torch
real_tracks_df_temp = torch.load('real_tracks_eee_May31.pt')
# A) Compute and save geometric‐feature normalization stats
geom_array = extract_geom_features_from_df6(real_tracks_df_temp)   # shape = (N_real, 10)
geom_means = geom_array.mean(axis=0)
geom_stds  = geom_array.std(axis=0)

graph_norm_stats = {
    'geom_means': geom_means,
    'geom_stds':  geom_stds
}
torch.save(graph_norm_stats, "graph_norm_stats.pt")

# B) Compute and save edge‐feature normalization stats
#    First convert DataFrame → list of track‐dicts
tracks_list6 = real_tracks_df_temp.to_dict("records")
edge_stats   = compute_global_edge_features_stats6(tracks_list6)

edge_norm_stats = {
    'global_distance_means': edge_stats['distance'][0],
    'global_distance_stds':  edge_stats['distance'][1],
    'global_lambda_means':   edge_stats['lambda'][0],
    'global_lambda_stds':    edge_stats['lambda'][1],
    'global_t_means':        edge_stats['tdist'][0],
    'global_t_stds':         edge_stats['tdist'][1],
    'global_z_means':        edge_stats['zdist'][0],
    'global_z_stds':         edge_stats['zdist'][1]
}
torch.save(edge_norm_stats, "edge_norm_stats.pt")
print("Saved graph_norm_stats.pt and edge_norm_stats.pt")

In [None]:
# 6hit; track to graph
# label 0: positron, 1: electron, 2: fake.
# beam takes me 5min to save
# make sure real and fake_tracks_df are loaded with the correct set
# and change saved file name at the bottom of this cell.

import torch
import pandas as pd
from graphgeneration import (layer_map, extract_constraints6, extract_geom_features_from_df6, compute_global_edge_features_stats6, track_to_graph6)

graph_norm_stats = torch.load("graph_norm_stats.pt")
geom_means = graph_norm_stats["geom_means"]
geom_stds  = graph_norm_stats["geom_stds"]

edge_norm_stats = torch.load("edge_norm_stats.pt")
global_distance_means = edge_norm_stats["global_distance_means"]
global_distance_stds  = edge_norm_stats["global_distance_stds"]
global_lambda_means   = edge_norm_stats["global_lambda_means"]
global_lambda_stds    = edge_norm_stats["global_lambda_stds"]
global_t_means        = edge_norm_stats["global_t_means"]
global_t_stds         = edge_norm_stats["global_t_stds"]
global_z_means        = edge_norm_stats["global_z_means"]
global_z_stds         = edge_norm_stats["global_z_stds"]

graphs6 = []

# (a) if you want to apply class weighting. i found that it performed well without but perhaps there are other methods
real_tracks_df["mc_pid"] = real_tracks_df["hits"].apply(
    lambda hits: hits[0]["mc_pid"] if (isinstance(hits, (list, tuple)) and len(hits) > 0) else None
)
num_elec  = len(real_tracks_df[ real_tracks_df["mc_pid"] ==  11 ])
num_pos  = len(real_tracks_df[ real_tracks_df["mc_pid"] == -11 ])
total_real= num_elec + num_pos
print("total_real_tracks", total_real, "num_electrons", num_elec, "num_positrons", num_pos)

# (b) Build graphs for real‐tracks (label 0=e+, 1=e-)
for track in real_tracks_df.to_dict("records"):
    mc_pid = track.get("mc_pid", None)
    if mc_pid == -11:
        label = 0  # positron
    elif mc_pid == 11:
        label = 1  # electron
    else:
        # skip anything else (shouldn't happen)
        continue

    data = track_to_graph6(
        track, label, layer_map, geom_means, geom_stds, 
        global_distance_means, global_distance_stds,
        global_lambda_means, global_lambda_stds, global_t_means,
        global_t_stds, global_z_means, global_z_stds
    )
    # i attached a uniform weight as a placeholder but class weighting can be 
    # applied in the GNN notebook if you want to do that instead
    data.charge_weight = 1.0
    graphs6.append(data)

# (c) Build graphs for fake‐tracks (label = 2)
for track in fake_tracks_df.to_dict("records"):
    data = track_to_graph6(
        track, 2,  # label=2 means “fake” 
        layer_map, geom_means,
        geom_stds, global_distance_means, global_distance_stds,
        global_lambda_means, global_lambda_stds, global_t_means,
        global_t_stds, global_z_means, global_z_stds
    )
    data.charge_weight = 1.0
    graphs6.append(data)

print(f"Created {len(graphs6)} graphs.")
torch.save(graphs6, 'Graphs_beam_May31.pt') # change name here

In [None]:
# inspect a graph
# graphs6 = torch.load('Graphs_eee_May31.pt') # this is slow
print(f"Loaded {len(graphs6)} graphs.")

first_graph = graphs6[0]
print("Frame ID:")
print(first_graph.frameId)
print("First graph details:")
print(first_graph)
print("Attributes:", first_graph.keys())

print("Node features (x):")
print(first_graph.x)
print("Edge index:")
print(first_graph.edge_index)
print("Label:")
print(first_graph.label)
print("Graph attributes (raw and normalized parameters):")
print(first_graph.graph_attr)
print(first_graph.raw_graph_attr)
print("Edge attributes:")
print(first_graph.edge_attr)
print("Hit Positions:")
print(first_graph.hit_pos)
print('mc_pt =', first_graph.mc_pt)
print('mc_p =', first_graph.mc_p)
print('mc_phi =', first_graph.mc_phi)
print('mc_lam =', first_graph.mc_lam)
print('mc_theta =', first_graph.mc_theta)
print('mc_type =', first_graph.mc_type)
print('mc_tid =', first_graph.mc_tid)
print(" extra features required", first_graph.graph_attr.shape) # should be 10 ie. (torch.Size([1, 10]))

In [None]:
# number of different tids used to generate graphs
def count_unique_tids(mc_tid):
    if isinstance(mc_tid, list):
        return len(set(mc_tid))
    elif isinstance(mc_tid, int):
        return 1 if mc_tid != 0 else 0
    else:
        return 0
    
validated_tracks_df['num_unique_tids'] = validated_tracks_df['mc_tid'].apply(count_unique_tids)

print(validated_tracks_df['num_unique_tids'].value_counts().sort_index())

tid_counts = validated_tracks_df['num_unique_tids'].value_counts().sort_index().reset_index()
tid_counts.columns = ['num_unique_tids', 'count']
tid_counts = tid_counts[tid_counts['num_unique_tids'] > 0]

print("\nDistribution of Unique TIDs per Track:")
print(tid_counts)

plt.figure(figsize=(4, 4))
sns.set(style="ticks")

ax = sns.barplot(
    x='num_unique_tids',
    y='count',
    data=tid_counts,
    palette='magma'
)

colours = ['green' if n == 1 else 'red' for n in tid_counts['num_unique_tids']]
for patch, col in zip(ax.patches, colours):
    patch.set_facecolor(col)
    height = patch.get_height()
    ax.text(
        patch.get_x() + patch.get_width() / 2,
        height + (tid_counts['count'].max() * 0.01),
        f"{int(height):,}",
        ha='center',
        va='bottom',
        fontsize=12,
        color='black'
    )

plt.xlabel('N Unique TIDs per Track', fontsize=14)
plt.ylabel('N Validated Graphs', fontsize=14)
plt.xticks(ticks=range(len(tid_counts)), labels=tid_counts['num_unique_tids'], rotation=0, fontsize=12)
plt.yticks(fontsize=12)
plt.ylim(0, tid_counts['count'].max() * 1.1)
plt.grid(axis='y', linestyle='-', linewidth=0.5, alpha=0.3)
plt.tick_params(axis='both', labelsize=13)
plt.tight_layout()
plt.show() 

In [None]:
# Find the missing real tracks (lost through constraints)
found_ids = set(real_tracks_df['mc_tid'])
truth_ids = set(sixhit_tracks_df['mc_tid'])
missing_ids = truth_ids - found_ids
print(f"Total truth tracks: {len(truth_ids)}")
print(f"Total found by constraints: {len(found_ids)}")
print(f"Tracks missing: {len(missing_ids)}")

missing_real_tracks_df = sixhit_tracks_df[
    sixhit_tracks_df['mc_tid'].isin(missing_ids)
].drop_duplicates(subset=['mc_tid', 'frameId']).reset_index(drop=True)

In [None]:
# visualise a frame for the generated graphs (and missing real tracks)!
from plotlyhelper import visualise_frame1, wireframe_traces
# in frame 474 for signal (eee), you will see an example of two very close hits in layer 1
# one of these hits has mc_tid=0 (noise) and is used for making fakes.
# in this case, the noise hit is not due to the overlapping ladders,
# for overlapping double hits, you'd expect the electron path to be a fixed
# radius from the trajectory's bending centre. so this is correct.
# in the same frame however, it appears that the real track is missed

target_frame_id = 474  # Replace with frame ID

# Call the visualization function
visualise_frame1(
    frame_id=target_frame_id,
    hits_df_unique=hits_df_unique,               # Use hits_df_unique here
    real_tracks_df=real_tracks_df,               # Pass real_tracks_df
    fake_tracks_df=fake_tracks_df,               # Pass fake_tracks_df
    missing_real_tracks_df=missing_real_tracks_df,  # Pass missing_real_tracks here
    wireframe_traces=wireframe_traces            # Wireframe traces as before
)

### some other plots

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming 'real_tracks_df' and 'fake_tracks_df' are already defined

# Function to explode list columns into separate rows
def explode_list_column(df, column_name):
    return df.explode(column_name).reset_index(drop=True)

# Explode 'pitches' and 'radii' for both real and fake tracks
real_pitches = explode_list_column(real_tracks_df, 'pitches')['pitches']
fake_pitches = explode_list_column(fake_tracks_df, 'pitches')['pitches']

real_radii = explode_list_column(real_tracks_df, 'radii')['radii']
fake_radii = explode_list_column(fake_tracks_df, 'radii')['radii']

# real_angle_pT = explode_list_column(real_tracks_df, 'angle_diffs')['angle_diffs']
# fake_angle_pT = explode_list_column(fake_tracks_df, 'angle_diffs')['angle_diffs']

# Create DataFrames for plotting
plot_data = pd.DataFrame({
    'Pitch (mm/rad)': pd.concat([real_pitches, fake_pitches], ignore_index=True),
    'Radius (mm)': pd.concat([real_radii, fake_radii], ignore_index=True),
    'Track Type': ['Real'] * len(real_pitches) + ['Fake'] * len(fake_pitches)
})

# Separate the real and fake data for pitches
real_pitches_data = plot_data[plot_data['Track Type'] == 'Real']['Pitch (mm/rad)']
fake_pitches_data = plot_data[plot_data['Track Type'] == 'Fake']['Pitch (mm/rad)']

# Define bin edges for pitches
pitch_bins = np.linspace(-250, 250, 51)

# Compute histograms for pitches
real_pitches_counts, pitch_edges = np.histogram(real_pitches_data, bins=pitch_bins, density=True)
fake_pitches_counts, _ = np.histogram(fake_pitches_data, bins=pitch_bins, density=True)

# Plot real pitches as a filled histogram with blue opacity and outer edges outlined
plt.figure(figsize=(6, 4))
plt.hist(real_pitches_data, bins=pitch_edges, density=True, color='green', alpha=0.3, histtype='stepfilled', label='Real')
plt.step(pitch_edges[:-1], real_pitches_counts, where='post', color='green', linewidth=1.5)

# Plot fake pitches as red triangles with black edges
pitch_centers = 0.5 * (pitch_edges[:-1] + pitch_edges[1:])
plt.scatter(pitch_centers, fake_pitches_counts, color='red', edgecolor='black', marker='^', label='Fake')

plt.tick_params(axis='both', labelsize=14)
plt.xlabel('Pitch (mm/rad)', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.xlim(-250, 250)
plt.grid(True)
plt.legend()
plt.show()

# Separate the real and fake data for radius
real_radius_data = plot_data[plot_data['Track Type'] == 'Real']['Radius (mm)']
fake_radius_data = plot_data[plot_data['Track Type'] == 'Fake']['Radius (mm)']

# Define bin edges for radius
radius_bins = np.linspace(0, 300, 31)

# Compute histograms for radius
real_radius_counts, radius_edges = np.histogram(real_radius_data, bins=radius_bins, density=True)
fake_radius_counts, _ = np.histogram(fake_radius_data, bins=radius_bins, density=True)

# Plot real radius as a filled histogram with blue opacity and outer edges outlined
plt.figure(figsize=(6, 4))
plt.hist(real_radius_data, bins=radius_edges, density=True, color='green', alpha=0.3, histtype='stepfilled', label='Real')
plt.step(radius_edges[:-1], real_radius_counts, where='post', color='green', linewidth=1.5)

# Plot fake radius as red triangles with black edges
radius_centers = 0.5 * (radius_edges[:-1] + radius_edges[1:])
plt.scatter(radius_centers, fake_radius_counts, color='red', edgecolor='black', marker='^', label='Fake')

plt.tick_params(axis='both', labelsize=14)
plt.xlabel('Bending Radius (mm)', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.xlim(0, 300)
plt.grid(True)
plt.legend()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def extract_constraints(track):
    hits = track['hits']
    radial_12 = np.sqrt((hits[0]['x'] - hits[1]['x'])**2 + (hits[0]['y'] - hits[1]['y'])**2)
    radial_34 = np.sqrt((hits[2]['x'] - hits[3]['x'])**2 + (hits[2]['y'] - hits[3]['y'])**2)
    radial_56 = np.sqrt((hits[4]['x'] - hits[5]['x'])**2 + (hits[4]['y'] - hits[5]['y'])**2)
    z_diff_12 = hits[0]['z'] - hits[1]['z']
    z_diff_34 = hits[2]['z'] - hits[3]['z']
    z_diff_56 = hits[4]['z'] - hits[5]['z']
    return pd.Series({
        'radial_12': radial_12, 'radial_34': radial_34, 'radial_56': radial_56,
        'z_diff_12': z_diff_12, 'z_diff_34': z_diff_34, 'z_diff_56': z_diff_56
    })

# Extract constraint values for all tracks and merge with the DataFrame
constraints_df = validated_tracks_df.apply(extract_constraints, axis=1)
validated_tracks_df = pd.concat([validated_tracks_df, constraints_df], axis=1)

# Separate real and fake tracks
real_tracks_df = validated_tracks_df[validated_tracks_df['is_real_track']]
fake_tracks_df = validated_tracks_df[~validated_tracks_df['is_real_track']]

def melt_constraints(df, keys):
    return df[keys].melt(var_name='constraint', value_name='value')

radial_keys = ['radial_12', 'radial_34', 'radial_56']
z_keys = ['z_diff_34', 'z_diff_56']

real_radial = melt_constraints(real_tracks_df, radial_keys)
fake_radial = melt_constraints(fake_tracks_df, radial_keys)
real_z = melt_constraints(real_tracks_df, z_keys)
fake_z = melt_constraints(fake_tracks_df, z_keys)

# Define bins for radial distances and z differences
radial_bins = np.linspace(0, 30, 31)  # 30 bins from 0 to 30 mm
z_bins = np.linspace(-55, 55, 31)       # 30 bins from -55 to 55 mm

# Extract the data (assumes real_radial and fake_radial dataframes exist as per previous cell)
real_radial_data = real_radial['value']
fake_radial_data = fake_radial['value']

# Compute histograms for radial distances
real_radial_counts, _ = np.histogram(real_radial_data, bins=radial_bins, density=True)
fake_radial_counts, _ = np.histogram(fake_radial_data, bins=radial_bins, density=True)
radial_centers = 0.5 * (radial_bins[:-1] + radial_bins[1:])

# Plot Radial Distances Histogram
plt.figure(figsize=(6, 4))
plt.hist(real_radial_data, bins=radial_bins, density=True, color='green', alpha=0.3,
         histtype='stepfilled', label='Real')
plt.step(radial_bins[:-1], real_radial_counts, where='post', color='green', linewidth=1.5)
plt.scatter(radial_centers, fake_radial_counts, color='red', edgecolor='black', marker='^', label='Fake')
plt.tick_params(axis='both', labelsize=12)
plt.xlabel('Radial Distance between Nodes (mm)', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.xlim(0, 30)
plt.grid(True)
plt.legend()
plt.show()

# Extract the data for z differences (assumes real_z and fake_z dataframes exist)
real_z_data = real_z['value']
fake_z_data = fake_z['value']

# Compute histograms for z differences
real_z_counts, _ = np.histogram(real_z_data, bins=z_bins, density=True)
fake_z_counts, _ = np.histogram(fake_z_data, bins=z_bins, density=True)
z_centers = 0.5 * (z_bins[:-1] + z_bins[1:])

# Plot Z Differences Histogram
plt.figure(figsize=(6, 4))
plt.hist(real_z_data, bins=z_bins, density=True, color='green', alpha=0.3,
         histtype='stepfilled', label='Real')
plt.step(z_bins[:-1], real_z_counts, where='post', color='green', linewidth=1.5)
plt.scatter(z_centers, fake_z_counts, color='red', edgecolor='black', marker='^', label='Fake')
plt.tick_params(axis='both', labelsize=12)
plt.xlabel('Change in z between Nodes (mm)', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.xlim(-55, 55)
plt.grid(True)
plt.legend()
plt.show()