In [3]:
# Standard Library Imports
import os
import pickle
import random
from collections import Counter

# Third-Party Libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import IterableDataset
from torch.nn import BatchNorm1d
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.transforms import ToUndirected
from torch_geometric.utils import add_self_loops
from sklearn.preprocessing import MinMaxScaler, label_binarize
from sklearn.metrics import roc_curve, auc, confusion_matrix
import uproot

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

In [4]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [5]:
file = uproot.open('/storage/mxg1065/MyxAODAnalysis_super3D.outputs.root')
tree = file['analysis;1']
branches = tree.arrays()

# 100 events and 187652 cells
# Arrays containing information about the energy, noise, snr, 
cell_e = np.array(branches['cell_e'])
cell_noise = np.array(branches['cell_noiseSigma'])
cell_snr = np.array(branches['cell_SNR'])
cell_eta = np.array(branches['cell_eta'])
cell_phi = np.array(branches['cell_phi'])

# Represents the index of the cluster that each cell corresponds to. If the index
# is 0, that means that the given cell does not belong to a cluster.
cell_to_cluster_index = np.array(branches['cell_cluster_index'])

# For each entry, contains the IDs of cells neighboring a given cell
neighbor = branches['neighbor']

num_of_events = len(cell_e) # 100 events

In [6]:
# We use the data arrays to crete a data dictionary, where each entry corresponds
# to the data of a given event; we scale this data.
data = {}

for i in range(num_of_events):
    data[f'data_{i}'] = np.concatenate((np.expand_dims(cell_snr[i], axis=1),
                                        np.expand_dims(cell_e[i], axis=1),
                                        np.expand_dims(cell_noise[i], axis=1),
                                        np.expand_dims(cell_eta[i], axis=1),
                                        np.expand_dims(cell_phi[i], axis=1)), axis=1)
    
# We combine the data into one array and apply the MinMaxScaler
combined_data = np.vstack([data[key] for key in data])
scaler = MinMaxScaler()
scaled_combined_data = scaler.fit_transform(combined_data)

# The scaled data is split to have the save structure as the original data dict
scaled_data = {}
start_idx = 0
for i in range(num_of_events):
    end_idx = start_idx + data[f"data_{i}"].shape[0]
    scaled_data[f"data_{i}"] = scaled_combined_data[start_idx:end_idx]
    start_idx = end_idx

print(scaled_data["data_0"])
print(scaled_data["data_0"].shape)

[[0.02640459 0.24943821 0.09700817 0.23466232 0.5085498 ]
 [0.02567646 0.24627675 0.09700815 0.23466876 0.52418756]
 [0.02508186 0.24369499 0.09700817 0.23467347 0.5398244 ]
 ...
 [0.02632877 0.24791998 0.03177016 0.62017125 0.4921177 ]
 [0.02705116 0.2511391  0.07512318 0.6461531  0.4921177 ]
 [0.02638626 0.24820389 0.04149057 0.6731673  0.4921177 ]]
(187652, 5)


In [7]:
# Removing any broken cells within the list of cells (These cells have
# no noise, which doesn't make sense)

broken_cells = []
for i in range(num_of_events):
    cells = np.argwhere(cell_noise[i]==0).flatten()
    broken_cells = np.squeeze(cells)

neighbor = neighbor[0]

neighbor_pairs_list = []
num_of_cells = len(neighbor) # 187652 cells

for i in range(num_of_cells):
    if i in broken_cells:
        continue
    for j in neighbor[i]:
        if j in broken_cells:
            continue
        neighbor_pairs_list.append((i, int(j)))

In [8]:
# These functions remove permutation variants
def canonical_form(t):
    return tuple(sorted(t))

def remove_permutation_variants(tuple_list):
    unique_tuples = set(canonical_form(t) for t in tuple_list)
    return [tuple(sorted(t)) for t in unique_tuples]

neighbor_pairs_list = np.array(remove_permutation_variants(neighbor_pairs_list))
print(neighbor_pairs_list.shape)

(1250242, 2)


In [9]:
# Initialize array for labels
labels_for_neighbor_pairs = np.zeros((num_of_events, len(neighbor_pairs_list)), dtype=int)

# Extracting the individual cells within a cell pair
cell_0 = cell_to_cluster_index[:, neighbor_pairs_list[:, 0]]
cell_1 = cell_to_cluster_index[:, neighbor_pairs_list[:, 1]]

# Computing labels using vectorized operations
same_cluster = cell_0 == cell_1 
both_nonzero = (cell_0 != 0) & (cell_1 != 0)

labels_for_neighbor_pairs[same_cluster & (cell_0 == 0)] = 0 # Lone-Lone (0)
labels_for_neighbor_pairs[same_cluster & (cell_0 != 0)] = 1 # True-True (1)
labels_for_neighbor_pairs[~same_cluster & both_nonzero] = 4 # Cluster-Cluster (4)
labels_for_neighbor_pairs[~same_cluster & (cell_0 == 0) & (cell_1 != 0)] = 3 # Lone-Cluster (3)
labels_for_neighbor_pairs[~same_cluster & (cell_0 != 0) & (cell_1 == 0)] = 2 # Cluster-Lone (2)

print(labels_for_neighbor_pairs.shape)
print(labels_for_neighbor_pairs)
print(np.unique(labels_for_neighbor_pairs[0]))

(100, 1250242)
[[0 0 0 ... 0 0 2]
 [3 3 0 ... 0 0 2]
 [1 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]
 [0 0 0 ... 2 0 0]]
[0 1 2 3 4]


In [10]:
class MultiClassBatchGenerator(IterableDataset):
    def __init__(self, x, neighbor_pairs, labels_for_neighbor_pairs, batch_size, class_counts):
        """
        Custom IterableDataset for multi-class batch generation.

        Arguments:
        x: (np.ndarray) Feature matrix from dict of shape data["data_i"].shape = (num_of_cells, num_features).
        neighbor_pairs: (np.ndarray) Array of cell pair indices, shape (num_pairs, 2).
        labels_for_neighbor_pairs: (np.ndarray) Array of labels per pair, shape (num_of_events, num_pairs).
        batch_size: (int) Total number of samples per batch.
        class_counts: (dict) Number of samples per class per batch, {class_label: num_samples_per_batch}.
        """
        self.x = x
        self.neighbor_pairs = neighbor_pairs
        self.labels_for_neighbor_pairs = labels_for_neighbor_pairs
        self.batch_size = batch_size
        self.class_counts = class_counts

        # Store indices per class for each event
        self.indices_per_event = self._compute_event_class_indices()

    def _compute_event_class_indices(self):
        """Precompute indices of pairs for each event and class."""
        event_class_indices = {}
        num_events = self.labels_for_neighbor_pairs.shape[0]

        for event_id in range(num_events):
            event_class_indices[event_id] = {}
            for cls in np.unique(self.labels_for_neighbor_pairs[event_id]):
                mask = self.labels_for_neighbor_pairs[event_id] == cls
                event_class_indices[event_id][cls] = np.where(mask)[0]

        return event_class_indices

    def _batch_generator(self):
        """Generator function that iterates over events and selects samples per class."""
        num_events = self.labels_for_neighbor_pairs.shape[0]

        for event_id in range(num_events):  # Iterate over events first
            selected_pairs = []
            selected_labels = []

            for cls, n_samples in self.class_counts.items():  # Iterate over classes
                if cls not in self.indices_per_event[event_id]:
                    continue  # Skip if no pairs for this class

                pair_indices = self.indices_per_event[event_id][cls]
                if len(pair_indices) < n_samples:
                    selected_idx = pair_indices  # Take all available samples
                else:
                    selected_idx = np.random.choice(pair_indices, size=n_samples, replace=False)  # Sample without replacement

                # Gather the selected pairs and corresponding labels
                # selected_pairs.append(self.neighbor_pairs[event_id][selected_idx])
                selected_pairs.append(self.neighbor_pairs[selected_idx])
                selected_labels.append(np.full(len(selected_idx), cls))  # Assign class labels

            # If no samples were selected, skip this event
            if len(selected_pairs) == 0:
                continue

            # Concatenate results
            selected_pairs = np.concatenate(selected_pairs, axis=0)
            selected_labels = np.concatenate(selected_labels, axis=0)

            # Shuffle within the event before yielding
            perm = np.random.permutation(len(selected_labels))
            selected_pairs = selected_pairs[perm]
            selected_labels = selected_labels[perm]

            # Yield x for this event along with (x1, x2) pairs and labels
            yield self.x[f"data_{event_id}"], selected_pairs, selected_labels

    def __iter__(self):
        return iter(self._batch_generator())

### Performing some checks

In [11]:
print(f"scaled_combined_data.shape: {scaled_combined_data.shape}")  # Expecting (total_data, num_features)
print(f"neighbor_pairs_list.shape: {neighbor_pairs_list.shape}")  # Expecting (num_pairs, 2)
print(f"labels_for_neighbor_pairs.shape: {labels_for_neighbor_pairs.shape}")  # Expecting (num_of_events, num_pairs)

# Define the parameters for instantiating the generator
batch_size = 20
class_counts = {0: 40000, 1: 40000, 2: 40000, 3: 40000, 4: 3000}
num_features = scaled_combined_data.shape[1]

# Create the generator instance
batch_generator = MultiClassBatchGenerator(
    x=scaled_data,  # Pass dictionary instead of concatenated array
    neighbor_pairs=neighbor_pairs_list, 
    labels_for_neighbor_pairs=labels_for_neighbor_pairs, 
    batch_size=batch_size, 
    class_counts=class_counts
)

# Explicitely get an iterator from the dataset
batch_iterator = iter(batch_generator)

# Retrieve the first batch
batch = next(batch_iterator)

# Unpack the batch according to the generator's output
x_batch, neighbor_pairs_batch, labels_batch = batch  # Only three outputs

print(f"x_batch shape: {x_batch.shape}")  
print(f"neighbor_pairs_batch shape: {neighbor_pairs_batch.shape}")  
print(f"labels_batch shape: {labels_batch.shape}")

# Display class distribution in the batch
label_counts = Counter(labels_batch.tolist())
print("Class distribution in batch:", label_counts)

num_batches = 3

for batch_num in range(num_batches):
    try:
        x_batch, neighbor_pairs_batch, labels_batch = next(batch_iterator)
    except StopIteration:
        print("No more batches available.")
        break

    mismatch_count = 0

    print(f"\nChecking Batch {batch_num + 1}:")

    # Ensure shape is (num_pairs, 2)
    neighbor_pairs_batch = neighbor_pairs_batch.reshape(-1, 2)

    for idx, (pair, assigned_label) in enumerate(zip(neighbor_pairs_batch, labels_batch)):
        i, j = pair
        true_labels = labels_for_neighbor_pairs[:, i]  # Ensure event filtering

        if assigned_label not in true_labels:
            mismatch_count += 1

    print(f"Total mismatches in Batch {batch_num + 1}: {mismatch_count} / {len(labels_batch)} | {mismatch_count/len(labels_batch):.3f}%")

scaled_combined_data.shape: (18765200, 5)
neighbor_pairs_list.shape: (1250242, 2)
labels_for_neighbor_pairs.shape: (100, 1250242)
x_batch shape: (187652, 5)
neighbor_pairs_batch shape: (163000, 2)
labels_batch shape: (163000,)
Class distribution in batch: Counter({2: 40000, 0: 40000, 1: 40000, 3: 40000, 4: 3000})

Checking Batch 1:
Total mismatches in Batch 1: 4684 / 163000 | 0.029%

Checking Batch 2:
Total mismatches in Batch 2: 4608 / 163000 | 0.028%

Checking Batch 3:
Total mismatches in Batch 3: 4632 / 163000 | 0.028%


1. **`scaled_combined_data.shape: (18765200, 5)`**  
   - This is the **feature matrix** containing data for all cells across events.  
   - **18765200** represents the total number of cells.  
   - **5** represents the number of features per cell.  

2. **`neighbor_pairs_list.shape: (1250242, 2)`**  
   - This represents the **neighbor (edge) pairs** of cells.  
   - **1250242** is the total number of pairs (edges) in the dataset.  
   - **2** indicates that each pair consists of two cell indices.

3. **`labels_for_neighbor_pairs.shape: (100, 1250242)`**  
   - This stores the **labels for cell pairs across events**.  
   - **100** represents the number of events.  
   - **1250242** represents the number of cell pairs in each event.  

4. **`x_batch.shape: (163000, 5)`**  
   - This is the **feature matrix for the current batch** of selected pairs.  
   - **163000** is the number of selected cell pairs in the batch.  
   - **5** is the number of features per cell.

5. **`neighbor_pairs_batch.shape: (163000, 2)`**  
   - This is the **subset of `neighbor_pairs_list`** used in the batch.  
   - **163000** is the number of selected pairs in the batch.  
   - **2** represents the two cell indices per pair.

6. **`edge_index_batch.shape: (2, 1250242)`**  
   - This represents the **edge index tensor** for the full dataset.  
   - **2** represents the row-wise format:  
     - Row 1: Source nodes  
     - Row 2: Target nodes  
   - **1250242** represents the total number of edges (pairs of cells).  

7. **`edge_index_out_batch.shape: (2, 163000)`**  
   - This is the **subset of `edge_index_batch`** corresponding to the batch.  
   - **2** follows the same format as `edge_index_batch`.  
   - **163000** represents the number of edges in the batch.  

8. **`labels_batch.shape: (163000,)`**  
   - This is the **batch of labels** corresponding to the selected edges.  
   - **163000** ensures each selected edge has one assigned label.

---

### **Overall Structure:**
- **Full Dataset:**  
  - `scaled_combined_data`: Stores all cell features.  
  - `neighbor_pairs_list`: Stores all neighbor relationships.  
  - `labels_for_neighbor_pairs`: Stores labels for all edges across events.  

- **Batch Data:**  
  - `x_batch`: Features of selected cells.  
  - `neighbor_pairs_batch`: Selected pairs of cells in the batch.  
  - `edge_index_out_batch`: Graph representation of selected pairs.  
  - `labels_batch`: Ground truth labels for the selected pairs.

### Model incorporation

In [46]:
class MultiClassBatchGenerator(IterableDataset):
    def __init__(self, x, neighbor_pairs, labels_for_neighbor_pairs, batch_size, class_counts):
        self.x = x
        self.neighbor_pairs = neighbor_pairs
        self.labels_for_neighbor_pairs = labels_for_neighbor_pairs
        self.batch_size = batch_size
        self.class_counts = class_counts
        self.indices_per_event = self._compute_event_class_indices()

    def _compute_event_class_indices(self):
        event_class_indices = {}
        num_events = self.labels_for_neighbor_pairs.shape[0]
        for event_id in range(num_events):
            event_class_indices[event_id] = {}
            for cls in np.unique(self.labels_for_neighbor_pairs[event_id]):
                mask = self.labels_for_neighbor_pairs[event_id] == cls
                event_class_indices[event_id][cls] = np.where(mask)[0]
        return event_class_indices

    def _batch_generator(self):
        num_events = self.labels_for_neighbor_pairs.shape[0]
        for event_id in range(num_events):
            selected_pairs = []
            selected_labels = []

            for cls, n_samples in self.class_counts.items():
                if cls not in self.indices_per_event[event_id]:
                    continue

                pair_indices = self.indices_per_event[event_id][cls]
                if len(pair_indices) < n_samples:
                    selected_idx = pair_indices
                else:
                    selected_idx = np.random.choice(pair_indices, size=n_samples, replace=False)

                selected_pairs.append(self.neighbor_pairs[selected_idx])
                selected_labels.append(np.full(len(selected_idx), cls))

            if len(selected_pairs) == 0:
                continue

            selected_pairs = np.concatenate(selected_pairs, axis=0)
            selected_labels = np.concatenate(selected_labels, axis=0)

            perm = np.random.permutation(len(selected_labels))
            selected_pairs = selected_pairs[perm]
            selected_labels = selected_labels[perm]

            edge_index_out = selected_pairs.T  # Transpose to match PyTorch Geometric format
            
            # **Debugging: Print the shapes before yielding**
            print(f"x[event_id] shape: {self.x[event_id].shape}")  # Should be (num_nodes, num_features)
            print(f"selected_pairs shape: {selected_pairs.shape}")  # Should be (num_edges, 2)
            print(f"edge_index_out shape: {edge_index_out.shape}")  # Should be (2, num_edges)
            print(f"selected_labels shape: {selected_labels.shape}")  # Should be (num_edges,)

            yield self.x[f"data_{event_id}"], selected_pairs, edge_index_out, selected_labels

    def __iter__(self):
        return iter(self._batch_generator())

In [47]:
class MultiEdgeClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=6, layer_weights=True, debug=False):
        super(MultiEdgeClassifier, self).__init__()
        self.debug = debug
        self.layer_weights_enabled = layer_weights  # Store setting

        # Node embedding layer
        self.node_embedding = nn.Linear(input_dim, hidden_dim)

        # Initialize convolution and batch norm layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.layer_weights = nn.ParameterList() if layer_weights else None  # Only create if enabled

        # First layer
        self.convs.append(GCNConv(hidden_dim, 128))
        self.bns.append(BatchNorm1d(128))
        if layer_weights:
            self.layer_weights.append(nn.Parameter(torch.tensor(1.0, requires_grad=True)))

        # Additional layers
        for i in range(1, num_layers):
            in_channels = 128 if i == 1 else 64
            out_channels = 64
            self.convs.append(GCNConv(in_channels, out_channels))
            self.bns.append(BatchNorm1d(out_channels))
            if layer_weights:
                self.layer_weights.append(nn.Parameter(torch.tensor(1.0, requires_grad=True)))

        # Edge classification layer (now output_dim is passed as a parameter)
        self.fc = nn.Linear(2 * hidden_dim, output_dim)

    def debug_print(self, message):
        if self.debug:
            print(message)

    def forward(self, x, edge_index, edge_index_out):
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float32)
        if isinstance(edge_index, np.ndarray):
            edge_index = torch.tensor(edge_index, dtype=torch.float32)
        edge_index = edge_index.T
        if isinstance(edge_index_out, np.ndarray):
            edge_index_out = torch.tensor(edge_index_out, dtype=torch.float32)
        self.debug_print(f"Input x shape: {x.shape}")
        self.debug_print(f"Input edge_index shape: {edge_index.shape}")
        self.debug_print(f"Input edge_index_out shape: {edge_index_out.shape}")

        # Node embedding
        x = self.node_embedding(x)
        self.debug_print(f"Node embedding output shape: {x.shape}")

        if x.dim() == 3 and x.size(0) == 1:  # Check and remove batch dimension
            x = x.squeeze(0)

        # Loop through convolution layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            self.debug_print(f"After GCNConv {i+1}: {x.shape}")
            if x.dim() == 3 and x.size(0) == 1:
                x = x.squeeze(0)
            x = self.bns[i](x)
            x = torch.relu(x)

            # Apply layer weight if enabled
            if self.layer_weights_enabled:
                x = x * self.layer_weights[i]
                self.debug_print(f"After Layer Weight {i+1}: {x.shape}")

        # Edge representations
        edge_rep = torch.cat([x[edge_index_out[0]], x[edge_index_out[1]]], dim=1)
        self.debug_print(f"Edge representation shape: {edge_rep.shape}")

        # Return Logits (size depends on output_dim)
        edge_scores = self.fc(edge_rep)
        return edge_scores

In [48]:
def loss_for_train_and_test(logits, labels, criterion=nn.CrossEntropyLoss()):
    return criterion(logits, labels)

In [49]:
def train_model(model, train_loader, optimizer, criterion):
    model.train()  # Set the model to training mode
    total_loss = 0
    correct = 0
    total = 0

    for data in train_loader:
        x, edge_index, edge_index_out, labels = data  # Make sure you unpack edge_index_out here
        optimizer.zero_grad()
        logits = model(x, edge_index, edge_index_out)  # Pass edge_index_out along with edge_index
        loss = loss_for_train_and_test(logits, labels, criterion)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted = logits.argmax(dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    return total_loss / len(train_loader), accuracy


def test_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            x, edge_index, labels = data
            logits = model(x, edge_index)
            loss = loss_for_train_and_test(logits, labels, criterion)

            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total
    return avg_loss, accuracy

In [50]:
def run_model(model, train_loader, test_loader, epochs=10, batch_size=20, class_counts=None, learning_rate=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    loss_per_epoch = []
    scores = []
    truth_labels = []

    avg_loss_training_true_class = []
    logits_training_true_class = []
    avg_loss_testing_true_class = []
    logits_testing_true_class = []
    avg_loss_training_bkg_classes = []
    logits_training_bkg_classes = []
    avg_loss_testing_bkg_classes = []
    logits_testing_bkg_classes = []

    for epoch in range(epochs):
        train_loss, train_accuracy = train_model(model, train_loader, optimizer, criterion)
        test_loss, test_accuracy = test_model(model, test_loader, criterion)

        loss_per_epoch.append({'train': train_loss, 'test': test_loss})
        scores.append({'train': train_accuracy, 'test': test_accuracy})

        # Track per-class losses and logits (assuming classes are numbered 0 to 4)
        avg_loss_training_true_class.append(train_loss)  # Example, adapt based on your data structure
        logits_training_true_class.append(None)  # Placeholder, replace with actual logits for true classes
        avg_loss_testing_true_class.append(test_loss)  # Example, adapt based on your data structure
        logits_testing_true_class.append(None)  # Placeholder, replace with actual logits for testing

        # Assuming background class labels exist from the previous configuration
        avg_loss_training_bkg_classes.append(None)  # Placeholder
        logits_training_bkg_classes.append(None)  # Placeholder
        avg_loss_testing_bkg_classes.append(None)  # Placeholder
        logits_testing_bkg_classes.append(None)  # Placeholder

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    return {
        'loss_per_epoch': loss_per_epoch,
        'scores': scores,
        'truth_labels': truth_labels,
        'avg_loss_training_true_class': avg_loss_training_true_class,
        'logits_training_true_class': logits_training_true_class,
        'avg_loss_testing_true_class': avg_loss_testing_true_class,
        'logits_testing_true_class': logits_testing_true_class,
        'avg_loss_training_bkg_classes': avg_loss_training_bkg_classes,
        'logits_training_bkg_classes': logits_training_bkg_classes,
        'avg_loss_testing_bkg_classes': avg_loss_testing_bkg_classes,
        'logits_testing_bkg_classes': logits_testing_bkg_classes
    }


In [51]:
# Custom collate function to handle the data
def collate_fn(data):
    # Assuming each `data` element is a tuple (x, edge_index, edge_index_out, labels)
    x = [d.x for d in data]
    edge_index = [d.edge_index for d in data]
    edge_index_out = [d.edge_index_out for d in data]  # Ensure this is included
    labels = [d.y for d in data]
    
    # Return a tuple of 4 elements
    return torch.stack(x, dim=0), torch.stack(edge_index, dim=0), torch.stack(edge_index_out, dim=0), torch.tensor(labels)


In [52]:
# Assume data has been loaded into `scaled_combined_data`, `neighbor_pairs_list`, `labels_for_neighbor_pairs`
batch_size = 20
class_counts = {0: 40000, 1: 40000, 2: 40000, 3: 40000, 4: 3000}

# Initialize the generator and model
batch_generator = MultiClassBatchGenerator(
    x=scaled_combined_data, 
    neighbor_pairs=neighbor_pairs_list, 
    labels_for_neighbor_pairs=labels_for_neighbor_pairs, 
    batch_size=batch_size, 
    class_counts=class_counts
)

train_loader = batch_generator

# Model definition
num_features = scaled_combined_data.shape[1]
num_classes = 5
model = MultiEdgeClassifier(input_dim=num_features, hidden_dim=128, output_dim=num_classes, debug=True)

# Run the model
metrics = run_model(model, train_loader, train_loader, epochs=10, batch_size=batch_size)

# Print or save metrics
print(metrics)

x[event_id] shape: (5,)
selected_pairs shape: (163000, 2)
edge_index_out shape: (2, 163000)
selected_labels shape: (163000,)


IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices