In [None]:
# set seed for reproducibility
import torch
torch.manual_seed(42)

In [1]:
import os
import uproot
import numpy as np

def load_root_file(file_path, branches=None, print_branches=False):
    all_branches = {}
    with uproot.open(file_path) as file:
        tree = file["tree"]
        # Load all ROOT branches into array if not specified
        if branches is None:
            branches = tree.keys()
        # Option to print the branch names
        if print_branches:
            print("Branches:", tree.keys())
        # Each branch is added to the dictionary
        for branch in branches:
            try:
                all_branches[branch] = (tree[branch].array(library="np"))
            except uproot.KeyInFileError as e:
                print(f"KeyInFileError: {e}")
        # Number of events in file
        all_branches['event'] = tree.num_entries
    return all_branches

branches_list = [
    # Core T3 properties from TripletsSoA
    't3_betaIn',
    't3_centerX',
    't3_centerY',
    't3_radius',
    't3_partOfPT5',
    't3_partOfT5',
    't3_partOfPT3',
    't3_layer_binary',
    't3_pMatched',
    't3_matched_simIdx',
    't3_sim_vxy',
    't3_sim_vz'
]

# Hit-dependent branches
suffixes = ['r', 'z', 'eta', 'phi', 'layer']
branches_list += [f't3_hit_{i}_{suffix}' for i in [0, 1, 2, 3, 4, 5] for suffix in suffixes]

file_path = "600_t3_dnn_relval_fix.root"
branches = load_root_file(file_path, branches_list)

In [2]:
z_max = np.max([np.max(event) for event in branches[f't3_hit_3_z']])
r_max = np.max([np.max(event) for event in branches[f't3_hit_3_r']])
eta_max = 2.5
phi_max = np.pi

print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')

Z max: 224.14950561523438, R max: 98.93299102783203, Eta max: 2.5


In [3]:
def delta_phi(phi1, phi2):
    delta = phi1 - phi2
    delta = np.where(delta > np.pi, delta - 2*np.pi, delta)
    delta = np.where(delta < -np.pi, delta + 2*np.pi, delta)
    return delta

n_events = branches['event']

all_eta0 = np.abs(np.concatenate([branches['t3_hit_0_eta'][evt] for evt in range(n_events)]))
all_eta2 = np.abs(np.concatenate([branches['t3_hit_2_eta'][evt] for evt in range(n_events)]))
all_eta4 = np.abs(np.concatenate([branches['t3_hit_4_eta'][evt] for evt in range(n_events)]))

all_phi0 = np.concatenate([branches['t3_hit_0_phi'][evt] for evt in range(n_events)])
all_phi2 = np.concatenate([branches['t3_hit_2_phi'][evt] for evt in range(n_events)])
all_phi4 = np.concatenate([branches['t3_hit_4_phi'][evt] for evt in range(n_events)])

all_z0 = np.abs(np.concatenate([branches['t3_hit_0_z'][evt] for evt in range(n_events)]))
all_z2 = np.abs(np.concatenate([branches['t3_hit_2_z'][evt] for evt in range(n_events)]))
all_z4 = np.abs(np.concatenate([branches['t3_hit_4_z'][evt] for evt in range(n_events)]))

all_r0 = np.concatenate([branches['t3_hit_0_r'][evt] for evt in range(n_events)])
all_r2 = np.concatenate([branches['t3_hit_2_r'][evt] for evt in range(n_events)])
all_r4 = np.concatenate([branches['t3_hit_4_r'][evt] for evt in range(n_events)])

all_radius = np.concatenate([branches['t3_radius'][evt] for evt in range(n_events)])
all_betaIn = np.concatenate([branches['t3_betaIn'][evt] for evt in range(n_events)])

features = np.array([
    all_eta0 / eta_max,                      # Hit 0 eta
    np.abs(all_phi0) / phi_max,              # Hit 0 phi
    all_z0 / z_max,                          # Hit 0 z
    all_r0 / r_max,                          # Hit 0 r
    (all_eta2 - all_eta0),                   # Difference in eta: hit2 - hit0
    delta_phi(all_phi2, all_phi0) / phi_max, # Difference in phi: hit2 - hit0
    (all_z2 - all_z0) / z_max,               # Difference in z: hit2 - hit0
    (all_r2 - all_r0) / r_max,               # Difference in r: hit2 - hit0
    (all_eta4 - all_eta2),                   # Difference in eta: hit4 - hit2
    delta_phi(all_phi4, all_phi2) / phi_max, # Difference in phi: hit4 - hit2
    (all_z4 - all_z2) / z_max,               # Difference in z: hit4 - hit2
    (all_r4 - all_r2) / r_max,               # Difference in r: hit4 - hit2
    np.log10(all_radius),                    # Circle radius
    all_betaIn                               # Beta angle
])

eta_list = np.array([all_eta0])

In [4]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np

# Stack features and handle NaN/Inf as before
input_features_numpy = np.stack(features, axis=-1)
mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)
filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]
t3_isFake_filtered = (np.concatenate(branches['t3_pMatched']) < 0.75)[np.all(mask, axis=1)]
t3_sim_vxy_filtered = np.concatenate(branches['t3_sim_vxy'])[np.all(mask, axis=1)]

# Convert to PyTorch tensor
input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, random_split, DataLoader
from torch.optim import Adam

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create multi-class labels
def create_multiclass_labels(t3_isFake, t3_sim_vxy, displacement_threshold=0.1):
    num_samples = len(t3_isFake)
    labels = torch.zeros((num_samples, 3))
    
    # Fake tracks (class 0)
    fake_mask = t3_isFake
    labels[fake_mask, 0] = 1
    
    # Real tracks
    real_mask = ~fake_mask
    
    # Split real tracks into prompt (class 1) and displaced (class 2)
    prompt_mask = (t3_sim_vxy <= displacement_threshold) & real_mask
    displaced_mask = (t3_sim_vxy > displacement_threshold) & real_mask
    
    labels[prompt_mask, 1] = 1
    labels[displaced_mask, 2] = 1
    
    return labels

# Create labels tensor
labels_tensor = create_multiclass_labels(
    t3_isFake_filtered,
    t3_sim_vxy_filtered
)

# Neural network for multi-class classification
class MultiClassNeuralNetwork(nn.Module):
    def __init__(self):
        super(MultiClassNeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_features_numpy.shape[1], 32)
        self.layer2 = nn.Linear(32, 32)
        self.output_layer = nn.Linear(32, 3)
        
    def forward(self, x):
        x = self.layer1(x)
        x = nn.ReLU()(x)
        x = self.layer2(x)
        x = nn.ReLU()(x)
        x = self.output_layer(x)
        return nn.functional.softmax(x, dim=1)

# Weighted loss function for multi-class
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(WeightedCrossEntropyLoss, self).__init__()
        
    def forward(self, outputs, targets, weights):
        eps = 1e-7
        log_probs = torch.log(outputs + eps)
        losses = -weights * torch.sum(targets * log_probs, dim=1)
        return losses.mean()

# Calculate class weights (each sample gets a weight to equalize class contributions)
def calculate_class_weights(labels):
    class_counts = torch.sum(labels, dim=0)
    total_samples = len(labels)
    class_weights = total_samples / (3 * class_counts)  # Normalize across 3 classes
    
    sample_weights = torch.zeros(len(labels))
    for i in range(3):
        sample_weights[labels[:, i] == 1] = class_weights[i]
    
    return sample_weights

# Print initial dataset size
print(f"Initial dataset size: {len(labels_tensor)}")

# Remove rows with NaN and update everything accordingly
nan_mask = torch.isnan(input_features_tensor).any(dim=1)
filtered_inputs = input_features_tensor[~nan_mask]
filtered_labels = labels_tensor[~nan_mask]

# Print class distribution before downsampling
class_counts_before = torch.sum(filtered_labels, dim=0)
print(f"Class distribution before downsampling - Fake: {class_counts_before[0]}, Prompt: {class_counts_before[1]}, Displaced: {class_counts_before[2]}")

# Option to downsample each class
downsample_classes = True  # Set to False to disable downsampling
if downsample_classes:
    # Define downsampling ratios for each class:
    # For example, downsample fakes (class 0) to 20% and keep prompt (class 1) and displaced (class 2) at 100%
    downsample_ratios = {0: 0.2, 1: 1.0, 2: 1.0}
    indices_list = []
    for cls in range(3):
        # Find indices for the current class
        cls_mask = (filtered_labels[:, cls] == 1)
        cls_indices = torch.nonzero(cls_mask).squeeze()
        ratio = downsample_ratios.get(cls, 1.0)
        num_cls = cls_indices.numel()
        num_to_sample = int(num_cls * ratio)
        # Ensure at least one sample is kept if available
        if num_to_sample < 1 and num_cls > 0:
            num_to_sample = 1
        # Shuffle and select the desired number of samples
        cls_indices_shuffled = cls_indices[torch.randperm(num_cls)]
        sampled_cls_indices = cls_indices_shuffled[:num_to_sample]
        indices_list.append(sampled_cls_indices)
    
    # Combine the indices from all classes
    selected_indices = torch.cat(indices_list)
    filtered_inputs = filtered_inputs[selected_indices]
    filtered_labels = filtered_labels[selected_indices]

# Print class distribution after downsampling
class_counts_after = torch.sum(filtered_labels, dim=0)
print(f"Class distribution after downsampling - Fake: {class_counts_after[0]}, Prompt: {class_counts_after[1]}, Displaced: {class_counts_after[2]}")

# Recalculate sample weights after downsampling (equal weighting per class based on new counts)
sample_weights = calculate_class_weights(filtered_labels)
filtered_weights = sample_weights

# Create dataset with weights
dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)

# Split into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)

# Initialize model and optimizer
model = MultiClassNeuralNetwork().to(device)
loss_function = WeightedCrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.0025)

def evaluate_loss(loader):
    model.eval()
    total_loss = 0
    num_batches = 0
    with torch.no_grad():
        for inputs, targets, weights in loader:
            inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, targets, weights)
            total_loss += loss.item()
            num_batches += 1
    return total_loss / num_batches

# Training loop
num_epochs = 400
train_loss_log = []
test_loss_log = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    num_batches = 0
    
    for inputs, targets, weights in train_loader:
        inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = loss_function(outputs, targets, weights)
        epoch_loss += loss.item()
        num_batches += 1
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Calculate average losses
    train_loss = epoch_loss / num_batches
    test_loss = evaluate_loss(test_loader)
    
    train_loss_log.append(train_loss)
    test_loss_log.append(test_loss)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

Using device: cuda
Initial dataset size: 55072926
Class distribution before downsampling - Fake: 49829032.0, Prompt: 4472777.0, Displaced: 771119.0
Class distribution after downsampling - Fake: 9965806.0, Prompt: 4472777.0, Displaced: 771119.0
Epoch [1/400], Train Loss: 0.6515, Test Loss: 0.5908
Epoch [2/400], Train Loss: 0.5771, Test Loss: 0.5659
Epoch [3/400], Train Loss: 0.5647, Test Loss: 0.5549
Epoch [4/400], Train Loss: 0.5588, Test Loss: 0.5627
Epoch [5/400], Train Loss: 0.5553, Test Loss: 0.5541
Epoch [6/400], Train Loss: 0.5528, Test Loss: 0.5664
Epoch [7/400], Train Loss: 0.5508, Test Loss: 0.5574
Epoch [8/400], Train Loss: 0.5492, Test Loss: 0.5503
Epoch [9/400], Train Loss: 0.5478, Test Loss: 0.5447
Epoch [10/400], Train Loss: 0.5472, Test Loss: 0.5538
Epoch [11/400], Train Loss: 0.5459, Test Loss: 0.5534
Epoch [12/400], Train Loss: 0.5454, Test Loss: 0.5487
Epoch [13/400], Train Loss: 0.5445, Test Loss: 0.5366
Epoch [14/400], Train Loss: 0.5441, Test Loss: 0.5387
Epoch [15

In [6]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score

# Convert tensors to numpy for simplicity if you want to manipulate them outside of PyTorch
input_features_np = input_features_tensor.numpy()
labels_np = torch.argmax(labels_tensor, dim=1).numpy()  # Convert one-hot to class indices

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def model_accuracy(features, labels, model):
    """
    Compute accuracy for a multi-class classification model
    that outputs probabilities of size [batch_size, num_classes].
    """
    model.eval()  # Set the model to evaluation mode
    
    # Move the features and labels to the correct device
    inputs = features.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        outputs = model(inputs)  # shape: [batch_size, num_classes]
        # For multi-class, the predicted class is argmax of the probabilities
        predicted = torch.argmax(outputs, dim=1)
        # Convert one-hot encoded labels to class indices if needed
        if len(labels.shape) > 1:
            labels = torch.argmax(labels, dim=1)
        # Compute mean accuracy
        accuracy = (predicted == labels).float().mean().item()
    
    return accuracy

# Compute baseline accuracy
baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)
print(f"Baseline accuracy: {baseline_accuracy:.4f}")

# Initialize array to store feature importances
feature_importances = np.zeros(input_features_tensor.shape[1])

# Iterate over each feature for permutation importance
for i in range(input_features_tensor.shape[1]):
    # Create a copy of the original features
    permuted_features = input_features_tensor.clone()
    
    # Permute feature i across all examples
    # We do this by shuffling the rows for that specific column
    permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i]
    
    # Compute accuracy after permutation
    permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)
    
    # The drop in accuracy is used as a measure of feature importance
    feature_importances[i] = baseline_accuracy - permuted_accuracy

# Sort features by descending importance
important_features_indices = np.argsort(feature_importances)[::-1]
important_features_scores = np.sort(feature_importances)[::-1]

# Print out results
print("\nFeature importances:")
for idx, score in zip(important_features_indices, important_features_scores):
    print(f"Feature {idx} importance: {score:.4f}")

Baseline accuracy: 0.8611

Feature importances:
Feature 0 importance: 0.0541
Feature 2 importance: 0.0480
Feature 5 importance: 0.0434
Feature 7 importance: 0.0242
Feature 6 importance: 0.0223
Feature 3 importance: 0.0206
Feature 11 importance: 0.0167
Feature 10 importance: 0.0148
Feature 13 importance: 0.0140
Feature 12 importance: 0.0128
Feature 9 importance: 0.0114
Feature 8 importance: 0.0046
Feature 4 importance: 0.0016
Feature 1 importance: 0.0000


In [8]:
def print_formatted_weights_biases(weights, biases, layer_name):
    # Print biases
    print(f"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{")
    print(", ".join(f"{b:.7f}f" for b in biases) + " };")
    print()

    # Print weights
    print(f"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{")
    for row in weights.T:
        formatted_row = ", ".join(f"{w:.7f}f" for w in row)
        print(f"{{ {formatted_row} }},")
    print("};")
    print()

def print_model_weights_biases(model):
    # Make sure the model is in evaluation mode
    model.eval()

    # Iterate through all named modules in the model
    for name, module in model.named_modules():
        # Check if the module is a linear layer
        if isinstance(module, nn.Linear):
            # Get weights and biases
            weights = module.weight.data.cpu().numpy()
            biases = module.bias.data.cpu().numpy()

            # Print formatted weights and biases
            print_formatted_weights_biases(weights, biases, name.replace('.', '_'))

print_model_weights_biases(model)


ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {
-0.9152892f, 3.2650192f, -0.4164221f, -0.1210157f, -2.4165483f, -1.0984275f, -2.1654966f, -0.8991888f, -0.0503724f, 7.1305695f, -5.2781415f, 3.2997849f, 1.0025330f, -0.5117974f, 0.2957068f, -0.1811045f, -2.7853479f, 1.8040915f, -2.8807588f, -4.6462102f, 1.2869841f, -0.0526987f, 0.4946094f, 2.6554070f, -0.1360572f, 0.2122774f, 4.7361507f, -1.4605266f, 0.1759245f, -0.7966636f, -0.0401897f, -0.2652957f };

ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[14][32] = {
{ 0.2570587f, 1.5017653f, -1.8436140f, 1.6314303f, -0.1464428f, 1.2261974f, 2.8629315f, -0.0778951f, -0.0007868f, -2.4665442f, 3.7231014f, -0.4062112f, 5.0222125f, -0.4256854f, -0.8145034f, -0.0993065f, 1.1874412f, 3.7737985f, -2.0898068f, 5.0041976f, -0.4184950f, 0.0133298f, -1.1757115f, 0.8953519f, -0.2589224f, 3.4567924f, -1.0867721f, -0.0325336f, -0.1398652f, 5.9361205f, -0.2938714f, 0.0110872f },
{ 0.0062326f, 0.0294117f, -0.1038531f, -0.1871421f, 0.0092176f

In [9]:
# Ensure input_features_tensor is moved to the appropriate device
input_features_tensor = input_features_tensor.to(device)

# Make predictions
with torch.no_grad():
    model.eval()
    outputs = model(input_features_tensor)
    predictions = outputs.squeeze().cpu().numpy()

full_tracks = (np.concatenate(branches['t3_pMatched']) > 0.95)

t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2

In [10]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
import torch

# Ensure input_features_tensor is on the right device
input_features_tensor = input_features_tensor.to(device)

# Get model predictions
with torch.no_grad():
    model.eval()
    outputs = model(input_features_tensor)
    predictions = outputs.cpu().numpy()  # Shape will be [n_samples, 3]

# Get track information
t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2

def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):
    """
    Calculate and plot cut values for specified percentiles in a given pt bin, separately for prompt and displaced tracks
    """
    # Filter data based on pt bin
    pt_mask = (t3_pt > pt_min) & (t3_pt <= pt_max)
    
    # Get absolute eta values for all tracks in pt bin
    abs_eta = np.abs(eta_list[0][pt_mask])
    
    # Get predictions for all tracks in pt bin
    pred_filtered = predictions[pt_mask]
    
    # Get track types using pMatched and t3_sim_vxy
    matched = (np.concatenate(branches['t3_pMatched']) > 0.95)[pt_mask]
    fake_tracks = (np.concatenate(branches['t3_pMatched']) < 0.75)[pt_mask]
    true_displaced = (t3_sim_vxy[pt_mask] > 0.1) & matched
    true_prompt = ~(t3_sim_vxy[pt_mask] > 0.1) & matched
    
    # Separate plots for prompt and displaced tracks
    for track_type, true_mask, pred_idx, title_suffix in [
        ("Prompt", true_prompt, 1, "Prompt Real Tracks"),
        ("Displaced", true_displaced, 2, "Displaced Real Tracks")
    ]:
        # Dictionaries to store values
        cut_values = {p: [] for p in percentiles}
        fake_rejections = {p: [] for p in percentiles}
        
        # Get probabilities for this class
        probs = pred_filtered[:, pred_idx]
        
        # Create two side-by-side plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # Plot probability distribution (only for true tracks of this type)
        h = ax1.hist2d(abs_eta[true_mask], 
                      probs[true_mask], 
                      bins=[eta_bin_edges, 50], 
                      norm=LogNorm())
        plt.colorbar(h[3], ax=ax1, label='Counts')
        
        # For each eta bin
        bin_centers = []
        for i in range(len(eta_bin_edges) - 1):
            eta_min, eta_max = eta_bin_edges[i], eta_bin_edges[i+1]
            bin_center = (eta_min + eta_max) / 2
            bin_centers.append(bin_center)
            
            # Get tracks in this eta bin
            eta_mask = (abs_eta >= eta_min) & (abs_eta < eta_max)
            
            # True tracks of this type in this bin
            true_type_mask = eta_mask & true_mask
            # Fake tracks in this bin
            fake_mask = eta_mask & fake_tracks
            
            print(f"Eta bin {eta_min:.2f}-{eta_max:.2f}: {np.sum(fake_mask)} fakes, {np.sum(true_type_mask)} true {track_type}")
            
            if np.sum(true_type_mask) > 0:  # If we have true tracks in this bin
                for percentile in percentiles:
                    # Calculate cut value to keep desired percentage of true tracks
                    cut_value = np.percentile(probs[true_type_mask], 100 - percentile)
                    cut_values[percentile].append(cut_value)
                    
                    # Calculate fake rejection for this cut
                    if np.sum(fake_mask) > 0:
                        fake_rej = 100 * np.mean(probs[fake_mask] < cut_value)
                        fake_rejections[percentile].append(fake_rej)
                    else:
                        fake_rejections[percentile].append(np.nan)
            else:
                for percentile in percentiles:
                    cut_values[percentile].append(np.nan)
                    fake_rejections[percentile].append(np.nan)
        
        # Plot cut values and fake rejections
        colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles)))
        bin_centers = np.array(bin_centers)
        
        for (percentile, color) in zip(percentiles, colors):
            values = np.array(cut_values[percentile])
            mask = ~np.isnan(values)
            if np.any(mask):
                # Plot cut values
                ax1.plot(bin_centers[mask], values[mask], '-', color=color, marker='o',
                        label=f'{percentile}% Retention Cut')
                # Plot fake rejections
                rej_values = np.array(fake_rejections[percentile])
                ax2.plot(bin_centers[mask], rej_values[mask], '-', color=color, marker='o',
                        label=f'{percentile}% Cut')
        
        # Set plot labels and titles
        ax1.set_xlabel("Absolute Eta")
        ax1.set_ylabel(f"DNN {track_type} Probability")
        ax1.set_title(f"DNN Score vs Eta ({title_suffix})\npt: {pt_min:.1f} to {pt_max:.1f} GeV")
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        ax2.set_xlabel("Absolute Eta")
        ax2.set_ylabel("Fake Rejection (%)")
        ax2.set_title(f"Fake Rejection vs Eta\npt: {pt_min:.1f} to {pt_max:.1f} GeV")
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 100)
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print(f"\n{track_type} tracks, pt: {pt_min:.1f} to {pt_max:.1f} GeV")
        print(f"Number of true {track_type.lower()} tracks: {np.sum(true_mask)}")
        print(f"Number of fake tracks in pt bin: {np.sum(fake_tracks)}")
        
        for percentile in percentiles:
            print(f"\n{percentile}% Retention Cut Values:",
                  '{' + ', '.join(f"{x:.4f}" if not np.isnan(x) else 'nan' for x in cut_values[percentile]) + '}',
                  f"Mean: {np.round(np.nanmean(cut_values[percentile]), 4)}")
            print(f"{percentile}% Cut Fake Rejections:",
                  '{' + ', '.join(f"{x:.1f}" if not np.isnan(x) else 'nan' for x in fake_rejections[percentile]) + '}',
                  f"Mean: {np.round(np.nanmean(fake_rejections[percentile]), 1)}%")

def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):
    """
    Analyze and plot for multiple pt bins and percentiles
    """
    for i in range(len(pt_bins) - 1):
        plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges,
                       t3_pt, predictions, t3_sim_vxy, eta_list)

# Run the analysis with same parameters as before
percentiles = [80, 90, 93, 96, 97, 98, 99, 99.5]
pt_bins = [0, 5, np.inf]
eta_bin_edges = np.arange(0, 2.75, 0.25)

analyze_pt_bins(
    pt_bins=pt_bins,
    percentiles=percentiles,
    eta_bin_edges=eta_bin_edges,
    t3_pt=t3_pt,
    predictions=predictions,
    t3_sim_vxy=np.concatenate(branches['t3_sim_vxy']),
    eta_list=eta_list
)

Eta bin 0.00-0.25: 9409714 fakes, 313231 true Prompt
Eta bin 0.25-0.50: 9242595 fakes, 323051 true Prompt
Eta bin 0.50-0.75: 7849380 fakes, 410185 true Prompt
Eta bin 0.75-1.00: 4293980 fakes, 322065 true Prompt
Eta bin 1.00-1.25: 4343023 fakes, 374215 true Prompt
Eta bin 1.25-1.50: 2725728 fakes, 351420 true Prompt
Eta bin 1.50-1.75: 1368266 fakes, 425819 true Prompt
Eta bin 1.75-2.00: 1413754 fakes, 467604 true Prompt
Eta bin 2.00-2.25: 448439 fakes, 419450 true Prompt
Eta bin 2.25-2.50: 124212 fakes, 247704 true Prompt


<Figure size 2000x800 with 3 Axes>


Prompt tracks, pt: 0.0 to 5.0 GeV
Number of true prompt tracks: 3654744
Number of fake tracks in pt bin: 41219091

80% Retention Cut Values: {0.6326, 0.6415, 0.6506, 0.6526, 0.5321, 0.5651, 0.5747, 0.5765, 0.6243, 0.6320} Mean: 0.6082
80% Cut Fake Rejections: {99.1, 99.1, 99.1, 99.2, 97.4, 97.1, 96.8, 96.8, 95.6, 92.4} Mean: 97.3%

90% Retention Cut Values: {0.4957, 0.5052, 0.5201, 0.5340, 0.4275, 0.4708, 0.4890, 0.4932, 0.5400, 0.5449} Mean: 0.502
90% Cut Fake Rejections: {98.5, 98.6, 98.6, 98.7, 95.5, 95.4, 95.0, 95.0, 93.1, 88.9} Mean: 95.7%

93% Retention Cut Values: {0.3874, 0.3978, 0.4201, 0.4502, 0.3750, 0.4117, 0.4394, 0.4466, 0.4916, 0.4899} Mean: 0.431
93% Cut Fake Rejections: {98.1, 98.2, 98.2, 98.3, 94.5, 94.3, 94.0, 94.1, 91.8, 86.8} Mean: 94.8%

96% Retention Cut Values: {0.1332, 0.1616, 0.2094, 0.2898, 0.2909, 0.3109, 0.3561, 0.3626, 0.4016, 0.3775} Mean: 0.2894
96% Cut Fake Rejections: {96.3, 96.7, 96.9, 97.1, 92.5, 92.4, 92.2, 92.4, 89.6, 82.9} Mean: 92.9%

97% Retent

<Figure size 2000x800 with 3 Axes>


Displaced tracks, pt: 0.0 to 5.0 GeV
Number of true displaced tracks: 611820
Number of fake tracks in pt bin: 41219091

80% Retention Cut Values: {0.2362, 0.2377, 0.2496, 0.2579, 0.2940, 0.3057, 0.3273, 0.3224, 0.2944, 0.2890} Mean: 0.2814
80% Cut Fake Rejections: {83.1, 81.5, 78.7, 76.4, 75.6, 68.3, 54.0, 49.9, 34.7, 18.3} Mean: 62.1%

90% Retention Cut Values: {0.1809, 0.1855, 0.1998, 0.2098, 0.2173, 0.2267, 0.2391, 0.2507, 0.2356, 0.2317} Mean: 0.2177
90% Cut Fake Rejections: {79.7, 78.0, 75.4, 73.2, 70.1, 61.5, 46.5, 43.7, 29.3, 12.4} Mean: 57.0%

93% Retention Cut Values: {0.1527, 0.1622, 0.1814, 0.1933, 0.1952, 0.1889, 0.2041, 0.2280, 0.2216, 0.2166} Mean: 0.1944
93% Cut Fake Rejections: {77.6, 76.1, 74.0, 72.1, 68.4, 58.2, 43.5, 41.8, 28.0, 10.8} Mean: 55.0%

96% Retention Cut Values: {0.1133, 0.1274, 0.1514, 0.1662, 0.1657, 0.1576, 0.1601, 0.2030, 0.2068, 0.2028} Mean: 0.1654
96% Cut Fake Rejections: {73.9, 72.8, 71.5, 70.0, 65.9, 55.1, 39.2, 39.6, 26.6, 9.6} Mean: 52.4%

97% 

<Figure size 2000x800 with 3 Axes>


Prompt tracks, pt: 5.0 to inf GeV
Number of true prompt tracks: 99138
Number of fake tracks in pt bin: 8609939

80% Retention Cut Values: {0.1353, 0.1569, 0.2083, 0.2407, 0.2910, 0.3552, 0.4221, 0.4197, 0.3174, 0.2681} Mean: 0.2815
80% Cut Fake Rejections: {98.8, 99.0, 99.3, 99.4, 96.1, 94.8, 93.5, 93.7, 88.5, 74.4} Mean: 93.7%

90% Retention Cut Values: {0.0302, 0.0415, 0.0994, 0.1791, 0.1960, 0.2467, 0.3227, 0.3242, 0.2367, 0.2187} Mean: 0.1895
90% Cut Fake Rejections: {90.2, 94.5, 98.1, 99.0, 93.4, 91.4, 89.8, 89.7, 80.1, 63.8} Mean: 89.0%

93% Retention Cut Values: {0.0188, 0.0261, 0.0502, 0.1487, 0.1606, 0.2022, 0.2772, 0.2797, 0.2182, 0.2063} Mean: 0.1588
93% Cut Fake Rejections: {82.0, 89.1, 94.9, 98.7, 91.8, 89.5, 87.9, 87.3, 77.5, 60.9} Mean: 85.9%

96% Retention Cut Values: {0.0095, 0.0138, 0.0243, 0.0620, 0.1096, 0.1450, 0.2184, 0.2276, 0.1949, 0.1769} Mean: 0.1182
96% Cut Fake Rejections: {67.7, 78.4, 86.6, 95.6, 88.5, 86.0, 84.6, 83.8, 74.2, 54.1} Mean: 79.9%

97% Retenti

<Figure size 2000x800 with 3 Axes>


Displaced tracks, pt: 5.0 to inf GeV
Number of true displaced tracks: 26375
Number of fake tracks in pt bin: 8609939

80% Retention Cut Values: {0.5019, 0.4954, 0.5197, 0.5256, 0.4161, 0.3509, 0.3778, 0.3667, 0.4388, 0.4915} Mean: 0.4484
80% Cut Fake Rejections: {98.6, 98.6, 98.7, 98.7, 97.0, 93.3, 87.6, 86.9, 89.0, 88.0} Mean: 93.6%

90% Retention Cut Values: {0.3265, 0.3967, 0.4493, 0.4656, 0.3459, 0.2910, 0.3486, 0.3356, 0.3693, 0.4448} Mean: 0.3773
90% Cut Fake Rejections: {97.6, 98.1, 98.2, 98.3, 95.8, 90.8, 85.4, 84.0, 83.3, 82.4} Mean: 91.4%

93% Retention Cut Values: {0.1804, 0.2502, 0.4033, 0.4375, 0.2958, 0.2556, 0.3353, 0.3242, 0.3405, 0.4194} Mean: 0.3242
93% Cut Fake Rejections: {96.1, 97.0, 97.9, 98.2, 94.6, 89.1, 84.5, 83.1, 80.3, 78.8} Mean: 90.0%

96% Retention Cut Values: {0.0464, 0.0366, 0.2521, 0.3317, 0.2324, 0.2021, 0.3165, 0.2992, 0.3096, 0.3568} Mean: 0.2383
96% Cut Fake Rejections: {90.2, 88.7, 96.8, 97.5, 92.9, 86.3, 83.3, 81.2, 77.1, 69.3} Mean: 86.3%

97% R