In [36]:
import random  
random.seed(42)  

import numpy as np  
import matplotlib.pyplot as plt 
import seaborn as sns  

from collections import Counter  

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import torch.optim as optim  

from torch.nn import BatchNorm1d  
from torch_geometric.data import Data  
from torch_geometric.nn import GCNConv  
from torch_geometric.transforms import ToUndirected  
from torch_geometric.utils import add_self_loops
from torch.utils.data import DataLoader, IterableDataset

from sklearn.metrics import roc_curve, auc, confusion_matrix  
from sklearn.preprocessing import MinMaxScaler  

import uproot  

### Preparing the Data Dictionary

In [37]:
file = uproot.open('/home/mxg1065/MyxAODAnalysis_super3D.outputs.root')
print(file.keys())

['analysis;1']


In [38]:
tree = file['analysis;1']
branches = tree.arrays()
print(tree.keys()) # Variables per event

['RunNumber', 'EventNumber', 'cell_eta', 'cell_phi', 'cell_x', 'cell_y', 'cell_z', 'cell_subCalo', 'cell_sampling', 'cell_size', 'cell_hashID', 'neighbor', 'seedCell_id', 'cell_e', 'cell_noiseSigma', 'cell_SNR', 'cell_time', 'cell_weight', 'cell_truth', 'cell_truth_indices', 'cell_shared_indices', 'cell_cluster_index', 'cluster_to_cell_indices', 'cluster_to_cell_weights', 'cell_to_cluster_e', 'cell_to_cluster_eta', 'cell_to_cluster_phi', 'cluster_eta', 'cluster_phi', 'cluster_e', 'cellsNo_cluster', 'clustersNo_event', 'jetEnergyWtdTimeAve', 'jetEta', 'jetPhi', 'jetE', 'jetPt', 'jetNumberPerEvent', 'cellIndices_per_jet']


In [39]:
# 100 events and 187652 cells
# Arrays containing information about the energy, noise, snr, 
cell_e = np.array(branches['cell_e'])
cell_noise = np.array(branches['cell_noiseSigma'])
cell_snr = np.array(branches['cell_SNR'])
cell_eta = np.array(branches['cell_eta'])
cell_phi = np.array(branches['cell_phi'])

# Represents the index of the cluster that each cell corresponds to. If the index
# is 0, that means that the given cell does not belong to a cluster.
cell_to_cluster_index = np.array(branches['cell_cluster_index'])

# For each entry, contains the IDs of cells neighboring a given cell
neighbor = branches['neighbor']

num_of_events = len(cell_e) # 100 events

In [40]:
# We use the data arrays to crete a data dictionary, where each entry corresponds
# to the data of a given event; we scale this data.
data = {}

for i in range(num_of_events):
    data[f'data_{i}'] = np.concatenate((np.expand_dims(cell_snr[i], axis=1),
                                        np.expand_dims(cell_e[i], axis=1),
                                        np.expand_dims(cell_noise[i], axis=1),
                                        np.expand_dims(cell_eta[i], axis=1),
                                        np.expand_dims(cell_phi[i], axis=1)), axis=1)
    
# We combine the data into one array and apply the MinMaxScaler
combined_data = np.vstack([data[key] for key in data])
scaler = MinMaxScaler()
scaled_combined_data = scaler.fit_transform(combined_data)

# The scaled data is split to have the save structure as the original data dict
scaled_data = {}
start_idx = 0
for i in range(num_of_events):
    end_idx = start_idx + data[f"data_{i}"].shape[0]
    scaled_data[f"data_{i}"] = scaled_combined_data[start_idx:end_idx]
    start_idx = end_idx

print(scaled_data["data_0"])
print(scaled_data["data_0"].shape)

[[0.02640459 0.24943821 0.09700817 0.23466232 0.5085498 ]
 [0.02567646 0.24627675 0.09700815 0.23466876 0.52418756]
 [0.02508186 0.24369499 0.09700817 0.23467347 0.5398244 ]
 ...
 [0.02632877 0.24791998 0.03177016 0.62017125 0.4921177 ]
 [0.02705116 0.2511391  0.07512318 0.6461531  0.4921177 ]
 [0.02638626 0.24820389 0.04149057 0.6731673  0.4921177 ]]
(187652, 5)


### Prepairing Neighbor Pairs

In [41]:
# The IDs of the broken cells (those with zero noise) are collected
broken_cells = []

for i in range(num_of_events):
    cells = np.argwhere(cell_noise[i]==0).flatten()
    broken_cells = np.squeeze(cells)

print(broken_cells)

[186986 187352]


In [42]:
# Since the values associated with neighbor[0] and neighbor[1] are all equal
# we will just work with neighbor[0] to simplify our calculations
neighbor = neighbor[0]

In [43]:
# We loop through the neighbor awkward array and remove the IDs associated
# with the broken cells.  Loops through all cells in the neighbor list. If the loop 
# reaches the cell numbers 186986 or 187352, loop skips over these inoperative cells. 
# The final list contains tuples (i,j) where i is the cell ID in question and the 
# js are the neighboring cell IDs
neighbor_pairs_list = []
num_of_cells = len(neighbor) # 187652 cells

for i in range(num_of_cells):
    if i in broken_cells:
        continue
    for j in neighbor[i]:
        if j in broken_cells:
            continue
        neighbor_pairs_list.append((i, int(j)))

In [44]:
# This code checks to see if the broken cells were removed
found_broken_cells = []

for pair in neighbor_pairs_list:
    # Loop through each cell in pair
    for cell in pair:
        # If the cell is broken, appends to list
        if cell in broken_cells:
            found_broken_cells.append(cell)

if found_broken_cells:
    print("Error: Broken cells are still present in neighbor pairs.")
else:
    print("Successfully excluded broken cells.")

Successfully excluded broken cells.


In [45]:
# These functions remove permutation variants
def canonical_form(t):
    return tuple(sorted(t))

def remove_permutation_variants(tuple_list):
    unique_tuples = set(canonical_form(t) for t in tuple_list)
    return [tuple(sorted(t)) for t in unique_tuples]

neighbor_pairs_list = np.array(remove_permutation_variants(neighbor_pairs_list))
print(neighbor_pairs_list)
print(neighbor_pairs_list.shape)

[[ 90345 119588]
 [  4388  17680]
 [ 39760  39825]
 ...
 [159757 168717]
 [ 62911  78974]
 [135353 135609]]
(1250242, 2)


### Plotting the Features

### Creating Labels for the Neighbor Pairs

For a given pair of cells and the IDs of the clusters that they belong to (i, j), if
1. i=j and both are nonzero, then both cells are part of the same cluster. 
    * We call these True-True pairs and label them with 1
2. i=j and both are zero, then both cells are not part of any cluster. 
    * We call these Lone-Lone pairs and label them with 0
3. i is nonzero and j=0, then cell i is part of a cluster while cell j is not. 
    * We call these Cluster-Lone pairs and label them with 2
4. i=0 and j is nonzero, then cell i is not part of a cluste while cell j is. 
    * We call these Lone-Cluster pairs and label them with 3
5. i is not the same as j and both are nonzero, then both cells are part of different clusters. 
    * We call these Cluster-Cluster pairs and label them with 4

In [46]:
# Initialize array for labels
labels_for_neighbor_pairs = np.zeros((num_of_events, len(neighbor_pairs_list)), dtype=int)

# Extracting the individual cells within a cell pair
cell_0 = cell_to_cluster_index[:, neighbor_pairs_list[:, 0]]
cell_1 = cell_to_cluster_index[:, neighbor_pairs_list[:, 1]]

# Computing labels using vectorized operations
same_cluster = cell_0 == cell_1 
both_nonzero = (cell_0 != 0) & (cell_1 != 0)

# Lone-Lone (0)
labels_for_neighbor_pairs[same_cluster & (cell_0 == 0)] = 0

# True-True (1)
labels_for_neighbor_pairs[same_cluster & (cell_0 != 0)] = 1

# Cluster-Cluster (4)
labels_for_neighbor_pairs[~same_cluster & both_nonzero] = 4

# Lone-Cluster (3)
labels_for_neighbor_pairs[~same_cluster & (cell_0 == 0) & (cell_1 != 0)] = 3

# Cluster-Lone (2)
labels_for_neighbor_pairs[~same_cluster & (cell_0 != 0) & (cell_1 == 0)] = 2

print(labels_for_neighbor_pairs.shape)
print(labels_for_neighbor_pairs)
print(np.unique(labels_for_neighbor_pairs[0]))

(100, 1250242)
[[0 0 0 ... 0 0 2]
 [3 3 0 ... 0 0 2]
 [1 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]
 [0 0 0 ... 2 0 0]]
[0 1 2 3 4]


### Building the Multi-Class Batch Generator

#### New Version

In [None]:
class MultiClassBatchGenerator(IterableDataset):
    def __init__(self, x, neighbor_pairs, labels_for_neighbor_pairs, batch_size, class_counts):
        '''
        Class Purpose: 
        This class is a custom PyTorch IterableDataset designed to generate mini-batches of data 
        for multi-class classification tasks. It supports fixed class-based sampling per batch.

        How it works:
        1. Initializes with event data, labels, and the exact number of samples per class.
        2. Precomputes permutations for each event, avoiding per-batch shuffling.
        3. Selects samples based on precomputed permutations and yields batches.

        Arguments:
        x: (np.ndarray) Feature matrix of shape (num_of_events, num_features).
        neighbor_pairs: (np.ndarray) Array of cell pair indices of shape (num_of_events, num_pairs, 2).
        labels_for_neighbor_pairs: (np.ndarray) Array of labels for each pair in each event, shape (num_of_events, num_pairs).
        batch_size: (int) The total number of samples per batch.
        class_counts: (dict) Specifies the exact number of samples to include per class in each batch, e.g., {1: 15, 2: 10, 3: 5}
        '''
        self.x = x
        self.neighbor_pairs = neighbor_pairs
        self.labels_for_neighbor_pairs = labels_for_neighbor_pairs
        self.batch_size = batch_size
        self.class_counts = class_counts  # Dictionary {class_label: num_samples_per_batch}

        # Store indices per class
        self.indices = {cls: np.where(labels_for_neighbor_pairs == cls) for cls in np.unique(labels_for_neighbor_pairs)}

        # Generate precomputed permutations for each event
        self.permutations = self._precompute_permutations()

    def _precompute_permutations(self):
        """Precompute a permutation of pair indices and corresponding labels for each event."""
        permutations = {}
        
        for cls, (event_ids, pair_ids) in self.indices.items():
            permuted_indices = []
            for event in np.unique(event_ids):
                event_mask = event_ids == event
                shuffled_indices = np.random.permutation(pair_ids[event_mask])  # Shuffle indices
                
                # Store tuples of (event_id, shuffled pair indices)
                permuted_indices.append((event, shuffled_indices))
            
            permutations[cls] = permuted_indices
        
        return permutations

    def _batch_generator(self):
        """Generates batches dynamically using precomputed permutations."""
        
        while True:
            batch_data = []
            batch_labels = []

            for cls, n_samples in self.class_counts.items():
                if cls not in self.permutations or len(self.permutations[cls]) == 0:
                    continue

                selected_pairs = []
                selected_events = []

                for event, shuffled_pairs in self.permutations[cls]:
                    if len(shuffled_pairs) < n_samples:
                        continue  # Skip if not enough pairs
                    
                    selected_indices = shuffled_pairs[:n_samples]
                    selected_pairs.extend(selected_indices)
                    selected_events.extend([event] * n_samples)
                    
                    break  # Stop after collecting enough samples

                if len(selected_events) == 0:
                    continue  # Skip if no valid samples found

                # Ensure features and labels are sampled correctly
                batch_data.extend(self.x[selected_events])  
                batch_labels.extend(self.labels_for_neighbor_pairs[selected_events, selected_pairs])  

            # Convert to numpy arrays before yielding
            batch_data = np.array(batch_data)
            batch_labels = np.array(batch_labels)

            yield batch_data, batch_labels

    def __iter__(self):
        return iter(self._batch_generator())


#### Old Version

In [None]:
def randomize_data(indices, labels):
    """Randomize indices and labels while keeping the same permutation."""
    for i in range(indices.shape[0]):  # Iterate over events
        perm = np.random.permutation(indices.shape[1])
        indices[i] = indices[i, perm]
        labels[i] = labels[i, perm]
    return indices, labels

# List of how many of each class I want in a batch - fix this instead of class ratios
# Notice, for every batch, you are shuffling, instead you should get a permtation for each event,
# then you select n_i of each type of permutation, then you move the data out

class MultiClassBatchGenerator(IterableDataset):
    def __init__(self, x, neighbor_pairs, labels_for_neighbor_pairs, batch_size, class_sampling_ratio):
        '''
        Class Purpose: 
        This class is a custom PyTorch IterableDataset designed to generate mini-batches of data 
        for multi-class classification tasks. It supports class-specific sampling, where certain
        classes can be more prevalent than others.

        How it works:
        1. Initializes with event data, labels, and class sampling ratios.
        2. Shuffles and samples indices based on class distribution.
        3. Yields batches of data and labels.

        Arguments:
        x: (np.ndarray) Feature matrix of shape (num_of_events, num_features).
        neighbor_pairs: (np.ndarray) Array of cell pair indices of shape (num_of_events, num_pairs, 2).
        labels_for_neighbor_pairs: (np.ndarray) Array of labels for each pair in each event, shape (num_of_events, num_pairs).
        batch_size: (int) The total number of samples per batch.
        class_sampling_ratio: (dict, optional) Dictionary defining the number of samples for each class per batch, e.g., {1: 15, 2: 10, 3: 5}
        '''
        self.x = x
        self.neighbor_pairs = neighbor_pairs
        self.labels_for_neighbor_pairs = labels_for_neighbor_pairs
        self.batch_size = batch_size
        self.class_sampling_ratio = class_sampling_ratio

        # The indices for each class is stored in a dictionary
        self.indices = {cls: np.where(labels_for_neighbor_pairs == cls) for cls in np.unique(labels_for_neighbor_pairs)}
        
    def _shuffle_indices(self):
        """Shuffles indices and labels together, maintaining the relationship between the two"""
        for cls in self.indices:
            event_ids, pair_ids = self.indices[cls]
            shuffled_event_ids, shuffled_pair_ids = randomize_data(event_ids.reshape(1, -1), pair_ids.reshape(1, -1))
            self.indices[cls] = (shuffled_event_ids.flatten(), shuffled_pair_ids.flatten())

            # Shuffling over and over, instead of in memory, give set of indices for each class and event, up to the number you need
            # don't pass everything. Returns (after permutation) the number that you need for the class.

            # Give back list of permutations -> then making batch just get the number that you want for each class

    def _batch_generator(self):
        """Generates batches dynamically according to the class_sampling_ratio"""
        self._shuffle_indices()
        
        while True:
            batch_data = []
            batch_labels = []
            
            for cls, num_samples in self.class_sampling_ratio.items():
                event_ids, pair_ids = self.indices[cls]

                # Select the required number of samples and gather the corresponding data and labels to the batch
                selected_indices = np.random.choice(len(pair_ids), num_samples, replace=False)
                selected_events = event_ids[selected_indices]
                selected_pairs = pair_ids[selected_indices]
                
                batch_data.extend(self.x[selected_events])
                batch_labels.extend(self.labels_for_neighbor_pairs[selected_events, selected_pairs])

            # Shuffle entire batch (cross-class shuffle)
            perm = np.random.permutation(len(batch_data))
            batch_data = np.array(batch_data)[perm]
            batch_labels = np.array(batch_labels)[perm]
            
            if len(batch_data) >= self.batch_size:
                yield batch_data, batch_labels

    def __iter__(self):
        return iter(self._batch_generator())