In [None]:
# Imports
import uproot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset, DataLoader, Batch
from torch.utils.data import Subset
from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm, GINEConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score, precision_recall_curve, roc_curve,
    classification_report, auc, accuracy_score
)
from collections import Counter
from sklearn.utils import resample
import math
import random
import seaborn as sns
from itertools import product
import os
import copy

# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Random Seed for Reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

pd.set_option('display.max_colwidth', None)  # 'None' removes the limit

## GNN Training from the Graphs saved in the Graph generation notebook

In [None]:
# load the graphs saved in the graph generation notebook
from loadgraphs import force_label_long, extract_unique_real, save_unique_real
graphs = torch.load('Graphs_eee_May31.pt')
test_graphs_IC = torch.load('Graphs_eeevv_May31.pt')
test_graphs_M = torch.load('Graphs_evv_May31.pt')
test_graphs_beam = torch.load('Graphs_beam_May31.pt')

graphs = force_label_long(graphs)
test_graphs_IC = force_label_long(test_graphs_IC)
test_graphs_M = force_label_long(test_graphs_M)
test_graphs_beam = force_label_long(test_graphs_beam)

# eee signal split into training, validation and testing sets
train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=43)
train_graphs, val_graphs = train_test_split(train_graphs, test_size=0.2, random_state=43)
print('signal: train',len(train_graphs),'val',len(val_graphs),'test',len(test_graphs))
train_loader = DataLoader(train_graphs, batch_size=128, shuffle=True, exclude_keys=['mc_pid', 'mc_tid'])
val_loader = DataLoader(val_graphs, batch_size=128, exclude_keys=['mc_pid', 'mc_tid'])
test_loader = DataLoader(test_graphs, batch_size=128, exclude_keys=['mc_pid', 'mc_tid'])

label_counts = Counter([data.label.item() for data in graphs])
print("Graph counts per class (signal):", label_counts)

# eeevv internal conversion
test_loader_IC = DataLoader(test_graphs_IC, batch_size=128, shuffle=False, exclude_keys=['mc_pid'])

label_counts = Counter([data.label.item() for data in test_graphs_IC])
print("Graph counts per class (IC):", label_counts)

# evv Michel
test_loader_M = DataLoader(test_graphs_M, batch_size=128, shuffle=False, exclude_keys=['mc_pid'])

print(f"Loaded {len(test_graphs_M)} Michel (evv) graphs for testing.")
label_counts = Counter([data.label.item() for data in test_graphs_M])
print("Graph counts per class (Michel):", label_counts)

# beam
test_loader_beam = DataLoader(test_graphs_beam, batch_size=128, shuffle=False, exclude_keys=['mc_pid'])

print(f"Loaded {len(test_graphs_beam)} beam graphs for testing.")
label_counts = Counter([data.label.item() for data in test_graphs_beam])
print("Graph counts per class (beam):", label_counts)

# save the mc_tids (and frameids) as this lets you know directly of the constraint/graph generation 
# efficiency before GNN efficiency is found.
# below finds only for real graphs (with labels 0 and 1).
save_unique_real(graphs,            "postconstraints_ids_signal.csv")
save_unique_real(test_graphs,       "testsubset_ids_signal.csv")
save_unique_real(test_graphs_IC,    "postconstraints_ids_IC.csv")
save_unique_real(test_graphs_M,     "postconstraints_ids_michel.csv")
save_unique_real(test_graphs_beam,  "postconstraints_ids_beam.csv")

In [None]:
# Train (GCN or GINE)
from defineGNNmodel import GCNMultiClass, GINEMultiClass
# Purity is the number of correct predictions out of all predictions made for a specific class (ie class e+, e- or fake).
# Efficiency is the number of correct predictions out of all available graphs in a class.
# the report has results for the GCN with 31May graphs as they are setup: in_channels=11, hidden_channels=64, extra_features=10
# i have also tested a GINE model and it produces very similar results.
# models need to be trained once and can be reloaded. make sure to save the model in the next cell.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 

model = GCNMultiClass(in_channels=11, hidden_channels=64, extra_features=10).to(device) 
# model = GINEMultiClass(in_channels=11, hidden_channels=64, edge_channels=5, extra_features=10).to(device) 

# Compute pos_weight based on training set
num_p = sum(1 for data in train_graphs if data.label.item() == 0)  # positron
num_e = sum(1 for data in train_graphs if data.label.item() == 1)  # electron
num_fake = sum(1 for data in train_graphs if data.label.item() == 2) # fake
print("Positron count:", num_p, "Electron count:", num_e, "Fake count:", num_fake)

class_counts = np.array([num_p, num_e, num_fake], dtype=np.float32)
total_samples = class_counts.sum()
num_classes = 3
class_weights = total_samples / (num_classes * class_counts)
print("Class weights [e+, e-, fake]:", class_weights)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor) # it performs better with no weighting applied
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training parameters logits best-class predictions
epochs = 100
patience = 10
best_val_loss = float('inf')
best_model_state = copy.deepcopy(model.state_dict())
patience_counter = 0

train_losses, val_losses = [], []

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    all_train_preds = []
    all_train_labels = []
    
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        outputs = model(batch)  # outputs shape: [batch_size, 3]
        loss = criterion(outputs, batch.label)  # batch.label is [batch_size] of integers
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * batch.num_graphs
        
        # Predictions: take the argmax of logits
        preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
        labels = batch.label.detach().cpu().numpy()
        all_train_preds.extend(preds)
        all_train_labels.extend(labels)
    
    avg_train_loss = epoch_loss / len(train_loader.dataset)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            outputs = model(batch)
            loss = criterion(outputs, batch.label)
            val_loss += loss.item() * batch.num_graphs
            preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
            labels = batch.label.detach().cpu().numpy()
            val_preds.extend(preds)
            val_labels.extend(labels)
    avg_val_loss = val_loss / len(val_loader.dataset)
    val_losses.append(avg_val_loss)
    
    train_acc = accuracy_score(all_train_labels, all_train_preds)
    val_acc = accuracy_score(val_labels, val_preds)
    precision_per_class = precision_score(val_labels, val_preds, average=None, zero_division=0)
    recall_per_class = recall_score(val_labels, val_preds, average=None, zero_division=0)
    f1_per_class = f1_score(val_labels, val_preds, average=None, zero_division=0)

    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    print(f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
    print("Purity per class:", precision_per_class)
    print("Efficiency per class:", recall_per_class)
    print("F1 per class:", f1_per_class)
    
    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = copy.deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

# Load best model state
model.load_state_dict(best_model_state)
print("Training complete.")

In [None]:
# torch.save(model.state_dict(), 'TrainedGCNModelJun02.pt')

In [None]:
# load model
from defineGNNmodel import GCNMultiClass, GINEMultiClass
model = GCNMultiClass(in_channels=11, hidden_channels=64, extra_features=10).to(device) # 
model.load_state_dict(torch.load('TrainedGCNModelJun02.pt', map_location=device))

# model = GINEMultiClass(in_channels=11, hidden_channels=128, edge_channels=5, extra_features=10).to(device) 
# model.load_state_dict(torch.load('TrainedGINEModelJun02.pt', map_location=device))

# Set the model to evaluation mode.
model.eval()

In [None]:
print("first graph", graphs[0])

In [None]:
# Plot Train and Validation Loss vs Epoch
import matplotlib.pyplot as plt
plt.figure(figsize=(4.5, 3))
epochs_range = range(1, len(train_losses) + 1)
plt.plot(epochs_range, train_losses, label="Training Loss")
plt.plot(epochs_range, val_losses, label="Validation Loss")
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.tick_params(axis="both", labelsize=11)
plt.legend(fontsize=11)
plt.grid(True)
plt.show()
best_loss = min(val_losses)
best_epoch = np.argmin(val_losses) + 1  # +1 to convert from 0-index to epoch number
print(f"Best Validation Loss: {best_loss:.4f} at Epoch: {best_epoch}")

In [None]:
# DEDUPLICATE the real graphs (not overlap removal)
# this cell ensures an mc_tid appears only once in all real (0 for e+, 1 for e-) graphs.
# graphs have duplicates in the sense that the double hit has led to two graphs that only differ by one hit. as the double
# hits both lie on the particle's bending circle, either graph can be saved. here, i save the one with higher GNN output.
# this requires mc_tid, so this is done only for analysis but if this were to run on real beam data, a method to 
# determine if duplicate graphs belong to the same particle is needed.
# this is only done for real graphs for correct efficiency calculation.
from deduplicategraphs import deduplicate_real_graphs, compute_real_graph_preds

# find preds of real graphs for deduplication
compute_real_graph_preds(test_graphs, model, device)
compute_real_graph_preds(test_graphs_beam, model, device)
compute_real_graph_preds(test_graphs_M, model, device)
compute_real_graph_preds(test_graphs_IC, model, device)

# Apply deduplication to the test sets.
d_test_graphs = deduplicate_real_graphs(test_graphs)
d_test_graphs_beam = deduplicate_real_graphs(test_graphs_beam)
d_test_graphs_M = deduplicate_real_graphs(test_graphs_M)
d_test_graphs_IC = deduplicate_real_graphs(test_graphs_IC)
print("Total test graphs before,after deduplication (eee):", len(test_graphs),',', len(d_test_graphs))
print("Total test graphs before,after deduplication (beam):", len(test_graphs_beam),',', len(d_test_graphs_beam))
print("Total test graphs before,after deduplication (M):", len(test_graphs_M),',', len(d_test_graphs_M))
print("Total test graphs before,after deduplication (IC):", len(test_graphs_IC),',', len(d_test_graphs_IC))

# Create a new DataLoader from the deduplicated test graphs
# remove keys that only appear in real graphs
d_test_loader = DataLoader(d_test_graphs, batch_size=128,           exclude_keys=['mc_pid', 'mc_tid', 'pred_confidence', 'pred_label', 'removal_counts'])
d_test_loader_beam = DataLoader(d_test_graphs_beam, batch_size=128, exclude_keys=['mc_pid', 'mc_tid', 'pred_confidence', 'pred_label', 'removal_counts'])
d_test_loader_M = DataLoader(d_test_graphs_M, batch_size=128,       exclude_keys=['mc_pid', 'mc_tid', 'pred_confidence', 'pred_label', 'removal_counts'])
d_test_loader_IC = DataLoader(d_test_graphs_IC, batch_size=128,     exclude_keys=['mc_pid', 'mc_tid', 'pred_confidence', 'pred_label', 'removal_counts'])

In [None]:
# find pre-OR efficiencies with clopper pearson 1sigma errors
# efficiency is the number of real graphs predicted real out of all real graphs available
# purity     is the number of real graphs predicted real out of all graphs predicted real
# the terms precision (for purity) and recall (for efficiency) are used by the default library classification_report
from evaluategraphs import evaluate_combined, evaluate_for_class, evaluate_with_argmax_combined, evaluate_class_with_argmax, efficiency_with_CP

y_true_sig,  y_pred_sig, cm, cr  = evaluate_with_argmax_combined(model, d_test_loader,    device)
y_true_beam, y_pred_beam, _, _ = evaluate_with_argmax_combined(model, d_test_loader_beam, device)
y_true_m,    y_pred_m, _, _    = evaluate_with_argmax_combined(model, d_test_loader_M,    device)
y_true_ic,   y_pred_ic, _, _   = evaluate_with_argmax_combined(model, d_test_loader_IC,   device)
print("signal: confusion matrix and classification report")
print(cm)
print(cr)

In [None]:
# plot clopper pearson one sigma errors before O.R
pre_metrics_eee  = efficiency_with_CP(y_true_sig,  y_pred_sig,  "Signal")
pre_metrics_beam = efficiency_with_CP(y_true_beam, y_pred_beam, "Beam")
pre_metrics_evv  = efficiency_with_CP(y_true_m,    y_pred_m,    "Michel")
pre_metrics_eeevv = efficiency_with_CP(y_true_ic,   y_pred_ic,   "I.C.")

all_pre_metrics = {
    'Beam': pre_metrics_beam,
    'Signal': pre_metrics_eee,
    'I.C.': pre_metrics_eeevv,
    'Michel': pre_metrics_evv
}

# For plotting, extract efficiencies and confidence intervals:
datasets = []
efficiencies = []
lower_bounds = []
upper_bounds = []

for ds, met in all_pre_metrics.items():
    if met:  # skip empty dictionaries
        datasets.append(ds)
        efficiencies.append(met['efficiency'])
        lower_bounds.append(met['lower_bound'])
        upper_bounds.append(met['upper_bound'])

# Compute symmetric error bars:
errors = [(up - low) / 2 for up, low in zip(upper_bounds, lower_bounds)]

# Now you can use these values for plotting:
import matplotlib.pyplot as plt
import numpy as np

y_pos = np.arange(len(datasets))

plt.figure(figsize=(6.3, 1.33))
plt.errorbar(efficiencies, y_pos, xerr=errors, fmt='o', color='black', ecolor='gray', capsize=5, label='Clopper–Pearson\n1σ errors')
plt.yticks(y_pos, datasets, fontsize=12)
plt.tick_params(axis='both', labelsize=12)
plt.xlabel("Pre-O.R. Efficiency", fontsize=14)
plt.legend(fontsize=11, framealpha=0.7, bbox_to_anchor=(1.4,1))
plt.grid(True, axis='x', linestyle='--', alpha=0.7)
plt.xlim(0.993, 1.00)  # adjust as needed
plt.ylim(-0.5, len(datasets) - 0.5)
plt.show()

In [None]:
# find (pre OR) per-class results for signal the plot in the cell below
results_electrons = evaluate_for_class(model, d_test_loader, class_index=1, dataset_name="Test Set - Electrons")
results_positrons = evaluate_for_class(model, d_test_loader, class_index=0, dataset_name="Test Set - Positrons")

In [None]:
# multi; eee; combine electron and positron curves. 
print("eee")
plt.figure(figsize=(4.5, 3))
plt.plot(results_electrons['efficiency'], results_electrons['brr'], 
         label='Electrons', color='blue', lw=2)
plt.plot(results_positrons['efficiency'], results_positrons['brr'], 
         label='Positrons', color='red', lw=2)
plt.axhline(y=1, color='gray', linestyle='--', lw=2, label='No Discriminatory Rejection')
plt.yscale('log')
plt.xlabel('Efficiency', fontsize=14)
plt.ylabel('Background Rejection Rate', fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.ylim(top=1e5)
plt.xlim(-0.05, 1.05)
plt.legend(loc="lower left", fontsize=11, framealpha=0.3)
plt.grid(True)
plt.show()

# For electrons:
fpr_e, tpr_e, _ = roc_curve(results_electrons['y_true'], results_electrons['y_probs'])
roc_auc_e = auc(fpr_e, tpr_e)
# For positrons:
fpr_p, tpr_p, _ = roc_curve(results_positrons['y_true'], results_positrons['y_probs'])
roc_auc_p = auc(fpr_p, tpr_p)
epsilon = 1e-6
# fpr_nonzero = np.maximum(fpr, epsilon)
fpr_diag = np.linspace(epsilon, 1, 100)

plt.figure(figsize=(4.5, 3))
plt.plot(fpr_e, tpr_e, color='blue', lw=2, 
         label=f'Electrons AUC = {roc_auc_e:.4f}')
plt.plot(fpr_p, tpr_p, color='red', lw=2, 
         label=f'Positrons AUC = {roc_auc_p:.4f}')
# plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.plot(fpr_diag, np.linspace(0, 1, 100), color='grey', linestyle='--', label='')
plt.xscale('log', base=10)
plt.xlim(5e-5, 1)
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('Efficiency', fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.legend(loc="lower right", fontsize=11, framealpha=0.6)#, bbox_to_anchor=(0.5, 1.3))
plt.grid(True)
plt.show()

In [None]:
# e+ and e- separation confusion matrix (pre O.R.)
from evaluategraphs import e_separation

e_separation(d_test_graphs, "eee", model, device)
e_separation(d_test_graphs_beam, "Beam", model, device)
e_separation(d_test_graphs_M, "evv", model, device)
e_separation(d_test_graphs_IC, "eeevv", model, device)

In [None]:
# Apply overlap removal to the test sets
from overlapremoval import perform_overlap_removal

OR_test_graphs, removal_log, removal_confidence_diffs, removed_true_real, removed_true_fake = perform_overlap_removal(d_test_graphs, model, dataset_name="eee")
OR_test_graphs_beam, removal_log_beam, removal_confidence_diffs_beam, removed_true_real_beam, removed_true_fake_beam = perform_overlap_removal(d_test_graphs_beam, model, dataset_name="Beam")
OR_test_graphs_M, removal_log_M, removal_confidence_diffs_M, removed_true_real_M, removed_true_fake_M = perform_overlap_removal(d_test_graphs_M, model, dataset_name="evv")
OR_test_graphs_IC, removal_log_IC, removal_confidence_diffs_IC, removed_true_real_IC, removed_true_fake_IC = perform_overlap_removal(d_test_graphs_IC, model, dataset_name="eeevv")

print("Overlap Removal Log (eee):", removal_log)

In [None]:
# save the mc_tids and frameids of the tracks that have now survived the full process:
# unique hits > constraints > overlap removal > GNN output above 0.5 on surivors
from evaluategraphs import save_OR_mc_tids
save_OR_mc_tids(OR_test_graphs, "postOR_ids_signal.csv", confidence_threshold=0.5)
# save_OR_mc_tids(OR_test_graphs_IC, "postOR_ids_signal_IC.csv", confidence_threshold=0.5)
# save_OR_mc_tids(OR_test_graphs_M, "postOR_ids_signal_michel.csv", confidence_threshold=0.5)
# save_OR_mc_tids(OR_test_graphs_beam, "postOR_ids_signal_beam.csv", confidence_threshold=0.5)

In [None]:
# view results after OR (clopper pearson 1sigma errors)
from evaluategraphs import compute_OR_metrics

post_metrics_eee = compute_OR_metrics(d_test_graphs, OR_test_graphs, "eee")
post_metrics_evv = compute_OR_metrics(d_test_graphs_M, OR_test_graphs_M, "evv")
post_metrics_eeevv = compute_OR_metrics(d_test_graphs_IC, OR_test_graphs_IC, "eeevv")
post_metrics_beam = compute_OR_metrics(d_test_graphs_beam, OR_test_graphs_beam, "Beam")

In [None]:
# plot clopper pearson one sigma errors after O.R.
all_metrics = {
    'Beam': {}, # post_metrics_beam 
    'Michel': post_metrics_evv,
    'I.C.': post_metrics_eeevv,
    'Signal': post_metrics_eee
}

# For plotting, extract efficiencies and confidence intervals:
datasets = []
efficiencies = []
lower_bounds = []
upper_bounds = []

for ds, met in all_metrics.items():
    if met:  # skip empty dictionaries
        datasets.append(ds)
        efficiencies.append(met['final_efficiency'])
        lower_bounds.append(met['lower_bound'])
        upper_bounds.append(met['upper_bound'])

# Compute symmetric error bars:
errors = [(up - low) / 2 for up, low in zip(upper_bounds, lower_bounds)]

# Now you can use these values for plotting:
import matplotlib.pyplot as plt
import numpy as np

y_pos = np.arange(len(datasets))

plt.figure(figsize=(6.3, 1.0))
plt.errorbar(efficiencies, y_pos, xerr=errors, fmt='o', color='black', ecolor='gray', capsize=5, label='Clopper–Pearson\n1σ errors')
plt.yticks(y_pos, datasets, fontsize=12)
plt.tick_params(axis='both', labelsize=12)
plt.xlabel("Post-O.R. Efficiency", fontsize=14)
plt.legend(fontsize=11, framealpha=0.7, bbox_to_anchor=(1.4,1))
plt.grid(True, axis='x', linestyle='--', alpha=0.7)
plt.xlim(0.993, 1.00)  # adjust as needed
plt.ylim(-0.5, len(datasets) - 0.5)
plt.show()

In [None]:
# view tracks that were true real, predicted real but then removed by a true fake with higher confidence with a shared hit
for g in removed_true_real[:10]:    
    print(f"Frame {g.frameId}, tid {g.mc_tid}")

In [None]:
# plotly visualisation to look at frames in 'removed_true_real'
from plotlyhelper import visualise_truth_tracks, wireframe_traces
tracks_df = torch.load('true_tracks_eee_May31.pt')
target_frame_id = 22205  # Replace with desired frame id
visualise_truth_tracks(target_frame_id, tracks_df, wireframe_traces)

### Distributions

In [None]:
# create DFs for all sets
# allows you to see histograms related to efficiency, purity and graph features or truth information
from create_dfs import create_df_from_graphs

df_eee    = create_df_from_graphs(d_test_graphs, model)
df_IC     = create_df_from_graphs(d_test_graphs_IC, model)
df_michel = create_df_from_graphs(d_test_graphs_M, model)
df_beam   = create_df_from_graphs(d_test_graphs_beam, model)

In [None]:
# chord lengths and purity
# split df_eee into its classes to do a plot like below
import numpy as np
import matplotlib.pyplot as plt
df_positrons = df_eee[df_eee['label'] == 0]   # actually only needed if you want to inspect true counts
df_electrons  = df_eee[df_eee['label'] == 1]
df_fake       = df_eee[df_eee['label'] == 2]

# ----- Setup for Chord Length -----
n_bins = 20
min_val = df_eee['chord_length'].min()
max_val = df_eee['chord_length'].max()
bins = np.linspace(min_val, max_val, n_bins + 1)
bin_centers = (bins[:-1] + bins[1:]) / 2
bar_width = bins[1] - bins[0]

# ----- Figure 1: Total Counts for Chord Length -----
# Compute total counts per bin for each class.
total_e, _ = np.histogram(df_electrons['chord_length'], bins=bins)
total_p, _ = np.histogram(df_positrons['chord_length'], bins=bins)
total_f, _ = np.histogram(df_fake['chord_length'], bins=bins)

plt.figure(figsize=(4.5,3))
# Plot electrons as blue unfilled circles.
plt.scatter(bin_centers, total_e, marker='o', s=50, linewidth=1.5, 
            facecolors='none', edgecolors='blue', label='Total Electrons')
# Plot positrons as red plus markers.
plt.scatter(bin_centers, total_p, marker='+', s=75, color='red', label='Total Positrons')
# Plot fakes as an unfilled histogram.
plt.hist(df_fake['chord_length'], bins=bins, histtype='step', 
         color='grey', linewidth=2, label='Total Fakes')

plt.xlabel("Chord Length (mm)", fontsize=14)
plt.ylabel("Count", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.legend(loc='upper right', bbox_to_anchor=(1.2,1))
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

# ----- Figure 2: Purity and Efficiency vs. Chord Length -----
def compute_pred_counts_and_purity(df, bins, column, label):
    # 1) total_pred_label: how many were predicted as class `label` in each bin?
    mask_pred = (df['pred_label'] == label)
    total_pred_label, _ = np.histogram(df.loc[mask_pred, column], bins=bins)

    # 2) correct_pred_label: of those predicted==label, how many were truly label==label?
    mask_correct = (df['pred_label'] == label) & (df['label'] == label)
    correct_pred_label, _ = np.histogram(df.loc[mask_correct, column], bins=bins)

    # 3) purity_label = correct_pred_label / total_pred_label  (NaN if total_pred_label == 0)
    with np.errstate(divide='ignore', invalid='ignore'):
        purity_label = correct_pred_label / total_pred_label
        purity_label[total_pred_label == 0] = np.nan

    # 4) total_true_label: how many truly belong to `label` in each bin?
    mask_true = (df['label'] == label)
    total_true_label, _ = np.histogram(df.loc[mask_true, column], bins=bins)

    # 5) efficiency_label = correct_pred_label / total_true_label  (NaN if total_true_label == 0)
    with np.errstate(divide='ignore', invalid='ignore'):
        efficiency_label = correct_pred_label / total_true_label
        efficiency_label[total_true_label == 0] = np.nan

    return total_pred_label, correct_pred_label, purity_label, total_true_label, efficiency_label

tot_pred_e, corr_pred_e, purity_e, tot_true_e, eff_e = compute_pred_counts_and_purity(df_eee, bins, 'path_length', label=1)
tot_pred_p, corr_pred_p, purity_p, tot_true_p, eff_p = compute_pred_counts_and_purity(df_eee, bins, 'path_length', label=0)
tot_pred_f, corr_pred_f, purity_f, tot_true_f, eff_f = compute_pred_counts_and_purity(df_eee, bins, 'path_length', label=2)

# Now you can plot Purity vs. feature:
plt.figure(figsize=(4.5,3))
plt.scatter(bin_centers, purity_e, marker='o', edgecolor='blue', facecolors='none', label='Electron Purity')
plt.scatter(bin_centers, purity_p, marker='+', color='red',   label='Positron Purity')
plt.step(bin_centers,     purity_f, where='mid', color='gray',linewidth=2, label='Fake Purity')

plt.xlabel("Path Length (mm)", fontsize=14)
plt.ylabel("Purity (Precision)", fontsize=14)
plt.ylim(0,1.05)
plt.legend(loc='lower right')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

# And separately plot Efficiency vs. feature:
plt.figure(figsize=(4.5,3))
plt.scatter(bin_centers, eff_e, marker='o', edgecolor='blue', facecolors='none', label='Electron Efficiency')
plt.scatter(bin_centers, eff_p, marker='+', color='red',   label='Positron Efficiency')
plt.step(bin_centers,     eff_f, where='mid', color='gray',linewidth=2, label='Fake Efficiency')

plt.xlabel("Path Length (mm)", fontsize=14)
plt.ylabel("Efficiency (Recall)", fontsize=14)
plt.ylim(0,1.05)
plt.legend(loc='lower right')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

In [None]:
# efficiency against truth info!
from GNNplots import plot_efficiency_vs_feature_step
# 1) true momentum mc_p
plot_efficiency_vs_feature_step(df_signal=df_eee, df_IC=df_IC, df_michel=df_michel, df_beam=df_beam, feature="mc_p", x_label=r"$p_{\mathrm{true}}$ [MeV]", label_signal="Signal", label_IC="I.C.", label_michel="Michel", label_beam="Beam", bins=40)
# 2) true transverse momentum mc_pt
plot_efficiency_vs_feature_step(df_signal=df_eee, df_IC=df_IC, df_michel=df_michel, df_beam=df_beam, feature="mc_pt", x_label=r"$p_{T,\mathrm{true}}$ [MeV]", label_signal="Signal", label_IC="I.C.", label_michel="Michel", label_beam="Beam", bins=40)
# 3) true phi mc_phi
plot_efficiency_vs_feature_step(df_signal=df_eee, df_IC=df_IC, df_michel=df_michel, df_beam=df_beam, feature="mc_phi", x_label=r"$\phi_{\mathrm{true}}$ [rad]", label_signal="Signal", label_IC="I.C.", label_michel="Michel", label_beam="Beam", bins=40)
# 4) true lambda mc_lam
plot_efficiency_vs_feature_step(df_signal=df_eee, df_IC=df_IC, df_michel=df_michel, df_beam=df_beam, feature="mc_lam", x_label=r"$\lambda_{\mathrm{true}}$ [rad]", label_signal="Signal", label_IC="I.C.", label_michel="Michel", label_beam="Beam", bins=40)

In [None]:
# average purity vs n graphs in a frame
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def plot_purity_vs_graphcount_multi(
    dfs, labels, colors, linestyles,
    max_graphs=800, bin_width=40
):
    """
    Overlayed step‐plots of average purity vs. #graphs/frame,
    grouping counts into bins of width `bin_width`.
    """
    plt.figure(figsize=(4.5,3))
    
    # 1) Define the half‐shifted edges for the step plot:
    edges = np.arange(0.5, max_graphs + bin_width + 0.5, bin_width)
    # 2) Compute the bin‐centres for plotting & ticks
    centers = edges[:-1] + bin_width/2

    for df, label, c, ls in zip(dfs, labels, colors, linestyles):
        by_frame = (
            df.groupby('frameId')
              .agg(TP=('TP','sum'),
                   FP=('FP','sum'),
                   n_graphs=('frameId','size'))
              .reset_index()
        )
        by_frame['purity'] = by_frame['TP'] / (by_frame['TP'] + by_frame['FP'])
        by_frame = by_frame.dropna(subset=['purity'])

        # cap counts > max_graphs
        by_frame.loc[by_frame['n_graphs'] > max_graphs, 'n_graphs'] = max_graphs

        # assign to bin index
        bin_idx = np.floor((by_frame['n_graphs'] - 1) / bin_width).astype(int)
        by_frame['bin'] = bin_idx

        # average purity per bin
        summary = (
            by_frame.groupby('bin')['purity']
                    .mean()
                    .reindex(range(len(centers)), fill_value=np.nan)
        )

        # build step coords
        xs = np.repeat(edges, 2)[1:-1]
        ys = np.repeat(summary.values, 2)

        plt.step(xs, ys, where='pre', color=c, linestyle=ls, label=label)

    plt.xlabel("Number of Graphs per Frame", fontsize=14)
    plt.ylabel("Average Purity",              fontsize=14)
    plt.tick_params(axis='both', labelsize=12)
    plt.xlim(0.5, max_graphs + 0.5)

    # only label every other bin centre
    tick_centres = centers[::2]
    tick_labels  = [f"{int(c)}" for c in tick_centres]
    plt.xticks(tick_centres, tick_labels, rotation=45)

    plt.ylim(0, 1.05)
    plt.grid(True, linestyle='--', alpha=0.7)
    # plt.legend(fontsize=11, loc="lower right", bbox_to_anchor=(1.33,0))
    plt.show()

# Usage
dfs        = [df_eee, df_IC, df_michel, df_beam]
labels     = ["Signal", "I.C.", "Michel", "Beam"]
colors     = ["blue",  "orange",    "red",    "green"]
linestyles = ["-",     "--",         ":",      "-."]

plot_purity_vs_graphcount_multi(
    dfs, labels, colors, linestyles,
    max_graphs=800,
    bin_width=40   # combine every 100 graph‐counts into one bin
)

In [None]:
# average purity vs n hits and n graphs in a frame 
from GNNplots import plot_purity_vs_graphcount_multi, plot_purity_vs_hitcount_multi
dfs        = [df_eee, df_IC, df_michel, df_beam]
labels     = ["Signal", "I.C.", "Michel", "Beam"]
colors     = ["blue",  "orange",    "red",    "green"]
linestyles = ["-",     "--",         ":",      "-."]

plot_purity_vs_graphcount_multi(dfs, labels, colors, linestyles, max_graphs=800, bin_width=40)

hits_df_eee   = torch.load('unique_hits_eee.pt')
hits_df_eeevv = torch.load('unique_hits_eeevv.pt')
hits_df_beam  = torch.load('unique_hits_beam.pt')
hits_df_evv   = torch.load('unique_hits_evv.pt')
hits_dfs = [hits_df_eee, hits_df_eeevv, hits_df_evv, hits_df_beam]
plot_purity_vs_hitcount_multi(dfs, hits_dfs, labels, colors, linestyles, start_hit = 6, max_hits = 20, michel_max = 6)

#### Purity, Efficiency, Threshold

In [None]:
# save MULTI; eee; purity, efficiency, threshold
def compute_purity(y_true, y_probs, thresholds):
    """Compute purity (precision) at each threshold."""
    purity = []
    for thresh in thresholds:
        y_pred = (y_probs >= thresh).astype(int)
        TP = np.sum((y_pred == 1) & (y_true == 1))
        FP = np.sum((y_pred == 1) & (y_true == 0))
        if TP + FP > 0:
            purity.append(TP / (TP + FP))
        else:
            purity.append(np.nan)
    return np.array(purity)

# Compute purity for each threshold
test_thresholds_e = results_electrons['sampled_thresholds'] # replace with(/out) _IC if wanted
test_efficiency_e = results_electrons['efficiency'] #
test_y_true_e = results_electrons['y_true'] #
test_y_probs_e = results_electrons['y_probs'] #
test_purity_e = compute_purity(test_y_true_e, test_y_probs_e, test_thresholds_e)

# Compute purity for each threshold
test_thresholds_p = results_positrons['sampled_thresholds'] #
test_efficiency_p = results_positrons['efficiency'] #
test_y_true_p = results_positrons['y_true'] #
test_y_probs_p = results_positrons['y_probs'] #
test_purity_p = compute_purity(test_y_true_p, test_y_probs_p, test_thresholds_p)

# Plot Purity vs Threshold
plt.figure(figsize=(4.5, 3))
plt.plot(test_thresholds_e[:-3], test_purity_e[:-3], label='Electrons', color='blue', lw=2)
plt.plot(test_thresholds_p[:-3], test_purity_p[:-3], label='Positrons', color='red', lw=2)

# plt.scatter(test_thresholds_p[::10], test_purity_p[::10], label='Positrons',
#             color='red', marker='+', s=50, zorder=3)
plt.xlabel("Required Score", fontsize=14)
plt.ylabel("Purity", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
# plt.legend(loc="lower right", fontsize=11)
plt.grid(True)
plt.ylim(0.88, 1.002)
plt.xlim(-0.02, 1.02)
plt.show()

# Plot Efficiency vs Threshold
plt.figure(figsize=(4.5, 3))
# Plot Validation Set as a continuous line
plt.plot(test_thresholds_e, test_efficiency_e, label='Electrons', color='blue', lw=2)
plt.plot(test_thresholds_p, test_efficiency_p, label='Positrons', color='red', lw=2)

# plt.scatter(test_thresholds_p[::50], test_efficiency_p[::50], label='Positrons',
#             color='red', marker='+', s=50, zorder=3)
plt.xlabel("Required Score", fontsize=14)
plt.ylabel("Efficiency", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.legend(loc="best", fontsize=11)
plt.grid(True)
plt.ylim(0.795, 1.005)
plt.xlim(0.75, 1.005)
plt.show()

# Plot Purity vs Efficiency
# Define target efficiency values at 0, 0.1, 0.2, ..., 1.0.
target_eff = np.linspace(0, 1, 21)
# For positrons, find the index of the closest value in test_efficiency_p for each target.
selected_indices_p = [np.argmin(np.abs(test_efficiency_p - te)) for te in target_eff]

plt.figure(figsize=(4.5, 3))
plt.plot(test_efficiency_e, test_purity_e, label='Electrons', color='blue', lw=2)
plt.plot(test_efficiency_p, test_purity_p, label='Positrons', color='red', lw=2)

# plt.scatter(test_efficiency_p[selected_indices_p], test_purity_p[selected_indices_p], 
#             label='Positrons', color='red', marker='+', s=50, zorder=3)
plt.xlabel("Efficiency", fontsize=14)
plt.ylabel("Purity", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
plt.legend(loc="lower left", fontsize=11)
plt.grid(True)
plt.ylim(0.95, 1.001)
plt.xlim(0.65, 1.01)
plt.show()

In [None]:
# plot information for efficiency, purity, threshold
# find results (return y_true, y_probs, y_pred_final) for all datasets below

results_eee   = evaluate_combined(model, d_test_loader,      dataset_name="eee Set",    num_thresholds=1000)
results_IC    = evaluate_combined(model, d_test_loader_IC,   dataset_name="eeevv Set",  num_thresholds=1000)
results_beam  = evaluate_combined(model, d_test_loader_beam, dataset_name="beam Set",  num_thresholds=1000)
results_michel= evaluate_combined(model, d_test_loader_M,    dataset_name="michel Set",num_thresholds=1000)

opt_th_eee    = results_eee   ["optimal_threshold"]
y_true_eee    = results_eee   ["y_true"]
y_probs_eee   = results_eee   ["y_probs"]

opt_th_IC     = results_IC    ["optimal_threshold"]
y_true_IC     = results_IC    ["y_true"]
y_probs_IC    = results_IC    ["y_probs"]

opt_th_beam   = results_beam  ["optimal_threshold"]
y_true_beam   = results_beam  ["y_true"]
y_probs_beam  = results_beam  ["y_probs"]

opt_th_michel = results_michel["optimal_threshold"]
y_true_michel = results_michel["y_true"]
y_probs_michel= results_michel["y_probs"]

# also for each class:
# For the eee set (using d_test_loader for electrons and positrons, and for fakes as well)
results_electrons_eee = evaluate_for_class(model, d_test_loader, class_index=1, dataset_name="Signal Set - Electrons", num_thresholds=1000)
results_positrons_eee = evaluate_for_class(model, d_test_loader, class_index=0, dataset_name="Signal Set - Positrons", num_thresholds=1000)
results_fakes_eee = evaluate_for_class(model, d_test_loader, class_index=2, dataset_name="Signal Set - Fakes", num_thresholds=1000)

# For the Internal Conversion (IC) set (using d_test_loader_IC)
results_electrons_IC = evaluate_for_class(model, d_test_loader_IC, class_index=1, dataset_name="IC Set - Electrons", num_thresholds=1000)
results_positrons_IC = evaluate_for_class(model, d_test_loader_IC, class_index=0, dataset_name="IC Set - Positrons", num_thresholds=1000)
results_fakes_IC = evaluate_for_class(model, d_test_loader_IC, class_index=2, dataset_name="IC Set - Fakes", num_thresholds=1000)

# For the Michel (M) set (using d_test_loader_M)
results_electrons_M = evaluate_for_class(model, d_test_loader_M, class_index=1, dataset_name="Michel Set - Electrons", num_thresholds=1000)
results_positrons_M = evaluate_for_class(model, d_test_loader_M, class_index=0, dataset_name="Michel Set - Positrons", num_thresholds=1000)
# results_fakes_M = evaluate_for_class(model, d_test_loader_M, class_index=2, dataset_name="Michel Set - Fakes", num_thresholds=1000)

results_electrons_beam = evaluate_for_class(model, d_test_loader_beam, class_index=1, dataset_name="Beam Set - Electrons", num_thresholds=1000)
results_positrons_beam = evaluate_for_class(model, d_test_loader_beam, class_index=0, dataset_name="Beam Set - Positrons", num_thresholds=1000)
results_fakes_beam = evaluate_for_class(model, d_test_loader_beam, class_index=2, dataset_name="Beam Set - Fakes", num_thresholds=1000)

In [None]:
def compute_efficiency(y_true, y_probs, thresholds):
    """Compute efficiency (TPR) at each threshold."""
    eff = []
    for t in thresholds:
        y_pred = (y_probs >= t).astype(int)
        TP = np.sum((y_pred == 1) & (y_true == 1))
        FN = np.sum((y_pred == 0) & (y_true == 1))
        eff.append(TP / (TP + FN) if (TP + FN) > 0 else np.nan)
    return np.array(eff)

import numpy as np
import matplotlib.pyplot as plt

# Define thresholds from 0 to 1.
thresholds = np.linspace(0, 1, 1000)

# Compute efficiency (TPR) for each dataset.
eff_eee   = compute_efficiency(y_true_eee, y_probs_eee, thresholds)
eff_IC    = compute_efficiency(y_true_IC, y_probs_IC, thresholds)
eff_beam  = compute_efficiency(y_true_beam, y_probs_beam, thresholds)
eff_michel = compute_efficiency(y_true_michel, y_probs_michel, thresholds)

colors = plt.get_cmap("Set2").colors  # This returns a tuple of color values.
plt.figure(figsize=(4.5,3))
plt.plot(thresholds, eff_eee, linestyle='-', label="Signal", color='blue', lw=2)
plt.plot(thresholds, eff_IC, linestyle='--',label="I.C.", color='orange', lw=2)
plt.plot(thresholds, eff_michel, linestyle=':',label="Michel", color='red', lw=2)
plt.plot(thresholds, eff_beam, linestyle='-.',label="Beam", color='green', lw=2)
plt.xlabel("Required Score", fontsize=14)
plt.ylabel("Efficiency", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
# plt.legend(loc="best", fontsize=11)
plt.grid(True)
plt.ylim(0.895, 1.005)
plt.xlim(0.745, 1.005)
plt.show()

In [None]:
def compute_purity(y_true, y_probs, thresholds):
    """Compute purity (precision) at each threshold."""
    purity = []
    for t in thresholds:
        y_pred = (y_probs >= t).astype(int)
        TP = np.sum((y_pred == 1) & (y_true == 1))
        FP = np.sum((y_pred == 1) & (y_true == 0))
        purity.append(TP / (TP + FP) if (TP + FP) > 0 else np.nan)
    return np.array(purity)

import numpy as np
import matplotlib.pyplot as plt

# Define thresholds from 0 to 1.
thresholds = np.linspace(0, 1, 1000)

# Compute purity for each dataset.
purity_eee   = compute_purity(y_true_eee, y_probs_eee, thresholds)
purity_IC    = compute_purity(y_true_IC, y_probs_IC, thresholds)
purity_beam  = compute_purity(y_true_beam, y_probs_beam, thresholds)
purity_michel = compute_purity(y_true_michel, y_probs_michel, thresholds)

# Use ColorBrewer’s Set2 palette for consistent colors.
colors = plt.get_cmap("Set2").colors


plt.figure(figsize=(4.5,3))

max_thresh = 0.999
mask = thresholds <= max_thresh

thresholds = thresholds[mask]
purity_eee = purity_eee[mask]
purity_IC = purity_IC[mask]
purity_beam = purity_beam[mask]
purity_michel = purity_michel[mask]

plt.plot(thresholds[:-5], purity_eee[:-5], linestyle='-', label="Signal", color='blue', lw=2)
plt.plot(thresholds[:-5], purity_IC[:-5], linestyle='--', label="I.C.", color='orange', lw=2)
plt.plot(thresholds[:-5], purity_michel[:-5], linestyle=':', label="Michel", color='red', lw=2)
plt.plot(thresholds[:-5], purity_beam[:-5], linestyle='-.', label="Beam", color='green', lw=2)

# plt.plot(thresholds, purity_eee, linestyle='-', label="Signal Set", color='orange', lw=2)
# plt.plot(thresholds, purity_IC, linestyle='--', label="Internal Conversion Set", color='green', lw=2)
# # plt.plot(thresholds, purity_beam, linestyle='-.', label="Beam Set", color='blue', lw=2)
# plt.plot(thresholds, purity_michel, linestyle=':', label="Michel Set", color='purple', lw=2)
plt.xlabel("Required Score", fontsize=14)
plt.ylabel("Purity", fontsize=14)
plt.tick_params(axis='both', labelsize=12)
# plt.legend(loc="best", fontsize=11, bbox_to_anchor=(1, 0.5))
plt.grid(True)
plt.ylim(0.395, 1.02)
plt.ylim(0.945, 0.99)
plt.xlim(-0.02, 1.02)
plt.show()

In [None]:
# confidence histograms for each class
# if the GNN is confident in a graph being fake, it will have a high score in class 2, but for the plot below
# it applies (1 - score) so that the fake peak is near 0.0
from GNNplots import plot_confidence_histograms_multi

plot_confidence_histograms_multi("Signal ",                  results_electrons_eee,   results_positrons_eee,   results_fakes_eee)
plot_confidence_histograms_multi("I.C. ",          results_electrons_IC,    results_positrons_IC,    results_fakes_IC)
# plot_confidence_histograms_multi("Michel",        results_electrons_M,     results_positrons_M,     fake_results=None)
# plot_confidence_histograms_multi("Beam",          results_electrons_beam,  results_positrons_beam,  results_fakes_beam)

In [None]:
# zoomed confidence histograms for each decay type
from GNNplots import plot_class_histograms_with_errors

# Example usage for electrons:
dataset_labels_elec = ["Signal – e⁻", "I.C. – e⁻", "Michel – e⁻", "Beam – e⁻"]
elec_results_list   = [
    results_electrons_eee,
    results_electrons_IC,
    results_electrons_M,
    results_electrons_beam
]
colors      = ["blue", "orange", "red", "green"]
line_styles = ["-", "--", ":", "-."]

plot_class_histograms_with_errors(
    dataset_labels_elec,
    elec_results_list,
    axis_label="GNN output for electrons",
    y_label="Frequency Density",
    color_list=colors,
    style_list=line_styles,
    ylim=(2e-3, 260)
)

# Example usage for positrons:
dataset_labels_pos = ["Signal – e⁺", "I.C. – e⁺", "Michel – e⁺", "Beam – e⁺"]
pos_results_list   = [
    results_positrons_eee,
    results_positrons_IC,
    results_positrons_M,
    results_positrons_beam
]

plot_class_histograms_with_errors(
    dataset_labels_pos,
    pos_results_list,
    axis_label="GNN output for positrons",
    y_label="Frequency Density",
    color_list=colors,
    style_list=line_styles,
    ylim=(2e-3, 170)
)