In [1]:
# set seed for reproducibility
import torch
torch.manual_seed(42)

<torch._C.Generator at 0x7f350c36dfb0>

In [3]:
import os
import uproot
import numpy as np

def load_root_file(file_path, branches=None, print_branches=False):
    all_branches = {}
    with uproot.open(file_path) as file:
        tree = file["tree"]
        # Load all ROOT branches into array if not specified
        if branches is None:
            branches = tree.keys()
        # Option to print the branch names
        if print_branches:
            print("Branches:", tree.keys())
        # Each branch is added to the dictionary
        for branch in branches:
            try:
                all_branches[branch] = (tree[branch].array(library="np"))
            except uproot.KeyInFileError as e:
                print(f"KeyInFileError: {e}")
        # Number of events in file
        all_branches['event'] = tree.num_entries
    return all_branches

def load_root_files(file_path1, file_path2, branches=None, print_branches=False):
    all_branches = {}
    def load_file(file_path, all_branches):
        with uproot.open(file_path) as file:
                tree = file["tree"]
                # Load all ROOT branches into array if not specified
                if branches is None:
                    file_branches = tree.keys()
                else:
                    file_branches = branches
                # Option to print the branch names
                if print_branches:
                    print(f"Branches in {file_path}:", tree.keys())
                # Each branch is added to the dictionary
                for branch in file_branches:
                    try:
                        if branch in all_branches:
                            all_branches[branch] = np.concatenate(
                                (all_branches[branch], tree[branch].array(library="np"))
                            )
                        else:
                            all_branches[branch] = tree[branch].array(library="np")
                    except uproot.KeyInFileError as e:
                        print(f"KeyInFileError in {file_path}: {e}")
                # Number of events in file
                all_branches['event'] = all_branches.get('event', 0) + tree.num_entries
    load_file(file_path1, all_branches)
    load_file(file_path2, all_branches)

    return all_branches

branches_list = [
    't4_innerRadius',
    't4_outerRadius',
    't4_pt',
    't4_eta',
    't4_phi',
    't4_isFake',
    't4_t3_idx0',
    't4_t3_idx1',
    't4_pMatched',
    't4_sim_vxy',
    't4_sim_vz',
    't4_t3_fakeScore1',
    't4_t3_promptScore1',
    't4_t3_displacedScore1',
    't4_t3_fakeScore2',
    't4_t3_promptScore2',
    't4_t3_displacedScore2',
    't4_regressionRadius',
    't4_nonAnchorRegressionRadius'
]

# Hit-dependent branches
suffixes = ['r', 'z', 'eta', 'phi', 'layer']
branches_list += [f't4_t3_{i}_{suffix}' for i in [0, 2, 4] for suffix in suffixes]

PU_file_path = "noCuts_Current_150925_500ev.root"
cube_file_path = "noCuts_cube50_cpu_debugfull.root"
branches = load_root_files(PU_file_path, cube_file_path, branches_list)

In [4]:
z_max = np.max([np.max(event) for event in branches[f't4_t3_4_z'] if event.size>0])
r_max = np.max([np.max(event) for event in branches[f't4_t3_4_r'] if event.size>0])
eta_max = 2.5
phi_max = np.pi

print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')

def delta_phi(phi1, phi2):
    delta = phi1 - phi2
    # Adjust delta to be within the range [-pi, pi]
    if delta > np.pi:
        delta -= 2 * np.pi
    elif delta < -np.pi:
        delta += 2 * np.pi
    return delta

Z max: 267.2349853515625, R max: 110.10993957519531, Eta max: 2.5


In [5]:
features_list = []
eta_list = [] # Used for DNN cut values

for event in range(branches['event']):
    # Determine the number of elements in this event
    num_elements = len(branches['t4_t3_idx0'][event])

    for i in range(num_elements):
        features_iter = []
        eta_iter = []
        
        idx0 = branches['t4_t3_idx0'][event][i]
        idx1 = branches['t4_t3_idx1'][event][i]

        eta1 = np.abs(branches['t4_t3_0_eta'][event][idx0])
        eta2 = np.abs(branches['t4_t3_2_eta'][event][idx0])
        eta3 = np.abs(branches['t4_t3_4_eta'][event][idx0])
        eta4 = np.abs(branches['t4_t3_4_eta'][event][idx1])

        phi1 = (branches['t4_t3_0_phi'][event][idx0])
        phi2 = (branches['t4_t3_2_phi'][event][idx0])
        phi3 = (branches['t4_t3_4_phi'][event][idx0])
        phi4 = (branches['t4_t3_4_phi'][event][idx1])

        z1 = np.abs(branches['t4_t3_0_z'][event][idx0])
        z2 = np.abs(branches['t4_t3_2_z'][event][idx0])
        z3 = np.abs(branches['t4_t3_4_z'][event][idx0])
        z4 = np.abs(branches['t4_t3_4_z'][event][idx1])

        r1 = branches['t4_t3_0_r'][event][idx0]
        r2 = branches['t4_t3_2_r'][event][idx0]
        r3 = branches['t4_t3_4_r'][event][idx0]
        r4 = branches['t4_t3_4_r'][event][idx1]

        innerRad = branches['t4_innerRadius'][event][i]
        outerRad = branches['t4_outerRadius'][event][i]

        regRad = branches['t4_regressionRadius'][event][i]
        nonAnchorRegRad = branches['t4_nonAnchorRegressionRadius'][event][i]

        f1 = branches['t4_t3_fakeScore1'][event][i]
        f2 = branches['t4_t3_fakeScore2'][event][i]
        p1 = branches['t4_t3_promptScore1'][event][i]
        p2 = branches['t4_t3_promptScore2'][event][i]
        d1 = branches['t4_t3_displacedScore1'][event][i]
        d2 = branches['t4_t3_displacedScore2'][event][i]


        # Construct the input feature vector using pairwise differences
        features_iter = [
            eta1 / eta_max,                      # First hit eta, normalized
            np.abs(phi1) / phi_max,              # First hit phi, normalized
            z1 / z_max,                          # First hit z, normalized
            r1 / r_max,                          # First hit r, normalized

            eta2 - eta1,                         # Difference in eta between hit 2 and 1
            delta_phi(phi2, phi1) / phi_max,     # Difference in phi between hit 2 and 1
            (z2 - z1) / z_max,                   # Difference in z between hit 2 and 1, normalized
            (r2 - r1) / r_max,                   # Difference in r between hit 2 and 1, normalized

            eta3 - eta2,                         # Difference in eta between hit 3 and 2
            delta_phi(phi3, phi2) / phi_max,     # Difference in phi between hit 3 and 2
            (z3 - z2) / z_max,                   # Difference in z between hit 3 and 2, normalized
            (r3 - r2) / r_max,                   # Difference in r between hit 3 and 2, normalized

            eta4 - eta3,                         # Difference in eta between hit 4 and 3
            delta_phi(phi4, phi3) / phi_max,     # Difference in phi between hit 4 and 3
            (z4 - z3) / z_max,                   # Difference in z between hit 4 and 3, normalized
            (r4 - r3) / r_max,                   # Difference in r between hit 4 and 3, normalized

            1.0/innerRad,
            1.0/outerRad,
            innerRad/outerRad,
            1.0/regRad,
            1.0/nonAnchorRegRad,

            f1,
            p1,
            d1,

            f2,
            p2,
            d2,

            (f2 - f1),
            (p2 - p1),
            (d2 - d1),
        ]

        # Use the abs eta value of first hit to select cut thresholds
        eta_iter.extend([np.abs(branches['t4_t3_0_eta'][event][idx0])])
        
        # Append the feature vector to the list
        features_list.append(features_iter)
        eta_list.append(eta_iter)

# Convert the list of features to a NumPy array
features = np.array(features_list).T
eta_list = np.array(eta_list).T

In [6]:
import torch

# Stack features along a new axis to form a single array suitable for NN input
input_features_numpy = np.stack(features, axis=-1)

# Identify rows with NaN or Inf values
mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)

# Apply mask across all columns: retain a row only if all its entries are neither NaN nor Inf
filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]
t4_isFake_filtered = (np.concatenate(branches['t4_pMatched']) <= 0.75)[np.all(mask, axis=1)]
t4_sim_vxy_filtered = np.concatenate(branches['t4_sim_vxy'])[np.all(mask, axis=1)]

# Convert to PyTorch tensor when ready to use with NN
input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)

In [None]:
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset, random_split

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create labels tensor
def create_multiclass_labels(t4_isFake, t4_sim_vxy, displacement_threshold=0.1):
    num_samples = len(t4_isFake)
    labels = torch.zeros((num_samples, 3))
    
    # Fake tracks (class 0)
    fake_mask = t4_isFake
    labels[fake_mask, 0] = 1
    
    # Real tracks
    real_mask = ~fake_mask 
    
    # Split real tracks into prompt (class 1) and displaced (class 2)
    prompt_mask = (t4_sim_vxy <= displacement_threshold) & real_mask
    displaced_mask = (t4_sim_vxy > displacement_threshold) & real_mask
    
    labels[prompt_mask, 1] = 1
    labels[displaced_mask, 2] = 1

    print(f"Total samples: {num_samples}")
    print(f"Fake count: {fake_mask.sum().item()}")
    print(f"Real count: {real_mask.sum().item()}")
    print(f"Prompt count: {prompt_mask.sum().item()}")
    print(f"Displaced count: {displaced_mask.sum().item()}")
    
    return labels

labels_tensor = create_multiclass_labels(
    t4_isFake_filtered,
    t4_sim_vxy_filtered
)

# Neural network for multi-class classification
class MultiClassNeuralNetwork(nn.Module):
    def __init__(self):
        super(MultiClassNeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_features_numpy.shape[1], 32)
        self.layer2 = nn.Linear(32, 32)
        self.output_layer = nn.Linear(32, 3)
        
    def forward(self, x):
        x = self.layer1(x)
        x = nn.ReLU()(x)
        x = self.layer2(x)
        x = nn.ReLU()(x)
        x = self.output_layer(x)
        return nn.functional.softmax(x, dim=1)

# Weighted loss function for multi-class
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(WeightedCrossEntropyLoss, self).__init__()
        
    def forward(self, outputs, targets, weights):
        eps = 1e-7
        log_probs = torch.log(outputs + eps)
        losses = -weights * torch.sum(targets * log_probs, dim=1)
        return losses.mean()


# Calculate class weights (each sample gets a weight to equalize class contributions)
def calculate_class_weights(labels):
    class_counts = torch.sum(labels, dim=0)
    total_samples = len(labels)
    class_weights = total_samples / (3 * class_counts)  # Normalize across 3 classes
    
    sample_weights = torch.zeros(len(labels))
    for i in range(3):
        sample_weights[labels[:, i] == 1] = class_weights[i]
    
    return sample_weights

# Print initial dataset size
print(f"Initial dataset size: {len(labels_tensor)}")

# Remove rows with NaN and update weights accordingly
nan_mask = torch.isnan(input_features_tensor).any(dim=1) | torch.isnan(labels_tensor).any(dim=1)
filtered_inputs = input_features_tensor[~nan_mask]
filtered_labels = labels_tensor[~nan_mask]

# Count samples in each class before downsampling
class_counts_before = torch.sum(filtered_labels, dim=0)
print(f"Class distribution before downsampling - Fake: {class_counts_before[0]}, Prompt: {class_counts_before[1]}, Displaced: {class_counts_before[2]}")

# Option to downsample each class (binary-class)
downsample_classes = True  # Set to False to disable downsampling
if downsample_classes:
    # Define downsampling ratios for each class:
    # For example, downsample fakes (class 0) to 50% and keep prompt (class 1) and displaced (class 2) at 100%
    downsample_ratios = {0: 0.5, 1: 1.0, 2: 1.0}
    indices_list = []
    for cls in range(3):
        # Find indices for the current class
        cls_mask = (filtered_labels[:, cls] == 1)
        cls_indices = torch.nonzero(cls_mask).squeeze()
        ratio = downsample_ratios.get(cls, 1.0)
        num_cls = cls_indices.numel()
        num_to_sample = int(num_cls * ratio)
        # Ensure at least one sample is kept if available
        if num_to_sample < 1 and num_cls > 0:
            num_to_sample = 1
        # Shuffle and select the desired number of samples
        cls_indices_shuffled = cls_indices[torch.randperm(num_cls)]
        sampled_cls_indices = cls_indices_shuffled[:num_to_sample]
        indices_list.append(sampled_cls_indices)
    
    # Combine the indices from all classes
    selected_indices = torch.cat(indices_list)
    filtered_inputs = filtered_inputs[selected_indices]
    filtered_labels = filtered_labels[selected_indices]

# Print class distribution after downsampling
class_counts_after = torch.sum(filtered_labels, dim=0)
print(f"Class distribution after downsampling - Fake: {class_counts_after[0]}, Prompt: {class_counts_after[1]}, Displaced: {class_counts_after[2]}")

# Recalculate sample weights after downsampling (equal weighting per class based on new counts)
sample_weights = calculate_class_weights(filtered_labels)
filtered_weights = sample_weights    

# Create dataset with weights
dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)

# Split into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)

# Initialize model and optimizer
model = MultiClassNeuralNetwork().to(device)
loss_function = WeightedCrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.0025)

def evaluate_loss(loader):
    model.eval()
    total_loss = 0
    num_batches = 0
    with torch.no_grad():
        for inputs, targets, weights in loader:
            inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, targets, weights)
            total_loss += loss.item()
            num_batches += 1
    return total_loss / num_batches

# Training loop
num_epochs = 300
train_loss_log = []
test_loss_log = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    num_batches = 0

    for inputs, targets, weights in train_loader:
        inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)
    
        # Forward pass
        outputs = model(inputs)
        loss = loss_function(outputs, targets, weights)
        epoch_loss += loss.item()
        num_batches += 1

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Calculate average losses
    train_loss = epoch_loss / num_batches
    test_loss = evaluate_loss(test_loader)
    
    train_loss_log.append(train_loss)
    test_loss_log.append(test_loss)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

Using device: cuda
Total samples: 1946967
Fake count: 1932985
Real count: 13982
Prompt count: 2190
Displaced count: 11792
Initial dataset size: 1946967
Class distribution before downsampling - Fake: 1932985.0, Prompt: 2190.0, Displaced: 11792.0
Class distribution after downsampling - Fake: 966492.0, Prompt: 2190.0, Displaced: 11792.0


Epoch [1/300], Train Loss: 0.8262, Test Loss: 0.7307
Epoch [2/300], Train Loss: 0.6984, Test Loss: 0.6723
Epoch [3/300], Train Loss: 0.6558, Test Loss: 0.6354
Epoch [4/300], Train Loss: 0.6422, Test Loss: 0.6504
Epoch [5/300], Train Loss: 0.6285, Test Loss: 0.6006
Epoch [6/300], Train Loss: 0.5979, Test Loss: 0.5772
Epoch [7/300], Train Loss: 0.5928, Test Loss: 0.5758
Epoch [8/300], Train Loss: 0.5960, Test Loss: 0.5796
Epoch [9/300], Train Loss: 0.5725, Test Loss: 0.5631
Epoch [10/300], Train Loss: 0.5847, Test Loss: 0.5568
Epoch [11/300], Train Loss: 0.5568, Test Loss: 0.5725
Epoch [12/300], Train Loss: 0.5551, Test Loss: 0.5485
Epoch [13/300], Train Loss: 0.5623, Test Loss: 0.5463
Epoch [14/300], Train Loss: 0.5651, Test Loss: 0.5555
Epoch [15/300], Train Loss: 0.5495, Test Loss: 0.5448
Epoch [16/300], Train Loss: 0.5534, Test Loss: 0.5455
Epoch [17/300], Train Loss: 0.5334, Test Loss: 0.6103
Epoch [18/300], Train Loss: 0.5272, Test Loss: 0.5251
Epoch [19/300], Train Loss: 0.5258, T

In [9]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score

# Convert tensors to numpy for simplicity if you want to manipulate them outside of PyTorch
input_features_np = input_features_tensor.numpy()
labels_np = torch.argmax(labels_tensor, dim=1).numpy()  # Convert one-hot to class indices

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def model_accuracy(features, labels, model):
    """
    Compute accuracy for a multi-class classification model
    that outputs probabilities of size [batch_size, num_classes].
    """
    model.eval()  # Set the model to evaluation mode
    
    # Move the features and labels to the correct device
    inputs = features.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        outputs = model(inputs)  # shape: [batch_size, num_classes]
        # For multi-class, the predicted class is argmax of the probabilities
        predicted = torch.argmax(outputs, dim=1)
        # Convert one-hot encoded labels to class indices if needed
        if len(labels.shape) > 1:
            labels = torch.argmax(labels, dim=1)
        # Compute mean accuracy
        accuracy = (predicted == labels).float().mean().item()
    
    return accuracy

# Compute baseline accuracy
baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)
print(f"Baseline accuracy: {baseline_accuracy:.4f}")

# Initialize array to store feature importances
feature_importances = np.zeros(input_features_tensor.shape[1])

# Iterate over each feature for permutation importance
for i in range(input_features_tensor.shape[1]):
    # Create a copy of the original features
    permuted_features = input_features_tensor.clone()
    
    # Permute feature i across all examples
    # We do this by shuffling the rows for that specific column
    permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i]
    
    # Compute accuracy after permutation
    permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)
    
    # The drop in accuracy is used as a measure of feature importance
    feature_importances[i] = baseline_accuracy - permuted_accuracy

# Sort features by descending importance
important_features_indices = np.argsort(feature_importances)[::-1]
important_features_scores = np.sort(feature_importances)[::-1]

# Print out results
print("\nFeature importances:")
for idx, score in zip(important_features_indices, important_features_scores):
    print(f"Feature {idx} importance: {score:.4f}")

Baseline accuracy: 0.8717

Feature importances:
Feature 28 importance: 0.0235
Feature 27 importance: 0.0229
Feature 0 importance: 0.0202
Feature 14 importance: 0.0136
Feature 18 importance: 0.0093
Feature 6 importance: 0.0086
Feature 16 importance: 0.0086
Feature 23 importance: 0.0068
Feature 24 importance: 0.0062
Feature 13 importance: 0.0047
Feature 17 importance: 0.0040
Feature 11 importance: 0.0034
Feature 22 importance: 0.0031
Feature 5 importance: 0.0023
Feature 25 importance: 0.0023
Feature 15 importance: 0.0011
Feature 10 importance: 0.0010
Feature 20 importance: 0.0008
Feature 1 importance: 0.0003
Feature 19 importance: 0.0001
Feature 8 importance: 0.0000
Feature 21 importance: -0.0001
Feature 2 importance: -0.0005
Feature 9 importance: -0.0006
Feature 3 importance: -0.0009
Feature 7 importance: -0.0026
Feature 29 importance: -0.0026
Feature 26 importance: -0.0030
Feature 12 importance: -0.0035
Feature 4 importance: -0.0044


In [10]:
def print_formatted_weights_biases(weights, biases, layer_name):
    # Print biases
    print(f"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{")
    print(", ".join(f"{b:.7f}f" for b in biases) + " };")
    print()

    # Print weights
    print(f"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{")
    for row in weights.T:
        formatted_row = ", ".join(f"{w:.7f}f" for w in row)
        print(f"{{ {formatted_row} }},")
    print("};")
    print()

def print_model_weights_biases(model):
    # Make sure the model is in evaluation mode
    model.eval()

    # Iterate through all named modules in the model
    for name, module in model.named_modules():
        # Check if the module is a linear layer
        if isinstance(module, nn.Linear):
            # Get weights and biases
            weights = module.weight.data.cpu().numpy()
            biases = module.bias.data.cpu().numpy()

            # Print formatted weights and biases
            print_formatted_weights_biases(weights, biases, name.replace('.', '_'))

print_model_weights_biases(model)


ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {
0.3130181f, -0.3157252f, -0.0845900f, -0.2268437f, 0.1305549f, 0.3839142f, 0.3933745f, -0.6758229f, -0.4188058f, -0.2523611f, 1.4036129f, 0.8239079f, 0.1575654f, 0.2041763f, 0.8787493f, 0.2706699f, -0.1112185f, 0.8988609f, 0.9274163f, -0.1023219f, 0.2916122f, -0.2606929f, 0.3098971f, -0.0602703f, -0.6031470f, -0.0826582f, 0.3605700f, 0.4836628f, -0.3951748f, 0.0171050f, 0.5156327f, 0.0655813f };

ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[30][32] = {
{ -0.3409404f, -0.2000102f, -0.0890483f, -0.6186467f, -1.7570605f, 0.7890699f, -0.8753229f, 0.5488843f, -0.5376814f, -0.2228569f, -0.3573552f, 2.1554949f, 0.2248887f, -0.6073594f, 0.2075009f, -0.1408760f, -0.7051892f, -0.0664303f, 0.2747473f, 0.1450685f, -2.2709231f, -0.4088669f, 0.5452566f, 0.3086576f, -0.1213564f, -0.9737034f, 0.2679004f, -0.1193755f, 0.9693206f, -0.7785844f, 0.4612639f, 0.7628022f },
{ 0.0309823f, -0.5052885f, 0.0509167f, 0.1379152f, 0.2392345f, 0.09

In [11]:
# Ensure input_features_tensor is moved to the appropriate device
input_features_tensor = input_features_tensor.to(device)
filtered_inputs = input_features_tensor[~nan_mask]
filtered_labels = labels_tensor[~nan_mask]

# Make predictions
with torch.no_grad():
    model.eval()
    outputs = model(input_features_tensor)
    predictions = outputs.squeeze().cpu().numpy()

full_tracks = (np.concatenate(branches['t4_pMatched']) > 0.95)


In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
import torch

# Ensure input_features_tensor is on the right device
input_features_tensor = input_features_tensor.to(device)

t4_pt = np.concatenate(branches['t4_pt'])

# Get model predictions
with torch.no_grad():
    model.eval()
    outputs = model(input_features_tensor)
    predictions = outputs.cpu().numpy()  # Shape will be [n_samples, 3]


def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, t4_pt, predictions, t4_sim_vxy, eta_list):
    """
    Calculate and plot cut values for specified percentiles in a given pt bin, separately for prompt and displaced tracks
    """
    # Filter data based on pt bin
    pt_mask = (t4_pt > pt_min) & (t4_pt <= pt_max)
    
    # Get absolute eta values for all tracks in pt bin
    abs_eta = np.abs(eta_list[0][pt_mask])
    
    # Get predictions for all tracks in pt bin
    pred_filtered = predictions[pt_mask]
    
    # Get track types using pMatched and t4_sim_vxy
    matched = (np.concatenate(branches['t4_pMatched']) > 0.95)[pt_mask]
    fake_tracks = (np.concatenate(branches['t4_pMatched']) <= 0.75)[pt_mask]
    true_displaced = (t4_sim_vxy[pt_mask] > 0.1) & matched
    
    # Separate plots for prompt and displaced tracks
    for track_type, true_mask, pred_idx, title_suffix in [
        ("Displaced", true_displaced, 2, "Displaced Real Tracks"),
        ("Fake", true_displaced, 0, "Displaced Real Tracks")
    ]:
        # Dictionaries to store values
        cut_values = {p: [] for p in percentiles}
        fake_rejections = {p: [] for p in percentiles}
        
        # Get probabilities for this class
        probs = pred_filtered[:, pred_idx]
        
        # Create two side-by-side plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # Plot probability distribution (only for true tracks of this type)
        h = ax1.hist2d(abs_eta[true_mask], 
                      probs[true_mask], 
                      bins=[eta_bin_edges, 50], 
                      norm=LogNorm())
        plt.colorbar(h[3], ax=ax1, label='Counts')
        
        # For each eta bin
        bin_centers = []
        for i in range(len(eta_bin_edges) - 1):
            eta_min, eta_max = eta_bin_edges[i], eta_bin_edges[i+1]
            bin_center = (eta_min + eta_max) / 2
            bin_centers.append(bin_center)
            
            # Get tracks in this eta bin
            eta_mask = (abs_eta >= eta_min) & (abs_eta < eta_max)
            
            # True tracks of this type in this bin
            true_type_mask = eta_mask & true_mask
            # Fake tracks in this bin
            fake_mask = eta_mask & fake_tracks
            
            if track_type == "Displaced":
                print(f"Eta bin {eta_min:.2f}-{eta_max:.2f}: {np.sum(fake_mask)} fakes, {np.sum(true_type_mask)} true {track_type}")
            
            if np.sum(true_type_mask) > 0:  # If we have true tracks in this bin
                for percentile in percentiles:
                    # Calculate cut value to keep desired percentage of true tracks
                    if track_type == "Fake":
                        cut_value = np.percentile(probs[true_type_mask], percentile)
                    else:
                        cut_value = np.percentile(probs[true_type_mask], 100 - percentile)
                    cut_values[percentile].append(cut_value)
                    
                    # Calculate fake rejection for this cut
                    if np.sum(fake_mask) > 0:
                        if track_type == "Fake":
                            fake_rej = 100 * np.mean(probs[fake_mask] > cut_value)
                        else:
                            fake_rej = 100 * np.mean(probs[fake_mask] < cut_value)
                        fake_rejections[percentile].append(fake_rej)
                    else:
                        fake_rejections[percentile].append(np.nan)
            else:
                for percentile in percentiles:
                    cut_values[percentile].append(np.nan)
                    fake_rejections[percentile].append(np.nan)
        
        # Plot cut values and fake rejections
        colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles)))
        bin_centers = np.array(bin_centers)
        
        for (percentile, color) in zip(percentiles, colors):
            values = np.array(cut_values[percentile])
            mask = ~np.isnan(values)
            if np.any(mask):
                # Plot cut values
                ax1.plot(bin_centers[mask], values[mask], '-', color=color, marker='o',
                        label=f'{percentile}% Retention Cut')
                # Plot fake rejections
                rej_values = np.array(fake_rejections[percentile])
                ax2.plot(bin_centers[mask], rej_values[mask], '-', color=color, marker='o',
                        label=f'{percentile}% Cut')
        
        # Set plot labels and titles
        ax1.set_xlabel("Absolute Eta")
        ax1.set_ylabel(f"DNN {track_type} Probability")
        ax1.set_title(f"DNN Score vs Eta ({title_suffix})\npt: {pt_min:.1f} to {pt_max:.1f} GeV")
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        ax2.set_xlabel("Absolute Eta")
        ax2.set_ylabel("Fake Rejection (%)")
        ax2.set_title(f"Fake Rejection vs Eta\npt: {pt_min:.1f} to {pt_max:.1f} GeV")
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 100)
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print(f"\n{track_type} tracks, pt: {pt_min:.1f} to {pt_max:.1f} GeV")
        print(f"Number of true {track_type.lower()} tracks: {np.sum(true_mask)}")
        print(f"Number of fake tracks in pt bin: {np.sum(fake_tracks)}")
        
        for percentile in percentiles:
            print(f"\n{percentile}% Retention Cut Values:",
                  '{' + ', '.join(f"{x:.4f}" if not np.isnan(x) else 'nan' for x in cut_values[percentile]) + '}',
                  f"Mean: {np.round(np.nanmean(cut_values[percentile]), 4)}")
            print(f"{percentile}% Cut Fake Rejections:",
                  '{' + ', '.join(f"{x:.1f}" if not np.isnan(x) else 'nan' for x in fake_rejections[percentile]) + '}',
                  f"Mean: {np.round(np.nanmean(fake_rejections[percentile]), 1)}%")

def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, t4_pt, predictions, t4_sim_vxy, eta_list):
    """
    Analyze and plot for multiple pt bins and percentiles
    """
    for i in range(len(pt_bins) - 1):
        plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges,
                       t4_pt, predictions, t4_sim_vxy, eta_list)

# Run the analysis with same parameters as before
percentiles = [65, 70, 75, 80, 85, 90, 95]
pt_bins = [0, 5, np.inf]
eta_bin_edges = np.arange(0, 2.6, 0.1)

analyze_pt_bins(
    pt_bins=pt_bins,
    percentiles=percentiles,
    eta_bin_edges=eta_bin_edges,
    t4_pt=t4_pt,
    predictions=predictions,
    t4_sim_vxy=np.concatenate(branches['t4_sim_vxy']),
    eta_list=eta_list
)

Eta bin 0.00-0.10: 343773 fakes, 783 true Displaced
Eta bin 0.10-0.20: 277198 fakes, 685 true Displaced
Eta bin 0.20-0.30: 236575 fakes, 679 true Displaced
Eta bin 0.30-0.40: 243786 fakes, 803 true Displaced
Eta bin 0.40-0.50: 236255 fakes, 682 true Displaced
Eta bin 0.50-0.60: 215018 fakes, 856 true Displaced
Eta bin 0.60-0.70: 157631 fakes, 989 true Displaced
Eta bin 0.70-0.80: 117039 fakes, 746 true Displaced
Eta bin 0.80-0.90: 103566 fakes, 893 true Displaced
Eta bin 0.90-1.00: 63672 fakes, 710 true Displaced
Eta bin 1.00-1.10: 46689 fakes, 803 true Displaced
Eta bin 1.10-1.20: 54520 fakes, 765 true Displaced


Eta bin 1.20-1.30: 69878 fakes, 988 true Displaced
Eta bin 1.30-1.40: 28034 fakes, 447 true Displaced
Eta bin 1.40-1.50: 10484 fakes, 458 true Displaced
Eta bin 1.50-1.60: 16192 fakes, 1550 true Displaced
Eta bin 1.60-1.70: 26335 fakes, 1499 true Displaced
Eta bin 1.70-1.80: 29781 fakes, 1367 true Displaced
Eta bin 1.80-1.90: 45547 fakes, 1894 true Displaced
Eta bin 1.90-2.00: 41836 fakes, 1748 true Displaced
Eta bin 2.00-2.10: 11233 fakes, 1706 true Displaced
Eta bin 2.10-2.20: 8634 fakes, 1151 true Displaced
Eta bin 2.20-2.30: 4555 fakes, 1104 true Displaced
Eta bin 2.30-2.40: 7739 fakes, 1940 true Displaced
Eta bin 2.40-2.50: 1872 fakes, 1428 true Displaced
<Figure size 2000x800 with 3 Axes>

Displaced tracks, pt: 0.0 to 5.0 GeV
Number of true displaced tracks: 26674
Number of fake tracks in pt bin: 2397842

65% Retention Cut Values: {0.6532, 0.6635, 0.7246, 0.7622, 0.6931, 0.7138, 0.7964, 0.7897, 0.8166, 0.8314, 0.7027, 0.6347, 0.6674, 0.7002, 0.7721, 0.8819, 0.8457, 0.8013, 0.7852