In [2]:
# Standard library imports
import random  
import numpy as np  

# Third-party libraries
from torch.utils.data import IterableDataset
from sklearn.preprocessing import MinMaxScaler  
import uproot  

# Set random seed for reproducibility
random.seed(42)

### Preparing the Data Dictionary

In [3]:
file = uproot.open('/storage/mxg1065/MyxAODAnalysis_super3D.outputs.root')
print(file.keys())

['analysis;1']


In [4]:
tree = file['analysis;1']
branches = tree.arrays()
print(tree.keys()) # Variables per event

['RunNumber', 'EventNumber', 'cell_eta', 'cell_phi', 'cell_x', 'cell_y', 'cell_z', 'cell_subCalo', 'cell_sampling', 'cell_size', 'cell_hashID', 'neighbor', 'seedCell_id', 'cell_e', 'cell_noiseSigma', 'cell_SNR', 'cell_time', 'cell_weight', 'cell_truth', 'cell_truth_indices', 'cell_shared_indices', 'cell_cluster_index', 'cluster_to_cell_indices', 'cluster_to_cell_weights', 'cell_to_cluster_e', 'cell_to_cluster_eta', 'cell_to_cluster_phi', 'cluster_eta', 'cluster_phi', 'cluster_e', 'cellsNo_cluster', 'clustersNo_event', 'jetEnergyWtdTimeAve', 'jetEta', 'jetPhi', 'jetE', 'jetPt', 'jetNumberPerEvent', 'cellIndices_per_jet']


In [5]:
# 100 events and 187652 cells
# Arrays containing information about the energy, noise, snr, 
cell_e = np.array(branches['cell_e'])
cell_noise = np.array(branches['cell_noiseSigma'])
cell_snr = np.array(branches['cell_SNR'])
cell_eta = np.array(branches['cell_eta'])
cell_phi = np.array(branches['cell_phi'])

# Represents the index of the cluster that each cell corresponds to. If the index
# is 0, that means that the given cell does not belong to a cluster.
cell_to_cluster_index = np.array(branches['cell_cluster_index'])

# For each entry, contains the IDs of cells neighboring a given cell
neighbor = branches['neighbor']

num_of_events = len(cell_e) # 100 events

In [6]:
# We use the data arrays to crete a data dictionary, where each entry corresponds
# to the data of a given event; we scale this data.
data = {}

for i in range(num_of_events):
    data[f'data_{i}'] = np.concatenate((np.expand_dims(cell_snr[i], axis=1),
                                        np.expand_dims(cell_e[i], axis=1),
                                        np.expand_dims(cell_noise[i], axis=1),
                                        np.expand_dims(cell_eta[i], axis=1),
                                        np.expand_dims(cell_phi[i], axis=1)), axis=1)
    
# We combine the data into one array and apply the MinMaxScaler
combined_data = np.vstack([data[key] for key in data])
scaler = MinMaxScaler()
scaled_combined_data = scaler.fit_transform(combined_data)

# The scaled data is split to have the save structure as the original data dict
scaled_data = {}
start_idx = 0
for i in range(num_of_events):
    end_idx = start_idx + data[f"data_{i}"].shape[0]
    scaled_data[f"data_{i}"] = scaled_combined_data[start_idx:end_idx]
    start_idx = end_idx

print(scaled_data["data_0"])
print(scaled_data["data_0"].shape)

[[0.02640459 0.24943821 0.09700817 0.23466232 0.5085498 ]
 [0.02567646 0.24627675 0.09700815 0.23466876 0.52418756]
 [0.02508186 0.24369499 0.09700817 0.23467347 0.5398244 ]
 ...
 [0.02632877 0.24791998 0.03177016 0.62017125 0.4921177 ]
 [0.02705116 0.2511391  0.07512318 0.6461531  0.4921177 ]
 [0.02638626 0.24820389 0.04149057 0.6731673  0.4921177 ]]
(187652, 5)


### Prepairing Neighbor Pairs

In [7]:
# The IDs of the broken cells (those with zero noise) are collected
broken_cells = []

for i in range(num_of_events):
    cells = np.argwhere(cell_noise[i]==0).flatten()
    broken_cells = np.squeeze(cells)

print(broken_cells)

[186986 187352]


In [8]:
# Since the values associated with neighbor[0] and neighbor[1] are all equal
# we will just work with neighbor[0] to simplify our calculations
neighbor = neighbor[0]

In [9]:
# We loop through the neighbor awkward array and remove the IDs associated
# with the broken cells.  Loops through all cells in the neighbor list. If the loop 
# reaches the cell numbers 186986 or 187352, loop skips over these inoperative cells. 
# The final list contains tuples (i,j) where i is the cell ID in question and the 
# js are the neighboring cell IDs
neighbor_pairs_list = []
num_of_cells = len(neighbor) # 187652 cells

for i in range(num_of_cells):
    if i in broken_cells:
        continue
    for j in neighbor[i]:
        if j in broken_cells:
            continue
        neighbor_pairs_list.append((i, int(j)))

In [10]:
# This code checks to see if the broken cells were removed
found_broken_cells = []

for pair in neighbor_pairs_list:
    # Loop through each cell in pair
    for cell in pair:
        # If the cell is broken, appends to list
        if cell in broken_cells:
            found_broken_cells.append(cell)

if found_broken_cells:
    print("Error: Broken cells are still present in neighbor pairs.")
else:
    print("Successfully excluded broken cells.")

Successfully excluded broken cells.


In [11]:
# These functions remove permutation variants
def canonical_form(t):
    return tuple(sorted(t))

def remove_permutation_variants(tuple_list):
    unique_tuples = set(canonical_form(t) for t in tuple_list)
    return [tuple(sorted(t)) for t in unique_tuples]

neighbor_pairs_list = np.array(remove_permutation_variants(neighbor_pairs_list))
print(neighbor_pairs_list)
print(neighbor_pairs_list.shape)

[[ 90345 119588]
 [  4388  17680]
 [ 39760  39825]
 ...
 [159757 168717]
 [ 62911  78974]
 [135353 135609]]
(1250242, 2)


### Plotting the Features

### Creating Labels for the Neighbor Pairs

For a given pair of cells and the IDs of the clusters that they belong to (i, j), if
1. i=j and both are nonzero, then both cells are part of the same cluster. 
    * We call these True-True pairs and label them with 1
2. i=j and both are zero, then both cells are not part of any cluster. 
    * We call these Lone-Lone pairs and label them with 0
3. i is nonzero and j=0, then cell i is part of a cluster while cell j is not. 
    * We call these Cluster-Lone pairs and label them with 2
4. i=0 and j is nonzero, then cell i is not part of a cluste while cell j is. 
    * We call these Lone-Cluster pairs and label them with 3
5. i is not the same as j and both are nonzero, then both cells are part of different clusters. 
    * We call these Cluster-Cluster pairs and label them with 4

In [12]:
# Initialize array for labels
labels_for_neighbor_pairs = np.zeros((num_of_events, len(neighbor_pairs_list)), dtype=int)

# Extracting the individual cells within a cell pair
cell_0 = cell_to_cluster_index[:, neighbor_pairs_list[:, 0]]
cell_1 = cell_to_cluster_index[:, neighbor_pairs_list[:, 1]]

# Computing labels using vectorized operations
same_cluster = cell_0 == cell_1 
both_nonzero = (cell_0 != 0) & (cell_1 != 0)

# Lone-Lone (0)
labels_for_neighbor_pairs[same_cluster & (cell_0 == 0)] = 0

# True-True (1)
labels_for_neighbor_pairs[same_cluster & (cell_0 != 0)] = 1

# Cluster-Cluster (4)
labels_for_neighbor_pairs[~same_cluster & both_nonzero] = 4

# Lone-Cluster (3)
labels_for_neighbor_pairs[~same_cluster & (cell_0 == 0) & (cell_1 != 0)] = 3

# Cluster-Lone (2)
labels_for_neighbor_pairs[~same_cluster & (cell_0 != 0) & (cell_1 == 0)] = 2

print(labels_for_neighbor_pairs.shape)
print(labels_for_neighbor_pairs)
print(np.unique(labels_for_neighbor_pairs[0]))

(100, 1250242)
[[0 0 0 ... 0 0 2]
 [3 3 0 ... 0 0 2]
 [1 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]
 [0 0 0 ... 2 0 0]]
[0 1 2 3 4]


### Building the Multi-Class Batch Generator

In [13]:
class MultiClassBatchGenerator(IterableDataset):
    def __init__(self, x, neighbor_pairs, labels_for_neighbor_pairs, batch_size, class_counts):
        """
        Custom IterableDataset for multi-class batch generation.

        Arguments:
        x: (np.ndarray) Feature matrix of shape (num_of_events, num_features).
        neighbor_pairs: (np.ndarray) Array of cell pair indices, shape (num_of_events, num_pairs, 2).
        labels_for_neighbor_pairs: (np.ndarray) Array of labels per pair, shape (num_of_events, num_pairs).
        batch_size: (int) Total number of samples per batch.
        class_counts: (dict) Number of samples per class per batch, {class_label: num_samples_per_batch} e.g., {1: 15, 2: 10, 3: 5}.
        """
        self.x = x
        self.neighbor_pairs = neighbor_pairs
        self.labels_for_neighbor_pairs = labels_for_neighbor_pairs
        self.batch_size = batch_size
        self.class_counts = class_counts

        # Store indices per class for each event
        self.indices_per_event = self._compute_event_class_indices()

    def _compute_event_class_indices(self):
        """Precompute indices of pairs for each event and class."""
        event_class_indices = {}
        num_events = self.labels_for_neighbor_pairs.shape[0]

        for event_id in range(num_events):
            event_class_indices[event_id] = {}
            for cls in np.unique(self.labels_for_neighbor_pairs[event_id]):
                mask = self.labels_for_neighbor_pairs[event_id] == cls
                event_class_indices[event_id][cls] = np.where(mask)[0]

        return event_class_indices
    '''
    If you make the above a dict of lists, you don't need to do the event_id loop
    '''

    def _batch_generator(self):
        """Generator function that iterates over events and selects samples per class."""
        num_events = self.labels_for_neighbor_pairs.shape[0]

        for event_id in range(num_events):  # Iterate over events first
            selected_pairs = []
            selected_labels = []

            for cls, n_samples in self.class_counts.items():  # Iterate over classes
                if cls not in self.indices_per_event[event_id]:
                    continue  # Skip if no pairs for this class

                pair_indices = self.indices_per_event[event_id][cls]
                if len(pair_indices) < n_samples:
                    selected_idx = pair_indices  # Take all available samples
                else:
                    selected_idx = np.random.choice(pair_indices, size=n_samples, replace=False)  # Sample without replacement

                # Gather the selected pairs and corresponding labels
                selected_pairs.append(self.neighbor_pairs[event_id][selected_idx])
                selected_labels.append(np.full(len(selected_idx), cls))  # Assign class labels

            # If no samples were selected, skip this event
            if len(selected_pairs) == 0:
                continue

            # Concatenate results
            selected_pairs = np.concatenate(selected_pairs, axis=0)
            selected_labels = np.concatenate(selected_labels, axis=0)

            # Shuffle within the event before yielding
            perm = np.random.permutation(len(selected_labels))
            selected_pairs = selected_pairs[perm]
            selected_labels = selected_labels[perm]

            # Yield x for this event along with (x1, x2) pairs and labels
            yield self.x[event_id], selected_pairs, selected_labels

    def __iter__(self):
        return iter(self._batch_generator())

In [None]:
# Define the parameters for instantiating the generator
batch_size = 1
class_counts = {0: 40000, 1: 40000, 2: 40000, 3: 40000, 4: 3000}
num_features = scaled_combined_data.shape[1]

# Create the generator instance
batch_generator = MultiClassBatchGenerator(
    x=scaled_combined_data, 
    neighbor_pairs=neighbor_pairs_list, 
    labels_for_neighbor_pairs=labels_for_neighbor_pairs, 
    batch_size=batch_size, 
    class_counts=class_counts
)

# Explicitely get an iterator from the dataset
batch_iterator = iter(batch_generator)
# Retrieve the first batch
batch = next(batch_iterator)

# Unpack the batch to inspect
x_batch, neighbor_pairs_batch, labels_batch = batch
print(f"x_batch shape: {x_batch.shape}")
print(f"neighbor_pairs_batch shape: {neighbor_pairs_batch.shape}")
print(f"labels_batch shape: {labels_batch.shape}")


IndexError: index 428328 is out of bounds for axis 0 with size 2

intantiate it, then call next, it will get a batch, validate by eye, then go through it and perform checks to make sure everything makes sense
* especially check to see that the pairs and labels are not being messed up, that the labels and pairs are correctly associated with one another

Can run a generator that takes another generator as an argument, takes one event to multiple events (concatenation)