In [1]:
import random  
random.seed(42)  

import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  

from collections import Counter  

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import torch.optim as optim  

from torch.nn import BatchNorm1d  
from torch_geometric.data import Data  
from torch_geometric.nn import GCNConv  
from torch_geometric.transforms import ToUndirected  
from torch_geometric.utils import add_self_loops  

from sklearn.metrics import roc_curve, auc, confusion_matrix  
from sklearn.preprocessing import MinMaxScaler  

import uproot  

## Preparing the Data Dictionary

In [2]:
file = uproot.open('/home/mxg1065/MyxAODAnalysis_super3D.outputs.root')
print(file.keys())

['analysis;1']


In [3]:
tree = file['analysis;1']
branches = tree.arrays()
print(tree.keys()) # Variables per event

['RunNumber', 'EventNumber', 'cell_eta', 'cell_phi', 'cell_x', 'cell_y', 'cell_z', 'cell_subCalo', 'cell_sampling', 'cell_size', 'cell_hashID', 'neighbor', 'seedCell_id', 'cell_e', 'cell_noiseSigma', 'cell_SNR', 'cell_time', 'cell_weight', 'cell_truth', 'cell_truth_indices', 'cell_shared_indices', 'cell_cluster_index', 'cluster_to_cell_indices', 'cluster_to_cell_weights', 'cell_to_cluster_e', 'cell_to_cluster_eta', 'cell_to_cluster_phi', 'cluster_eta', 'cluster_phi', 'cluster_e', 'cellsNo_cluster', 'clustersNo_event', 'jetEnergyWtdTimeAve', 'jetEta', 'jetPhi', 'jetE', 'jetPt', 'jetNumberPerEvent', 'cellIndices_per_jet']


In [4]:
# 100 events and 187652 cells
# Arrays containing information about the energy, noise, and snr for each cell
cell_e = np.array(branches['cell_e'])
cell_noise = np.array(branches['cell_noiseSigma'])
cell_snr = np.array(branches['cell_SNR'])

# Represents the index of the cluster that each cell corresponds to. If the index
# is 0, that means that the given cell does not belong to a cluster.
cell_to_cluster_index = np.array(branches['cell_cluster_index'])

# For each entry, contains the IDs of cells neighboring a given cell
neighbor = branches['neighbor']

num_of_events = len(cell_e) # 100 events

In [5]:
# We use the data arrays to crete a data dictionary, where each entry corresponds
# to the data of a given event; we scale this data.
data = {}

for i in range(num_of_events):
    data[f'data_{i}'] = np.concatenate((np.expand_dims(cell_snr[i], axis=1),
                                        np.expand_dims(cell_e[i], axis=1),
                                        np.expand_dims(cell_noise[i], axis=1)), axis=1)
    
# We combine the data into one array and apply the MinMaxScaler
combined_data = np.vstack([data[key] for key in data])
scaler = MinMaxScaler()
scaled_combined_data = scaler.fit_transform(combined_data)

# The scaled data is split to have the save structure as the original data dict
scaled_data = {}
start_idx = 0
for i in range(num_of_events):
    end_idx = start_idx + data[f"data_{i}"].shape[0]
    scaled_data[f"data_{i}"] = scaled_combined_data[start_idx:end_idx]
    start_idx = end_idx

print(scaled_data["data_0"])
print(scaled_data["data_0"].shape)

[[0.02640459 0.24943821 0.09700817]
 [0.02567646 0.24627675 0.09700815]
 [0.02508186 0.24369499 0.09700817]
 ...
 [0.02632877 0.24791998 0.03177016]
 [0.02705116 0.2511391  0.07512318]
 [0.02638626 0.24820389 0.04149057]]
(187652, 3)


## Preparing Neighbor Pairs

In [6]:
# The IDs of the broken cells (those with zero noise) are collected
broken_cells = []

for i in range(num_of_events):
    cells = np.argwhere(cell_noise[i]==0)
    broken_cells = np.squeeze(cells)

print(broken_cells)

[186986 187352]


In [7]:
# Since the values associated with neighbor[0] and neighbor[1] are all equal
# we will just work with neighbor[0] to simplify our calculations
neighbor = neighbor[0]

In [8]:
len(neighbor)

187652

In [10]:
# We loop through the neighbor awkward array and remove the IDs associated
# with the broken cells.  Loops through all cells in the neighbor list. If the loop 
# reaches the cell numbers 186986 or 187352, loop skips over these inoperative cells. 
# The final list contains tuples (i,j) where i is the cell ID in question and the 
# js are the neighboring cell IDs
neighbor_pairs_list = []
num_of_cells = len(neighbor) # 187652 cells

for i in range(num_of_cells):
    if i in broken_cells:
        continue
    for j in neighbor[i]:
        if j in broken_cells:
            continue
        neighbor_pairs_list.append((i, int(j)))


# # This code checks to see if the broken cells were removed
# found_broken_cells = []

# for pair in neighbor_pairs_list:
#     # Loop through each cell in pair
#     for cell in pair:
#         # If the cell is broken, appends to list
#         if cell in broken_cells:
#             found_broken_cells.append(cell)

# if found_broken_cells:
#     print("Error: Broken cells are still present in neighbor pairs.")
# else:
#     print("Successfully excluded broken cells.")

In [11]:
# These functions remove permutation variants
def canonical_form(t):
    return tuple(sorted(t))

def remove_permutation_variants(tuple_list):
    unique_tuples = set(canonical_form(t) for t in tuple_list)
    return [tuple(sorted(t)) for t in unique_tuples]

neighbor_pairs_list = np.array(remove_permutation_variants(neighbor_pairs_list))
print(neighbor_pairs_list)
print(neighbor_pairs_list.shape)

[[ 90345 119588]
 [  4388  17680]
 [ 39760  39825]
 ...
 [159757 168717]
 [ 62911  78974]
 [135353 135609]]
(1250242, 2)


## Creating Labels for the Neighbor Pairs

For a given pair of cells and the IDs of the clusters that they belong to (i, j), if
1. i=j and both are nonzero, then both cells are part of the same cluster. 
    * We call these True-True pairs and label them with 1
2. i=j and both are zero, then both cells are not part of any cluster. 
    * We call these Lone-Lone pairs and label them with 0
3. i is nonzero and j=0, then cell i is part of a cluster while cell j is not. 
    * We call these Cluster-Lone pairs and label them with 2
4. i=0 and j is nonzero, then cell i is not part of a cluste while cell j is. 
    * We call these Lone-Cluster pairs and label them with 3
5. i is not the same as j and both are nonzero, then both cells are part of different clusters. 
    * We call these Cluster-Cluster pairs and label them with 4

In [12]:
labels_for_neighbor_pairs = []
for i in range(num_of_events):
    labels_for_neighbor_pairs_for_event_i = []
    for pair in neighbor_pairs_list:
        if cell_to_cluster_index[i][pair[0]] == cell_to_cluster_index[i][pair[1]]:
            if cell_to_cluster_index[i][pair[0]] != 0:
                labels_for_neighbor_pairs_for_event_i.append(1) # True-True
            else:
                labels_for_neighbor_pairs_for_event_i.append(0) # Lone-Lone
        else:
            if cell_to_cluster_index[i][pair[0]] != 0 and cell_to_cluster_index[i][pair[1]] != 0:
                labels_for_neighbor_pairs_for_event_i.append(4) # Cluster-Cluster
            elif cell_to_cluster_index[i][pair[0]] == 0 and cell_to_cluster_index[i][pair[1]] != 0:
                labels_for_neighbor_pairs_for_event_i.append(3) # Lone-Cluster
            else:
                labels_for_neighbor_pairs_for_event_i.append(2) # Cluster-Lone
    labels_for_neighbor_pairs.append(labels_for_neighbor_pairs_for_event_i)

labels_for_neighbor_pairs = np.array(labels_for_neighbor_pairs)
print(labels_for_neighbor_pairs.shape)
print(labels_for_neighbor_pairs)
print(np.unique(labels_for_neighbor_pairs[0]))

(100, 1250242)
[[0 0 0 ... 0 0 2]
 [3 3 0 ... 0 0 2]
 [1 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]
 [0 0 0 ... 2 0 0]]
[0 1 2 3 4]


## Creating plots for the features

In [13]:
# # Here we create a dictionary to store features by class
# features_by_class = {cls: {'SNR': [], 'Energy': [], 'Noise': []}
#                      for cls in range(5)}

# # The features for all events are computed
# for i in range(num_of_events):  # Iterate through events
#     for pair_idx, pair in enumerate(neighbor_pairs_list):
#         class_label = labels_for_neighbor_pairs[i][pair_idx]

#         # Vectorized feature extraction for cell pairs
#         cell_1_features = [cell_snr[i][pair[0]], cell_e[i][pair[0]], cell_noise[i][pair[0]]]
#         cell_2_features = [cell_snr[i][pair[1]], cell_e[i][pair[1]], cell_noise[i][pair[1]]]

#         # Append features to the corresponding class
#         features_by_class[class_label]['SNR'] += [cell_1_features[0],cell_2_features[0]]
#         features_by_class[class_label]['Energy'] += [cell_1_features[1], cell_2_features[1]]
#         features_by_class[class_label]['Noise'] += [cell_1_features[2],cell_2_features[2]]

In [14]:
# _ = plt.hist(cell_snr[:][neighbor_pairs_list[labels_for_neighbor_pairs[:]==1][:,0]], bins=np.arange(-20,20,1))

In [15]:
# (cell_snr[0][neighbor_pairs_list[labels_for_neighbor_pairs[0]==1]][:,0]).shape
# # cell_snr[0][neighbor_pairs_list[0][labels_for_neighbor_pairs[0]==1][0]]

In [16]:
# _ =plt.hist(cell_snr[0], bins=np.arange(-10,10,.25))

In [17]:
# # We create a function to precompute bin edges for all features
# def precompute_bins(features_by_class, snr_xlim=None, energy_xlim=None, noise_xlim=None, bins=50):
#     bins_dict = {}

#     # Compute bins for SNR
#     if snr_xlim:
#         bins_dict['SNR'] = np.linspace(snr_xlim[0], snr_xlim[1], bins + 1)
#     else:
#         all_snr = np.concatenate([features['SNR']for features in features_by_class.values()])
#         bins_dict['SNR'] = np.histogram_bin_edges(all_snr, bins=bins)

#     # Compute bins for Energy
#     if energy_xlim:
#         bins_dict['Energy'] = np.linspace(energy_xlim[0], energy_xlim[1], bins + 1)
#     else:
#         all_energy = np.concatenate([features['Energy']for features in features_by_class.values()])
#         bins_dict['Energy'] = np.histogram_bin_edges(all_energy, bins=bins)

#     # Compute bins for Noise
#     if noise_xlim:
#         bins_dict['Noise'] = np.linspace(noise_xlim[0], noise_xlim[1], bins + 1)
#     else:
#         all_noise = np.concatenate([features['Noise']for features in features_by_class.values()])
#         bins_dict['Noise'] = np.histogram_bin_edges(all_noise, bins=bins)

#     return bins_dict


# # Optimized function to plot histograms using precomputed bins
# def plot_histograms_optimized(features_by_class, bins_dict):
#     for class_label, features in features_by_class.items():
#         if len(features['SNR']) > 0:  # Ensure there are features to plot
#             plt.figure(figsize=(15, 5))

#             # Plot SNR histogram with precomputed bins
#             plt.subplot(1, 3, 1)
#             plt.hist(features['SNR'], bins=bins_dict['SNR'],
#                      alpha=0.6, label=f'Class {class_label}')
#             plt.xlabel('SNR')
#             plt.ylabel('Frequency')
#             plt.title(f'Class {class_label}: SNR')
#             plt.grid(True)

#             # Plot Energy histogram with precomputed bins
#             plt.subplot(1, 3, 2)
#             plt.hist(features['Energy'], bins=bins_dict['Energy'],
#                      alpha=0.6, label=f'Class {class_label}')
#             plt.xlabel('Energy')
#             plt.ylabel('Frequency')
#             plt.title(f'Class {class_label}: Energy')
#             plt.grid(True)

#             # Plot Noise histogram with precomputed bins
#             plt.subplot(1, 3, 3)
#             plt.hist(features['Noise'], bins=bins_dict['Noise'],
#                      alpha=0.6, label=f'Class {class_label}')
#             plt.xlabel('Noise')
#             plt.ylabel('Frequency')
#             plt.title(f'Class {class_label}: Noise')
#             plt.grid(True)

#             plt.tight_layout()
#             plt.suptitle(
#                 f'Feature Distributions for Class {class_label}', fontsize=16)
#             plt.subplots_adjust(top=0.85)  # Adjust title to prevent overlap
#             plt.show()


# # Precompute bins for SNR, Energy, and Noise
# bins_dict = precompute_bins(features_by_class, snr_xlim=(-6.5, 6.5),
#                             energy_xlim=(-500, 500), noise_xlim=(0, 500))

# # Call the optimized function to plot histograms
# plot_histograms_optimized(features_by_class, bins_dict)

## Preparing the Data for Multi-Class Classification

In [18]:
# Here we collect the indices of the neighbor pairs by the pair type
indices_for_tt_pairs = []  # Label 1
indices_for_ll_pairs = []  # Label 0
indices_for_cl_pairs = []  # Label 2
indices_for_lc_pairs = []  # Label 3
indices_for_cc_pairs = []  # Label 4

for i in range(num_of_events):
    indices_for_tt_pairs.append(list(np.where(labels_for_neighbor_pairs[i] == 1)[0]))
    indices_for_ll_pairs.append(list(np.where(labels_for_neighbor_pairs[i] == 0)[0]))
    indices_for_cl_pairs.append(list(np.where(labels_for_neighbor_pairs[i] == 2)[0]))
    indices_for_lc_pairs.append(list(np.where(labels_for_neighbor_pairs[i] == 3)[0]))
    indices_for_cc_pairs.append(list(np.where(labels_for_neighbor_pairs[i] == 4)[0]))

In [19]:
# Here we collect the number of each pair type across the events
number_of_tt_pairs = [len(indices_for_tt_pairs[i])for i in range(num_of_events)]
number_of_ll_pairs = [len(indices_for_ll_pairs[i])for i in range(num_of_events)]
number_of_cl_pairs = [len(indices_for_cl_pairs[i])for i in range(num_of_events)]
number_of_lc_pairs = [len(indices_for_lc_pairs[i])for i in range(num_of_events)]
number_of_cc_pairs = [len(indices_for_cc_pairs[i])for i in range(num_of_events)]

In [20]:
# Here we perform a 70-30 split on the indices of neighbor pairs
training_indices_tt = indices_for_tt_pairs[:70]
training_indices_ll = indices_for_ll_pairs[:70]
training_indices_cl = indices_for_cl_pairs[:70]
training_indices_lc = indices_for_lc_pairs[:70]
training_indices_cc = indices_for_cc_pairs[:70]

testing_indices_tt = indices_for_tt_pairs[70:]
testing_indices_ll = indices_for_ll_pairs[70:]
testing_indices_cl = indices_for_cl_pairs[70:]
testing_indices_lc = indices_for_lc_pairs[70:]
testing_indices_cc = indices_for_cc_pairs[70:]

# Here we perform a 70-30 split on the number of neighbor pairs
training_num_tt = number_of_tt_pairs[:70]
training_num_ll = number_of_ll_pairs[:70]
training_num_cl = number_of_cl_pairs[:70]
training_num_lc = number_of_lc_pairs[:70]
training_num_cc = number_of_cc_pairs[:70]

testing_num_tt = number_of_tt_pairs[70:]
testing_num_ll = number_of_ll_pairs[70:]
testing_num_cl = number_of_cl_pairs[70:]
testing_num_lc = number_of_lc_pairs[70:]
testing_num_cc = number_of_cc_pairs[70:]

In [21]:
# We check the minimum number of each pair type across the events. When we
# randomly sample from the indices, if our sample is greater than the minimum
# numbers, then we will run into errors
print("Minimum number of pairs for training:")
print("True-True", min(training_num_tt))
print("Lone-Lone", min(training_num_ll))
print("Cluster-Lone", min(training_num_cl))
print("Lone-Cluster", min(training_num_lc))
print("Cluster-Cluster", min(training_num_cc))
print('\nMinimum number of pairs for testing:')
print("True-True", min(testing_num_tt))
print("Lone-Lone", min(testing_num_ll))
print("Cluster-Lone", min(testing_num_cl))
print("Lone-Cluster", min(testing_num_lc))
print("Cluster-Cluster", min(testing_num_cc))

Minimum number of pairs for training:
True-True 45600
Lone-Lone 926119
Cluster-Lone 41654
Lone-Cluster 45689
Cluster-Cluster 3334

Minimum number of pairs for testing:
True-True 51518
Lone-Lone 906630
Cluster-Lone 44444
Lone-Cluster 48069
Cluster-Cluster 4936


In [22]:
def sample_and_concatenate(indices, sample_size):
    """Sample indices and concatenate them into a single array."""
    # Ensure sample_size does not exceed the minimum number of pairs
    min_pairs = min(len(row) for row in indices)
    sample_size = min(sample_size, min_pairs)
    
    # Sample indices and reshape to 2D
    sampled_pairs = np.array([random.sample(row, sample_size) for row in indices])
    return sampled_pairs

def create_total_indices(training_indices, testing_indices, sample_sizes_train, sample_sizes_test):
    """Create training and testing indices for all pair types."""
    # Sample and concatenate training indices
    train_indices_pairs = {
        key: sample_and_concatenate(indices, size) 
        for key, indices, size in zip(["tt", "ll", "cl", "lc", "cc"], training_indices, sample_sizes_train)
    }
    train_indices_bkg = np.concatenate([train_indices_pairs[key] for key in ["ll", "cl", "lc", "cc"]], axis=1)
    total_training_indices = np.concatenate((train_indices_pairs["tt"], train_indices_bkg), axis=1)
    
    # Sample and concatenate testing indices
    test_indices_pairs = {
        key: sample_and_concatenate(indices, size) 
        for key, indices, size in zip(["tt", "ll", "cl", "lc", "cc"], testing_indices, sample_sizes_test)
    }
    test_indices_bkg = np.concatenate([test_indices_pairs[key] for key in ["ll", "cl", "lc", "cc"]], axis=1)
    total_testing_indices = np.concatenate((test_indices_pairs["tt"], test_indices_bkg), axis=1)
    
    return train_indices_pairs, test_indices_pairs, total_training_indices, total_testing_indices

def create_labels(num_events, num_samples, label_value):
    """Create a label array for a specific pair type."""
    return np.full((num_events, num_samples), label_value, dtype=int)

def create_total_labels(num_events_train, num_events_test, sample_sizes_train, sample_sizes_test):
    """Create training and testing labels for all pair types."""
    # Define label values for each pair type
    label_values = {
        "tt": 1,  # True-True
        "ll": 0,  # Lone-Lone
        "cl": 2,  # Cluster-Lone
        "lc": 3,  # Lone-Cluster
        "cc": 4,  # Cluster-Cluster
    }
    
    # Create training labels
    labels_train = {
        key: create_labels(num_events_train, size, value) 
        for key, value, size in zip(label_values.keys(), label_values.values(), sample_sizes_train)
    }
    labels_bkg_train = np.concatenate([labels_train[key] for key in ["ll", "cl", "lc", "cc"]], axis=1)
    labels_training = np.concatenate((labels_train["tt"], labels_bkg_train), axis=1)
    
    # Create testing labels
    labels_test = {
        key: create_labels(num_events_test, size, value) 
        for key, value, size in zip(label_values.keys(), label_values.values(), sample_sizes_test)
    }
    labels_bkg_test = np.concatenate([labels_test[key] for key in ["ll", "cl", "lc", "cc"]], axis=1)
    labels_testing = np.concatenate((labels_test["tt"], labels_bkg_test), axis=1)
    
    return labels_train, labels_test, labels_training, labels_testing

def randomize_data(indices, labels):
    """Randomize indices and labels while keeping the same permutation."""
    for i in range(indices.shape[0]):
        perm = np.random.permutation(indices.shape[1])
        indices[i] = indices[i, perm]
        labels[i] = labels[i, perm]
    return indices, labels

In [23]:
# Define sample sizes for training and testing for each version
sample_sizes_train_list = [
    [12000, 3000, 3000, 3000, 3000],  # Version 0
    [3000, 12000, 3000, 3000, 3000],  # Version 1
    [3000, 3000, 12000, 3000, 3000],  # Version 2
    [3000, 3000, 3000, 12000, 3000],  # Version 3
]

sample_sizes_test_list = [
    [12000, 3000, 3000, 3000, 3000],  # Version 0
    [3000, 12000, 3000, 3000, 3000],  # Version 1
    [3000, 3000, 12000, 3000, 3000],  # Version 2
    [3000, 3000, 3000, 12000, 3000],  # Version 3
]

# Combine training and testing indices into lists
training_indices = [training_indices_tt, training_indices_ll, training_indices_cl, training_indices_lc, training_indices_cc]
testing_indices = [testing_indices_tt, testing_indices_ll, testing_indices_cl, testing_indices_lc, testing_indices_cc]

# Number of events for training and testing
num_events_train = 70
num_events_test = 30

# Dictionaries to store the four versions of the indices and labels
indices_versions = {}
labels_versions = {}

# Loop through sample sizes
for i, (sample_sizes_train, sample_sizes_test) in enumerate(zip(sample_sizes_train_list, sample_sizes_test_list)):
    # Create indices
    train_indices_pairs, test_indices_pairs, total_training_indices, total_testing_indices = create_total_indices(
        training_indices, testing_indices, sample_sizes_train, sample_sizes_test
    )
    
    # Create labels
    labels_train, labels_test, labels_training, labels_testing = create_total_labels(
        num_events_train, num_events_test, sample_sizes_train, sample_sizes_test
    )
    
    # Randomize training data
    total_training_indices, labels_training = randomize_data(total_training_indices, labels_training)
    
    # Randomize testing data
    total_testing_indices, labels_testing = randomize_data(total_testing_indices, labels_testing)
    
    # Store the randomized data in the dictionaries
    indices_versions[f"version_{i}"] = {
        "training_indices": total_training_indices,
        "testing_indices": total_testing_indices,
        "testing_tt": test_indices_pairs["tt"],
        "testing_ll": test_indices_pairs["ll"],
        "testing_cl": test_indices_pairs["cl"],
        "testing_lc": test_indices_pairs["lc"],
        "testing_cc": test_indices_pairs["cc"],
    }
    
    labels_versions[f"version_{i}"] = {
        "training_labels": labels_training,
        "testing_labels": labels_testing,
        "testing_tt": labels_test["tt"],
        "testing_ll": labels_test["ll"],
        "testing_cl": labels_test["cl"],
        "testing_lc": labels_test["lc"],
        "testing_cc": labels_test["cc"],
    }
    
    # Print shapes for verification
    print(f"Version {i}:")
    print(f"Training indices shape: {total_training_indices.shape}")
    print(f"Training labels shape: {labels_training.shape}")
    print(f"Testing indices shape: {total_testing_indices.shape}")
    print(f"Testing labels shape: {labels_testing.shape}\n")

Version 0:
Training indices shape: (70, 24000)
Training labels shape: (70, 24000)
Testing indices shape: (30, 24000)
Testing labels shape: (30, 24000)

Version 1:
Training indices shape: (70, 24000)
Training labels shape: (70, 24000)
Testing indices shape: (30, 24000)
Testing labels shape: (30, 24000)

Version 2:
Training indices shape: (70, 24000)
Training labels shape: (70, 24000)
Testing indices shape: (30, 24000)
Testing labels shape: (30, 24000)

Version 3:
Training indices shape: (70, 24000)
Training labels shape: (70, 24000)
Testing indices shape: (30, 24000)
Testing labels shape: (30, 24000)



In [24]:
indices_versions["version_0"]["testing_tt"]

array([[ 650927,  720647,  961651, ...,  502670,  812034,  605121],
       [  86023,  600807,  957669, ...,   35772,  938896,  215158],
       [ 402697,  711389, 1240295, ...,   72706,  585196,  271946],
       ...,
       [  20387,  223281,   64746, ...,  558189,   50685,  609182],
       [ 944789,   72571,  406632, ...,  676971,  556269,   22966],
       [1172955,  492553, 1172522, ...,  363329,   62132,  241743]])

In [25]:
# Create a dictionary to store the neighbor pairs for each version
neighbor_pairs_versions = {}

# Loop through each version
for i in range(len(sample_sizes_train_list)):
    # Get the indices for the current version
    train_indices_version = indices_versions[f"version_{i}"]["training_indices"]
    test_indices_version = indices_versions[f"version_{i}"]["testing_indices"]
    
    # Efficiently index into neighbor_pairs_list for training and testing indices
    total_train_neighbor_random = neighbor_pairs_list[train_indices_version]
    total_test_neighbor_random = neighbor_pairs_list[test_indices_version]
    
    # Store the neighbor pairs in the dictionary for each version
    neighbor_pairs_versions[f"version_{i}"] = {
        "train_neighbors": total_train_neighbor_random,
        "test_neighbors": total_test_neighbor_random,
    }
    
    # Print shapes for verification
    print(f"Version {i}:")
    print(f"Training neighbors shape: {total_train_neighbor_random.shape}")
    print(f"Testing neighbors shape: {total_test_neighbor_random.shape}\n")


Version 0:
Training neighbors shape: (70, 24000, 2)
Testing neighbors shape: (30, 24000, 2)

Version 1:
Training neighbors shape: (70, 24000, 2)
Testing neighbors shape: (30, 24000, 2)

Version 2:
Training neighbors shape: (70, 24000, 2)
Testing neighbors shape: (30, 24000, 2)

Version 3:
Training neighbors shape: (70, 24000, 2)
Testing neighbors shape: (30, 24000, 2)



In [26]:
# This function assists in creating data arrays associated with bi- and
# uni-directional arrays
def createArray(input_data, num_of_data, is_source, is_bi_directional):
    # Initialize an empty list to store the output data
    data = []

    # Loop through each set of data in input_data
    for i in range(num_of_data):
        _data = []

        # Loop through each pair of data in the current data set
        for pair in input_data[i]:

            # Process data depending on is_bi_directional flag
            if is_bi_directional:
                # If is_source is True, append both elements in original order
                if is_source:
                    _data.append(pair[0])
                    _data.append(pair[1])
                else:
                    # If is_source is False, append elements in reversed order
                    _data.append(pair[1])
                    _data.append(pair[0])
            else:
                # If is_bi_directional is False, append only one element depending on is_source flag
                if is_source:
                    _data.append(pair[0])
                else:
                    _data.append(pair[1])

        # Add the processed data set to the output list
        data.append(_data)

    # Return the final processed list of data
    data = np.array(data)
    return data

In [27]:
# Create a dictionary to store bi-/uni-directional arrays for each version
neighbor_arrays_versions = {}

# Loop through each version
for i in range(len(sample_sizes_train_list)):
    # Get the neighbor pairs for the current version
    total_train_neighbor_random = neighbor_pairs_versions[f"version_{i}"]["train_neighbors"]
    total_test_neighbor_random = neighbor_pairs_versions[f"version_{i}"]["test_neighbors"]
    
    # Create bi- and uni-directional arrays for training
    train_edge_source_bi = createArray(total_train_neighbor_random, 70, True, True)
    train_edge_dest_bi = createArray(total_train_neighbor_random, 70, False, True)
    train_edge_source_uni = createArray(total_train_neighbor_random, 70, True, False)
    train_edge_dest_uni = createArray(total_train_neighbor_random, 70, False, False)
    
    # Create bi- and uni-directional arrays for testing
    test_edge_source_bi = createArray(total_test_neighbor_random, 30, True, True)
    test_edge_dest_bi = createArray(total_test_neighbor_random, 30, False, True)
    test_edge_source_uni = createArray(total_test_neighbor_random, 30, True, False)
    test_edge_dest_uni = createArray(total_test_neighbor_random, 30, False, False)
    
    # Store the arrays in the dictionary for each version
    neighbor_arrays_versions[f"version_{i}"] = {
        "train_edge_source_bi": train_edge_source_bi,
        "train_edge_dest_bi": train_edge_dest_bi,
        "train_edge_source_uni": train_edge_source_uni,
        "train_edge_dest_uni": train_edge_dest_uni,
        "test_edge_source_bi": test_edge_source_bi,
        "test_edge_dest_bi": test_edge_dest_bi,
        "test_edge_source_uni": test_edge_source_uni,
        "test_edge_dest_uni": test_edge_dest_uni,
    }
    
    # # Print the shapes of the arrays for verification
    # print(f"Version {i}:")
    # print(f"train_edge_source_bi shape: {train_edge_source_bi.shape}")
    # print(f"train_edge_dest_bi shape: {train_edge_dest_bi.shape}")
    # print(f"train_edge_source_uni shape: {train_edge_source_uni.shape}")
    # print(f"train_edge_dest_uni shape: {train_edge_dest_uni.shape}")
    # print(f"test_edge_source_bi shape: {test_edge_source_bi.shape}")
    # print(f"test_edge_dest_bi shape: {test_edge_dest_bi.shape}")
    # print(f"test_edge_source_uni shape: {test_edge_source_uni.shape}")
    # print(f"test_edge_dest_uni shape: {test_edge_dest_uni.shape}\n")

In [28]:
neighbor_pairs_versions['version_0']['test_neighbors'].shape

(30, 24000, 2)

In [29]:
labels_testing.shape

(30, 24000)

In [30]:
indices_versions['version_0']['testing_tt'].shape

(30, 12000)

In [31]:
# Function to create edge index tensors with correct permutation
def make_edge_index_tensor(source, dest):
    source = np.array(source)
    dest = np.array(dest)
    edge_index = torch.tensor([source, dest], dtype=torch.long)
    return edge_index.permute(1, 0, 2)

# Function to generate edge arrays for all versions, including training
def generate_edge_arrays_for_versions(neighbor_pairs_list, indices_versions, labels_testing, sample_size):
    edge_arrays_versions = {}
    valid_neighbor_types = ['tt', 'll', 'cl', 'lc', 'cc']

    for version_key in indices_versions:
        version_name = version_key  
        edge_arrays_versions[version_name] = {}

        # Handle training edges
        train_pairs = [neighbor_pairs_list[i] for i in indices_versions[version_key]['training_indices']]
        train_pairs = np.array(train_pairs)

        edge_arrays_versions[version_name]['train'] = {
            "source_bi": createArray(train_pairs, sample_size, True, True),
            "dest_bi": createArray(train_pairs, sample_size, False, True),
            "source_uni": createArray(train_pairs, sample_size, True, False),
            "dest_uni": createArray(train_pairs, sample_size, False, False),
        }

        # Handle testing edges by neighbor type
        test_neighbor_pairs = {ntype: [] for ntype in valid_neighbor_types}

        for neighbor_type in valid_neighbor_types:
            indices = indices_versions[version_key].get(f'testing_{neighbor_type}')
            if indices is not None:
                for i in range(len(labels_testing)):
                    test_neighbor_pairs[neighbor_type].append(neighbor_pairs_list[indices[i]])
                test_neighbor_pairs[neighbor_type] = np.array(test_neighbor_pairs[neighbor_type])

        # Generate edge arrays for each neighbor type in testing
        for neighbor_type, pairs in test_neighbor_pairs.items():
            edge_arrays_versions[version_name][neighbor_type] = {
                "source_bi": createArray(pairs, sample_size, True, True),
                "dest_bi": createArray(pairs, sample_size, False, True),
                "source_uni": createArray(pairs, sample_size, True, False),
                "dest_uni": createArray(pairs, sample_size, False, False),
            }

    return edge_arrays_versions

# Generate edge arrays for all versions (training + testing)
edge_arrays_versions = generate_edge_arrays_for_versions(neighbor_pairs_list, indices_versions, labels_testing, sample_size=30)

In [32]:
# Dictionary to store edge index tensors for all versions
edge_indices_versions = {}

for version in edge_arrays_versions:
    edge_indices_versions[version] = {}

    # Process training data
    edge_indices_versions[version]['train'] = {
        "bi": make_edge_index_tensor(edge_arrays_versions[version]['train']["source_bi"],
                                     edge_arrays_versions[version]['train']["dest_bi"]),
        "uni": make_edge_index_tensor(edge_arrays_versions[version]['train']["source_uni"],
                                      edge_arrays_versions[version]['train']["dest_uni"])
    }

    # Process testing data by neighbor type
    for neighbor_type in edge_arrays_versions[version]:
        if neighbor_type == 'train':  # Skip training since it's already processed
            continue
        
        edge_indices_versions[version][neighbor_type] = {
            "bi": make_edge_index_tensor(edge_arrays_versions[version][neighbor_type]["source_bi"],
                                         edge_arrays_versions[version][neighbor_type]["dest_bi"]),
            "uni": make_edge_index_tensor(edge_arrays_versions[version][neighbor_type]["source_uni"],
                                          edge_arrays_versions[version][neighbor_type]["dest_uni"])
        }

  edge_index = torch.tensor([source, dest], dtype=torch.long)


In [33]:
edge_arrays_versions['version_0'].keys()

dict_keys(['train', 'tt', 'll', 'cl', 'lc', 'cc'])

In [34]:
# indices_versions = {
#     'version_0': {
#         'training_indices': indices_versions['version_0']['training_indices'],  # Replace with your actual training indices
#         'testing_indices': indices_versions['version_0']['testing_indices'],  # Replace with your actual testing indices
#         'testing_tt': indices_versions['version_0']['testing_tt'],  # Indices for version 0 for testing_tt
#         'testing_ll': indices_versions['version_0']['testing_ll'],  # Indices for version 0 for testing_ll
#         'testing_cl': indices_versions['version_0']['testing_cl'],  # Indices for version 0 for testing_cl
#         'testing_lc': indices_versions['version_0']['testing_lc'],  # Indices for version 0 for testing_lc
#         'testing_cc': indices_versions['version_0']['testing_cc'],  # Indices for version 0 for testing_cc
#     },
#     'version_1': {
#         'training_indices': indices_versions['version_1']['training_indices'],  
#         'testing_indices': indices_versions['version_1']['testing_indices'], 
#         'testing_tt': indices_versions['version_1']['testing_tt'], 
#         'testing_ll': indices_versions['version_1']['testing_ll'], 
#         'testing_cl': indices_versions['version_1']['testing_cl'],
#         'testing_lc': indices_versions['version_1']['testing_lc'],
#         'testing_cc': indices_versions['version_1']['testing_cc'],
#     },
#     'version_2': {
#         'training_indices': indices_versions['version_2']['training_indices'],  
#         'testing_indices': indices_versions['version_2']['testing_indices'], 
#         'testing_tt': indices_versions['version_2']['testing_tt'], 
#         'testing_ll': indices_versions['version_2']['testing_ll'], 
#         'testing_cl': indices_versions['version_2']['testing_cl'],
#         'testing_lc': indices_versions['version_2']['testing_lc'],
#         'testing_cc': indices_versions['version_2']['testing_cc'],
#     },
#     'version_3': {
#         'training_indices': indices_versions['version_3']['training_indices'],  
#         'testing_indices': indices_versions['version_3']['testing_indices'], 
#         'testing_tt': indices_versions['version_3']['testing_tt'], 
#         'testing_ll': indices_versions['version_3']['testing_ll'], 
#         'testing_cl': indices_versions['version_3']['testing_cl'],
#         'testing_lc': indices_versions['version_3']['testing_lc'],
#         'testing_cc': indices_versions['version_3']['testing_cc'],
#     }
# }

In [35]:
# # Generate edge arrays for all versions
# edge_arrays_versions = generate_edge_arrays_for_versions(neighbor_pairs_list, indices_versions, labels_testing, sample_size=30)


In [36]:
print("Version 0, TT Pair Edge Arrays:")
print(f"Source Bi-directional: {edge_arrays_versions['version_0']['tt']['source_bi'].shape}")
print(f"Destination Bi-directional: {edge_arrays_versions['version_0']['tt']['dest_bi'].shape}")
print(f"Source Uni-directional: {edge_arrays_versions['version_0']['tt']['source_uni'].shape}")
print(f"Destination Uni-directional: {edge_arrays_versions['version_0']['tt']['dest_uni'].shape}")

Version 0, TT Pair Edge Arrays:
Source Bi-directional: (30, 24000)
Destination Bi-directional: (30, 24000)
Source Uni-directional: (30, 12000)
Destination Uni-directional: (30, 12000)


In [37]:
# Print shapes for Version 0 as an example
print("Version 0, TT Pair Edge Index Tensors:")
print(f"Bi-directional: {edge_indices_versions['version_0']['tt']['bi'].shape}")
print(f"Uni-directional: {edge_indices_versions['version_0']['tt']['uni'].shape}")

Version 0, TT Pair Edge Index Tensors:
Bi-directional: torch.Size([30, 2, 24000])
Uni-directional: torch.Size([30, 2, 12000])


In [38]:
# Here we create the features array for training and testing

# Here we create a rearranged dictionary from the scaled data dictionary
keys = list(scaled_data.keys())
values = list(scaled_data.values())
features_dict = dict(zip(keys, values))

# Here these features are split into training and testing sets
features_training = np.concatenate([value for key, value in list(features_dict.items())[:70]])
features_testing = np.concatenate([value for key, value in list(features_dict.items())[70:]])

In [39]:
print(features_training.shape)
print(features_testing.shape)

(13135640, 3)
(5629560, 3)


In [40]:
features_training = features_training.reshape(70, 187652, 3)
features_testing = features_testing.reshape(30, 187652, 3)

## Creation of the NN Model

In [41]:
# Make the scaled features into a torch tensor (inputs)
x_train = torch.tensor(features_training, dtype=torch.float)
x_test = torch.tensor(features_testing, dtype=torch.float)

In [42]:
# # Here the currect dimension permutations are applied for the model
# def make_edge_index_tensor(source, dest):
#     source = np.array(source)
#     dest = np.array(dest)
#     edge_index = torch.tensor([source, dest], dtype=torch.long)
#     return edge_index.permute(1, 0, 2)

# # Training set (Bi-directional and Uni-directional)
# train_edge_indices_bi = make_edge_index_tensor(train_edge_source_bi, train_edge_dest_bi)
# train_edge_indices_uni = make_edge_index_tensor(train_edge_source_uni, train_edge_dest_uni)
# print(train_edge_indices_bi.shape)
# print(train_edge_indices_uni.shape)

In [43]:
# Dictionary to store edge index tensors for all versions
edge_indices_versions_generalized = {}

# Iterate over each version
for version in edge_indices_versions:
    edge_indices_versions_generalized[version] = {}

    # Iterate over each neighbor type (tt, ll, cl, lc, cc)
    for neighbor_type in edge_indices_versions[version]:
        
        # Get bi-directional and uni-directional tensors
        bi_tensor = edge_indices_versions[version][neighbor_type]["bi"]
        uni_tensor = edge_indices_versions[version][neighbor_type]["uni"]
        
        # Store them with standardized keys
        edge_indices_versions_generalized[version][f"{neighbor_type}_bi"] = bi_tensor
        edge_indices_versions_generalized[version][f"{neighbor_type}_uni"] = uni_tensor

# Example: Print edge index tensor shapes for version_0
print("Edge index tensor shapes for version_0:")
for key, tensor in edge_indices_versions_generalized["version_0"].items():
    print(f"{key}: {tensor.shape}")


Edge index tensor shapes for version_0:
train_bi: torch.Size([30, 2, 48000])
train_uni: torch.Size([30, 2, 24000])
tt_bi: torch.Size([30, 2, 24000])
tt_uni: torch.Size([30, 2, 12000])
ll_bi: torch.Size([30, 2, 6000])
ll_uni: torch.Size([30, 2, 3000])
cl_bi: torch.Size([30, 2, 6000])
cl_uni: torch.Size([30, 2, 3000])
lc_bi: torch.Size([30, 2, 6000])
lc_uni: torch.Size([30, 2, 3000])
cc_bi: torch.Size([30, 2, 6000])
cc_uni: torch.Size([30, 2, 3000])


In [44]:
# edge_index_data = {
#     "tt_bi": (test_edge_tt_source_bi, test_edge_tt_dest_bi),
#     "tt_uni": (test_edge_tt_source_uni, test_edge_tt_dest_uni),
#     "ll_bi": (test_edge_ll_source_bi, test_edge_ll_dest_bi),
#     "ll_uni": (test_edge_ll_source_uni, test_edge_ll_dest_uni),
#     "cl_bi": (test_edge_cl_source_bi, test_edge_cl_dest_bi),
#     "cl_uni": (test_edge_cl_source_uni, test_edge_cl_dest_uni),
#     "lc_bi": (test_edge_lc_source_bi, test_edge_lc_dest_bi),
#     "lc_uni": (test_edge_lc_source_uni, test_edge_lc_dest_uni),
#     "cc_bi": (test_edge_cc_source_bi, test_edge_cc_dest_bi),
#     "cc_uni": (test_edge_cc_source_uni, test_edge_cc_dest_uni),
# }

# # Create and permute tensors for all edge types
# edge_indices = {key: make_edge_index_tensor(
#     sources, dests) for key, (sources, dests) in edge_index_data.items()}

In [45]:
y_train = np.expand_dims(labels_training, axis=1)
y_train = torch.tensor(y_train)
y_train.shape

torch.Size([70, 1, 24000])

In [46]:
y_test = np.expand_dims(labels_testing, axis=1)
y_test = torch.tensor(y_test)

## Creation of custom data lists, collate functions, and data loaders

In [47]:
# Create a class that inherents from the torch.utils.data.Dataset class
# The pytorch class is abstract, meaning we need to define certain methods
# like __len__() and __getitem__()
class custom_dataset(torch.utils.data.Dataset):
    # Class constructor that takes in data list and
    # stores it as an instance, making it avaliable
    # to other methods in the class
    def __init__(self, data_list):
        self.data_list = data_list

    # Method return length of data set
    def __len__(self):
        return len(self.data_list)

    # Method returns data point at index idx
    def __getitem__(self, idx):
        return self.data_list[idx]

# Used to handle batch loading, shuffling, and parallel loading during
# training and testing in the ML pipeline

In [48]:
# Create a list with information regarding a homogenous graph (a graph
# where all nodes represent instances of the same type [cells in the
# detector] and all edges represent relations of the same type [connections
# between cells])
def create_data_list(bi_edge_indices, uni_edge_indices, x, y):
    data_list = []
    for i in range(len(bi_edge_indices)):
        # Create the feature matrix
        x_mat = x[i]
        # Create graph connectivity matrix
        edge_index = bi_edge_indices[i]
        edge_index, _ = add_self_loops(edge_index)

        # Convert y[i] to a PyTorch tensor
        y_tensor = torch.tensor(y[i], dtype=torch.long) if not isinstance(
            y[i], torch.Tensor) else y[i]

        # Create the data object describing a homogeneous graph
        data = Data(x=x_mat, edge_index=edge_index,
                    edge_index_out=uni_edge_indices[i], y=y_tensor)
        data = ToUndirected()(data)
        data_list.append(data)
    return data_list


def collate_data(data_list):
    return ([data.x for data in data_list],
            [data.edge_index for data in data_list],
            [data.edge_index_out for data in data_list],
            torch.cat([data.y for data in data_list], dim=0))

In [49]:
edge_indices_versions_generalized['version_0'].keys()

dict_keys(['train_bi', 'train_uni', 'tt_bi', 'tt_uni', 'll_bi', 'll_uni', 'cl_bi', 'cl_uni', 'lc_bi', 'lc_uni', 'cc_bi', 'cc_uni'])

In [50]:
labels_versions['version_0']['testing_tt']

array([[1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1]])

In [51]:
# # Create the data lists for all edge types and categories

# data_list_train0 = create_data_list(edge_indices_versions['version_0']['train']['bi'], 
#                                     edge_indices_versions['version_0']['train']['uni'], 
#                                     x_train, 
#                                     y_train)  # Training Edges

# data_list_tt0 = create_data_list(edge_indices_versions_generalized['version_0']['tt_bi'], 
#                                  edge_indices_versions_generalized['version_0']['tt_uni'], 
#                                  x_test, 
#                                  np.expand_dims(labels_versions['version_0']['testing_tt'], axis=1))  # True-True Edges

# data_list_ll0 = create_data_list(edge_indices_versions_generalized['version_0']['ll_bi'], 
#                                  edge_indices_versions_generalized['version_0']['ll_uni'], 
#                                  x_test, 
#                                  np.expand_dims(labels_versions['version_0']['testing_ll'], axis=1))  # Lone-lone Edges

# data_list_cl0 = create_data_list(edge_indices_versions_generalized['version_0']['cl_bi'], 
#                                  edge_indices_versions_generalized['version_0']['cl_uni'], 
#                                  x_test, 
#                                  np.expand_dims(labels_versions['version_0']['testing_cl'], axis=1))  # Cluster-Lone Edges

# data_list_lc0 = create_data_list(edge_indices_versions_generalized['version_0']['lc_bi'],
#                                   edge_indices_versions_generalized['version_0']['lc_uni'],
#                                     x_test, 
#                                     np.expand_dims(labels_versions['version_0']['testing_lc'], axis=1))  # Lone-Cluster Edges

# data_list_cc0 = create_data_list(edge_indices_versions_generalized['version_0']['cc_bi'], 
#                                  edge_indices_versions_generalized['version_0']['cc_uni'], 
#                                  x_test, 
#                                  np.expand_dims(labels_versions['version_0']['testing_cc'], axis=1))  # Cluster-Cluster Edges

In [52]:
# Dictionary to store data lists for all versions
data_lists_versions = {}

# Define all neighbor types
neighbor_types = ['tt', 'll', 'cl', 'lc', 'cc']

# Iterate over all versions
for version in ['version_0', 'version_1', 'version_2', 'version_3']:
    data_lists_versions[version] = {}

    # Training edges
    data_lists_versions[version]['train'] = create_data_list(
        edge_indices_versions[version]['train']['bi'], 
        edge_indices_versions[version]['train']['uni'], 
        x_train, 
        y_train
    )

    # Testing edges for each neighbor type
    for neighbor_type in neighbor_types:
        data_lists_versions[version][neighbor_type] = create_data_list(
            edge_indices_versions[version][neighbor_type]['bi'], 
            edge_indices_versions[version][neighbor_type]['uni'], 
            x_test, 
            np.expand_dims(labels_versions[version][f'testing_{neighbor_type}'], axis=1)
        )


In [53]:
data_lists_versions['version_0']['train'][0]

Data(x=[187652, 3], edge_index=[2, 235637], y=[1, 24000], edge_index_out=[2, 24000])

In [54]:
# # Batch size value
# batch_size = 1

# # Create the data loaders
# data_loader = {}
# data_list_mapping = {
#     "train": data_list_train,  # Training Edges
#     "tt": data_list_tt,           # True-True Edges
#     "ll": data_list_ll,           # Lone-Lone Edges
#     "lc": data_list_lc,           # Lone-Cluster Edges
#     "cl": data_list_cl,           # Cluster-Lone Edges
#     "cc": data_list_cc            # Cluster-Cluster Edges
# }

# # Total background dataset
# data_list_total_bkg = data_list_ll + data_list_cl + data_list_lc + data_list_cc
# data_loader_total_bkg = torch.utils.data.DataLoader(
#     custom_dataset(data_list_total_bkg),
#     batch_size=batch_size,
#     collate_fn=lambda batch: collate_data(batch)
# )
# # For the other datasets
# for key, data_list in data_list_mapping.items():
#     data_loader[key] = torch.utils.data.DataLoader(
#         custom_dataset(data_list),
#         batch_size=batch_size,
#         collate_fn=lambda batch: collate_data(batch)
#     )

In [55]:
# Batch size
batch_size = 1

# Dictionary to store data loaders for all versions
data_loaders_versions = {}

# Iterate over all versions
for version in ['version_0', 'version_1', 'version_2', 'version_3']:
    data_loaders_versions[version] = {}

    # Define dataset for each category
    data_list_mapping = {
        "train": data_lists_versions[version]['train'],  # Training Edges
        "tt": data_lists_versions[version]['tt'],       # True-True Edges
        "ll": data_lists_versions[version]['ll'],       # Lone-Lone Edges
        "lc": data_lists_versions[version]['lc'],       # Lone-Cluster Edges
        "cl": data_lists_versions[version]['cl'],       # Cluster-Lone Edges
        "cc": data_lists_versions[version]['cc']        # Cluster-Cluster Edges
    }

    # Create background dataset (sum of all background types)
    data_list_total_bkg = (
        data_lists_versions[version]['ll'] + 
        data_lists_versions[version]['cl'] + 
        data_lists_versions[version]['lc'] + 
        data_lists_versions[version]['cc']
    )

    # Create data loader for total background dataset
    data_loaders_versions[version]['total_bkg'] = torch.utils.data.DataLoader(
        custom_dataset(data_list_total_bkg),
        batch_size=batch_size,
        collate_fn=lambda batch: collate_data(batch)
    )

    # Create data loaders for all other datasets
    for key, data_list in data_list_mapping.items():
        data_loaders_versions[version][key] = torch.utils.data.DataLoader(
            custom_dataset(data_list),
            batch_size=batch_size,
            collate_fn=lambda batch: collate_data(batch)
        )


## Creation of Model, Weights, and Functions to Train and Test the Model

#### Model without learnable layer weights

In [56]:
# class MultiEdgeClassifier(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers, debug=False):
#         super(MultiEdgeClassifier, self).__init__()
#         # Set the debug mode
#         self.debug = debug

#         # Node embedding layer
#         self.node_embedding = nn.Linear(input_dim, hidden_dim)

#         # Initialize first convolution and batch norm layers
#         self.convs = nn.ModuleList()
#         self.bns = nn.ModuleList()
#         self.convs.append(GCNConv(hidden_dim, 128))
#         self.bns.append(BatchNorm1d(128))

#         # Additional conv and bn layers based on 'num_layers' param
#         for i in range(1, num_layers):
#             in_channels = 128 if i == 1 else 64
#             out_channels = 64
#             self.convs.append(GCNConv(in_channels, out_channels))
#             self.bns.append(BatchNorm1d(out_channels))

#         # Edge classification layer
#         self.fc = nn.Linear(128, output_dim)

#     def debug_print(self, message):
#         if self.debug:
#             print(message)

#     def forward(self, x, edge_index, edge_index_out):
#         # Node embedding
#         x = self.node_embedding(x)
#         self.debug_print(f"Node embedding output shape: {x.shape}")

#         if x.dim() == 3 and x.size(0) == 1:  # Check and remove batch dimension
#             x = x.squeeze(0)

#         # Loop through convolution layers
#         for i, conv in enumerate(self.convs):
#             x = conv(x, edge_index)
#             self.debug_print(f"After GCNConv {i+1}: {x.shape}")
#             if x.dim() == 3 and x.size(0) == 1:  # Check and remove batch dimension
#                 x = x.squeeze(0)
#             x = self.bns[i](x)
#             x = torch.relu(x)
#             self.debug_print(f"After BatchNorm {i+1}: {x.shape}")

#         # Edge representations
#         edge_rep = torch.cat(
#             [x[edge_index_out[0]], x[edge_index_out[1]]], dim=1)
#         self.debug_print(f"Edge representation shape: {edge_rep.shape}")

#         # Return Logits
#         edge_scores = self.fc(edge_rep)
#         return edge_scores

#### Model with learnable layer weights

In [57]:
# Have different weights for the different layers

class MultiEdgeClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, debug=False):
        super(MultiEdgeClassifier, self).__init__()
        # Set the debug mode
        self.debug = debug

        # Node embedding layer
        self.node_embedding = nn.Linear(input_dim, hidden_dim)

        # Initialize first convolution and batch norm layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        # List to store learnable weights for each layer
        self.layer_weights = nn.ParameterList()

        self.convs.append(GCNConv(hidden_dim, 128))
        self.bns.append(BatchNorm1d(128))
        self.layer_weights.append(nn.Parameter(torch.tensor(
            1.0, requires_grad=True)))  # Weight for layer 1

        # Additional conv and bn layers based on 'num_layers' param
        for i in range(1, num_layers):
            in_channels = 128 if i == 1 else 64
            out_channels = 64
            self.convs.append(GCNConv(in_channels, out_channels))
            self.bns.append(BatchNorm1d(out_channels))
            self.layer_weights.append(nn.Parameter(torch.tensor(
                1.0, requires_grad=True)))  # Weight for each layer

        # Edge classification layer
        self.fc = nn.Linear(128, output_dim)

    def debug_print(self, message):
        if self.debug:
            print(message)

    def forward(self, x, edge_index, edge_index_out):
        # Node embedding
        x = self.node_embedding(x)
        self.debug_print(f"Node embedding output shape: {x.shape}")

        if x.dim() == 3 and x.size(0) == 1:  # Check and remove batch dimension
            x = x.squeeze(0)

        # Loop through convolution layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            self.debug_print(f"After GCNConv {i+1}: {x.shape}")
            if x.dim() == 3 and x.size(0) == 1:  # Check and remove batch dimension
                x = x.squeeze(0)
            x = self.bns[i](x)
            x = torch.relu(x)
            # Multiply output of each layer by its corresponding weight
            x = x * self.layer_weights[i]
            self.debug_print(f"After Layer Weight {i+1}: {x.shape}")

        # Edge representations
        edge_rep = torch.cat(
            [x[edge_index_out[0]], x[edge_index_out[1]]], dim=1)
        self.debug_print(f"Edge representation shape: {edge_rep.shape}")

        # Return Logits
        edge_scores = self.fc(edge_rep)
        return edge_scores

In [58]:
# Model parameters
input_dim = 3
hidden_dim = 256
output_dim = 5  # Multiclass classification
num_layers = 5

# Assigning different GPUs for different versions
devices = {
    'version_0': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    'version_1': torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
    'version_2': torch.device("cuda:2" if torch.cuda.is_available() else "cpu"),
    'version_3': torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
}

# Creating models for each version on its respective device
models = {
    version: MultiEdgeClassifier(input_dim, hidden_dim, output_dim, num_layers).to(devices[version])
    for version in devices
}

In [59]:
# Flatten labels for training
labels_training_flat = labels_training.flatten()

# Compute class frequencies
train_label_counts = Counter(labels_training_flat)

# Total number of training samples
total_train_samples = labels_training_flat.size

# Compute class weights
class_weights = {label: total_train_samples / count for label, count in train_label_counts.items()}

# Normalize class weights
total_weight = sum(class_weights.values())
normalized_class_weights = {label: weight / total_weight for label, weight in class_weights.items()}

# Convert weights to tensors (for each version on the correct device)
unique_classes = np.unique(labels_training_flat)
weight_tensors = {
    version: torch.tensor([class_weights[label] for label in unique_classes], dtype=torch.float).to(devices[version])
    for version in devices
}
normalized_weight_tensors = {
    version: torch.tensor([normalized_class_weights[label] for label in unique_classes], dtype=torch.float).to(devices[version])
    for version in devices
}

In [60]:
# This is a modified version of cross entropy loss by Lin et al. 2017
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # Apply softmax to logits to get probabilities
        probs = F.softmax(logits, dim=1)

        # Get the probabilities of the true class
        targets_one_hot = F.one_hot(
            targets, num_classes=logits.size(1)).float()
        probs = probs * targets_one_hot
        probs = probs.sum(dim=1)  # Shape: (batch_size,)

        # Compute the focal loss components
        log_probs = torch.log(probs + 1e-8)  # Add epsilon to avoid log(0)
        focal_weights = (1 - probs) ** self.gamma

        # Apply class weights (if provided)
        if self.alpha is not None:
            alpha_weights = self.alpha[targets]  # Shape: (batch_size,)
            focal_loss = -alpha_weights * focal_weights * log_probs
        else:
            focal_loss = -focal_weights * log_probs

        # Reduce the loss based on the reduction method
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:  # 'none'
            return focal_loss

In [61]:
# List of possible losses:
# 1. nn.CrossEntropyLoss(weight=weight_tensors['version_something that matches the version number'])
# 2. nn.CrossEntropyLoss(weight=normalized_weight_tensors['version_something that matches the version number'])
# 3. FocalLoss(alpha=weight_tensors['version_3'], gamma=2.0, reduction='mean')

In [62]:
num_epochs = 300

# Loss functions per version
criterions = {
    'version_0': nn.CrossEntropyLoss(),
    'version_1': nn.CrossEntropyLoss(),
    'version_2': nn.CrossEntropyLoss(),
    'version_3': nn.CrossEntropyLoss()
}

# Optimizers per version
optimizers = {
    version: optim.Adam(models[version].parameters(), lr=0.1)
    for version in devices
}

# Learning rate schedulers per version
schedulers = {
    version: torch.optim.lr_scheduler.ExponentialLR(optimizers[version], gamma=0.99)
    for version in devices
}

# I HAVE STOPED HERE, GET THIS FIXED BELOW

In [63]:
def train_model(model, device, data_loader, optimizer, criterion):
    # Sets the model into training mode
    model.train()
    # Sends model to GPU if available, otherwise uses the CPU
    model.to(device)

    # Assumes there is only one batch in the data loader
    # Retrieve the single batch from the data loader
    batch_x, batch_edge_index, batch_edge_index_out, batch_y = next(iter(data_loader))

    # Sends the input features, the edge indices, and target
    # labels to the GPU if available, otherwise the CPU
    batch_x = torch.stack(batch_x).to(device)
    batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
    batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]

    # Convert target labels to LongTensor (torch.int64)
    batch_y = [y.long().to(device) for y in batch_y]

    # Clears the gradients of the model parameters to ensure
    # they are not accumulated across batches
    optimizer.zero_grad()

    # Initialize loss tracking for subgraphs in the single batch
    loss_per_batch = []

    # Model processes each graph in the batch one by one
    for i in range(len(batch_edge_index)):
        # Pass the features and the edge indices into the model and store
        # the output (logits)
        _output = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])

        # Ensure that model outputs (logits) are of type float32
        _output = _output.float()

        # Calculate the difference between the model output and the targets
        # via the provided criterion (loss function)
        loss = criterion(_output.squeeze(), batch_y[i].squeeze())

        # This difference is stored in the loss_per_batch list
        loss_per_batch.append(loss)

    # The average loss across all subgraphs within the single batch is calculated
    total_loss_per_batch = sum(loss_per_batch) / len(loss_per_batch)

    # Computes the loss gradients with respect to the model parameters
    total_loss_per_batch.backward()

    # Updates the model parameters using the gradients
    optimizer.step()

    # Returns the total loss for the single batch
    return total_loss_per_batch

In [64]:
def test_model(model, device, data_loader_true, data_loader_bkg_dict):
    all_scores = []
    true_labels = []

    with torch.no_grad():
        model.eval()
        model.to(device)

        # Process true edges (positive class, label 1)
        batch_x, batch_edge_index, batch_edge_index_out, batch_y = next(iter(data_loader_true))
        batch_x = torch.stack(batch_x).to(device)
        batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
        batch_edge_index_out = [edge_index_out.to(device) for edge_index_out in batch_edge_index_out]

        for i in range(len(batch_edge_index)):
            test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
            test_edge_scores = F.softmax(test_edge_scores, dim=1)

            all_scores.append(test_edge_scores)
            true_labels.append(torch.ones(test_edge_scores.size(0), dtype=torch.long, device=device))

        # Process background edges
        for background_type, data_loader_bkg in data_loader_bkg_dict.items():
            batch_x, batch_edge_index, batch_edge_index_out, batch_y = next(iter(data_loader_bkg))
            batch_x = torch.stack(batch_x).to(device)
            batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
            batch_edge_index_out = [edge_index_out.to(device) for edge_index_out in batch_edge_index_out]

            for i in range(len(batch_edge_index)):
                test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                test_edge_scores = F.softmax(test_edge_scores, dim=1)

                all_scores.append(test_edge_scores)
                true_labels.append(torch.full((test_edge_scores.size(0),), background_type, dtype=torch.long, device=device))

    # Concatenate all scores and labels
    all_scores = torch.cat(all_scores, dim=0)
    true_labels = torch.cat(true_labels, dim=0)

    # Reorder scores and labels to match the desired order (0, 1, 2, 3, 4)
    desired_order = [0, 1, 2, 3, 4]
    # Ensure reordering matches desired labels
    mask = torch.argsort(true_labels).argsort()
    all_scores = all_scores[mask]
    true_labels = true_labels[mask]

    return all_scores.cpu().numpy(), true_labels.cpu().numpy()

In [65]:
def loss_for_train_and_test(model, loader, loss_fn, optimizer, training, device):
    if training:
        model.train()
    else:
        model.eval()

    total_loss = 0.0
    num_batches = 0
    all_logits = []  # Store raw logits in testing mode

    for batch in loader:
        if training:
            batch_x, batch_edge_index, batch_edge_index_out, batch_y = batch
            batch_y = batch_y.to(device)
        else:
            batch_x, batch_edge_index, batch_edge_index_out, *_ = batch

        # Move features and edge indices to the device
        batch_x = torch.stack(batch_x).to(device)
        batch_edge_index = [edge_index.to(device)for edge_index in batch_edge_index]
        batch_edge_index_out = [edge_index_out.to(
            device) for edge_index_out in batch_edge_index_out]

        if training:
            for i in range(len(batch_edge_index)):
                optimizer.zero_grad()

                # Forward pass
                logits = model(
                    batch_x[i], batch_edge_index[i], batch_edge_index_out[i])

                # Compute loss
                loss = loss_fn(logits, batch_y[i])
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1
        else:
            with torch.no_grad():
                for i in range(len(batch_edge_index)):
                    logits = model(
                        batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                    all_logits.append(logits)  # Store raw logits
                    num_batches += 1

    average_loss = total_loss / num_batches if training else None
    all_logits = torch.cat(all_logits, dim=0).cpu().numpy() if all_logits else None

    return average_loss, all_logits

In [70]:
# Mapping background data loaders to the correct labels per version
data_loader_bkg_dict = {
    'version_0': {
        0: data_loaders_versions['version_0']['ll'],  # Label 0
        2: data_loaders_versions['version_0']['cl'],  # Label 2
        3: data_loaders_versions['version_0']['lc'],  # Label 3
        4: data_loaders_versions['version_0']['cc']   # Label 4
    },
    'version_1': {
        0: data_loaders_versions['version_1']['ll'],
        2: data_loaders_versions['version_1']['cl'],
        3: data_loaders_versions['version_1']['lc'],
        4: data_loaders_versions['version_1']['cc']
    },
    'version_2': {
        0: data_loaders_versions['version_2']['ll'],
        2: data_loaders_versions['version_2']['cl'],
        3: data_loaders_versions['version_2']['lc'],
        4: data_loaders_versions['version_2']['cc']
    },
    'version_3': {
        0: data_loaders_versions['version_3']['ll'],
        2: data_loaders_versions['version_3']['cl'],
        3: data_loaders_versions['version_3']['lc'],
        4: data_loaders_versions['version_3']['cc']
    }
}


In [73]:
# Number of models (versions)
num_models = 4
versions = ['version_0', 'version_1', 'version_2', 'version_3']

results = {version: {
    "loss_per_epoch": [],
    "scores": [],
    "truth_labels": [],
    "avg_loss_training_true": [],
    "logits_training_true": [],
    "avg_loss_testing_true": [],
    "logits_testing_true": [],
    "avg_loss_training_bkg": [],
    "logits_training_bkg": [],
    "avg_loss_testing_bkg": [],
    "logits_testing_bkg": []
} for version in devices}

for epoch in range(num_epochs):
    losses = []
    for version in devices:
        model = models[version]
        device = devices[version]
        optimizer = optimizers[version]
        criterion = criterions[version]
        scheduler = schedulers[version]
        
        # Train the model
        total_loss = train_model(model, device, data_loaders_versions[version]['train'], optimizer, criterion)
        results[version]["loss_per_epoch"].append(total_loss.cpu().detach().numpy())

        # Update learning rate
        scheduler.step()

        # Test the model
        epoch_scores, epoch_true_labels = test_model(model, device, data_loaders_versions[version]['tt'], data_loader_bkg_dict[version])
        results[version]["scores"].append(epoch_scores)
        results[version]["truth_labels"].append(epoch_true_labels)

        # Compute the average loss for true and background edges
        avgLossTrueTrain, logitsTrueTrain = loss_for_train_and_test(model, data_loaders_versions[version]['train'], criterion, optimizer, True, device)
        avgLossTrueTest, logitsTrueTest = loss_for_train_and_test(model, data_loaders_versions[version]['tt'], criterion, optimizer, False, device)
        avgLossBkgTrain, logitsBkgTrain = loss_for_train_and_test(model, data_loaders_versions[version]['train'], criterion, optimizer, True, device)
        avgLossBkgTest, logitsBkgTest = loss_for_train_and_test(model, data_loaders_versions[version]['total_bkg'], criterion, optimizer, False, device)

        results[version]["avg_loss_training_true"].append(avgLossTrueTrain)
        results[version]["logits_training_true"].append(logitsTrueTrain)
        results[version]["avg_loss_testing_true"].append(avgLossTrueTest)
        results[version]["logits_testing_true"].append(logitsTrueTest)
        results[version]["avg_loss_training_bkg"].append(avgLossBkgTrain)
        results[version]["logits_training_bkg"].append(logitsBkgTrain)
        results[version]["avg_loss_testing_bkg"].append(avgLossBkgTest)
        results[version]["logits_testing_bkg"].append(logitsBkgTest)

        losses.append(f"Loss {version}: {total_loss.item():.4f}")
    
    print(f"Epoch: {epoch+1} | " + " | ".join(losses))

# Convert all lists to numpy arrays
def convert_to_numpy(results):
    for key, value in results.items():
        for sub_key in value:
            value[sub_key] = np.array(value[sub_key])
    return results

results = convert_to_numpy(results)

Epoch: 1 | Loss version_0: 1.6575 | Loss version_1: 2.1245 | Loss version_2: 1.6683 | Loss version_3: 1.6496
Epoch: 2 | Loss version_0: 1.3865 | Loss version_1: 1.3863 | Loss version_2: 1.3865 | Loss version_3: 1.2908
Epoch: 3 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3866 | Loss version_3: 1.2006
Epoch: 4 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3865 | Loss version_3: 1.1551
Epoch: 5 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3863 | Loss version_3: 1.1381
Epoch: 6 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3863 | Loss version_3: 1.1171
Epoch: 7 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3863 | Loss version_3: 1.1120
Epoch: 8 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3863 | Loss version_3: 1.1144
Epoch: 9 | Loss version_0: 1.3863 | Loss version_1: 1.3863 | Loss version_2: 1.3863 | Loss version_3: 1.1047
Epoch: 10 | Loss ve

KeyboardInterrupt: 

In [None]:
# Define colors for different loss plots
colors = ['dimgray', 'silver', 'olive', 'lightgreen']

# Plot loss per epoch for each model version
plt.figure(figsize=(12, 8))

for version_idx, version in enumerate(results.keys()):
    epoch_number = np.arange(1, len(results[version]["loss_per_epoch"]) + 1)

    plt.plot(epoch_number, results[version]["loss_per_epoch"], label=f'Loss per Epoch - {version}', linestyle='-', marker='o')
    plt.plot(epoch_number, results[version]["avg_loss_training_true"], colors[0], label=f'Avg Loss True Train - {version}', linestyle='--')
    plt.plot(epoch_number, results[version]["avg_loss_testing_true"], colors[1], label=f'Avg Loss True Test - {version}', linestyle='--')
    plt.plot(epoch_number, results[version]["avg_loss_training_bkg"], colors[2], label=f'Avg Loss Bkg Train - {version}', linestyle='-.')
    plt.plot(epoch_number, results[version]["avg_loss_testing_bkg"], colors[3], label=f'Avg Loss Bkg Test - {version}', linestyle='-.')

plt.xlabel("Epoch Number")
plt.ylabel("Loss")
plt.title("Loss per Epoch for Multi-Class Case with Three Layers")
plt.legend()
plt.show()

In [74]:
# # Define the class names
# class_names = ['Lone-Lone', 'True-True', 'Cluster-Lone', 'Lone-Cluster', 'Cluster-Cluster']

# # Define colormap
# cmap = 'YlGnBu'

# # Loop through each model version in results
# for version in results.keys():
#     print(f"Plotting Confusion Matrix for {version}...")

#     # Extract scores and truth labels for this version
#     scores = results[version]["scores"]  # Shape (500, 24000, 5)
#     truth_labels = results[version]["truth_labels"]  # Shape (500, 24000)

#     # Get predicted labels
#     predicted_labels = np.argmax(scores, axis=-1)  # Shape (500, 24000)

#     # Flatten arrays to compute confusion matrix across all samples
#     predicted_labels_flat = predicted_labels.flatten()  # (500 * 24000,)
#     truth_labels_flat = truth_labels.flatten()  # (500 * 24000,)

#     # Compute confusion matrix
#     conf_matrix = confusion_matrix(truth_labels_flat, predicted_labels_flat, labels=[0, 1, 2, 3, 4])

#     # Normalize the confusion matrix to get percentages
#     conf_matrix_percent = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] * 100

#     # Create figure for confusion matrices
#     plt.figure(figsize=(16, 6))

#     # Plot raw count confusion matrix
#     plt.subplot(1, 2, 1)
#     sns.heatmap(conf_matrix, annot=True, fmt='d', cmap=cmap, xticklabels=class_names, yticklabels=class_names)
#     plt.xlabel('Predicted Labels')
#     plt.ylabel('True Labels')
#     plt.title(f'Confusion Matrix - {version}')

#     # Plot percentage confusion matrix
#     plt.subplot(1, 2, 2)
#     sns.heatmap(conf_matrix_percent, annot=True, fmt='.2f', cmap=cmap, xticklabels=class_names, yticklabels=class_names)
#     plt.xlabel('Predicted Labels')
#     plt.ylabel('True Labels')
#     plt.title(f'Confusion Matrix (Percentage) - {version}')

#     # Adjust layout and show
#     plt.tight_layout()
#     plt.show()

In [71]:
# Define custom class labels
class_names = ["Lone-Lone", "True Cluster", "Cluster-Lone", "Lone-Cluster", "Cluster-Cluster"]
class_order = [0, 1, 2, 3, 4]  # Labels to iterate over

def plot_roc_curves(data_loaders_versions, model_names, num_epochs, interval=50, num_classes=5):
    """
    Plots ROC curves for each model and version, showing curves at every 50th epoch.

    Parameters:
    - data_loaders_versions: Dictionary with keys ('version_0', 'version_1', ...) containing the data for each version.
    - model_names: List of model names.
    - num_epochs: Total number of epochs.
    - interval: Epoch interval for plotting (default is 50).
    - num_classes: Number of classes (default is 5).
    """
    
    epoch_indices = list(range(0, num_epochs, interval))  # Select epochs at intervals
    class_order = list(range(num_classes))  # Class order for iteration
    
    for version_key in data_loaders_versions.keys():
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))  # 2x3 grid for ROC curves
        axes = axes.flatten()  # Flatten for easier subplot indexing

        for model_idx, model_name in enumerate(model_names):
            for plot_idx, epoch in enumerate(epoch_indices):
                if plot_idx >= 6:  # Ensure we only plot within a 2x3 layout
                    break

                ax = axes[plot_idx]  # Get subplot axis for the current plot

                # Get scores and labels for the given epoch and version
                epoch_scores = data_loaders_versions[version_key]['scores'][epoch]  # Shape (N, 5)
                epoch_truth_labels = data_loaders_versions[version_key]['truth_labels'][epoch]  # Shape (N,)

                # Iterate over each class and plot the ROC curve
                for class_idx in class_order:
                    binary_truth_labels = (epoch_truth_labels == class_idx).astype(int)  # True labels for this class
                    class_scores = epoch_scores[:, class_idx]  # Scores for this class

                    fpr, tpr, _ = roc_curve(binary_truth_labels, class_scores)
                    roc_auc = auc(fpr, tpr)

                    ax.plot(fpr, tpr, label=f'{class_names[class_idx]} (AUC = {roc_auc:.3f})')

                ax.set_title(f"{model_name} - {version_key} - Epoch {epoch + 1}")
                ax.set_xlabel("False Positive Rate")
                ax.set_ylabel("True Positive Rate")
                ax.legend(loc="lower right")

        # Adjust layout for each version and model
        plt.suptitle(f"ROC Curves for {version_key}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

# Example usage:
model_names = ["Model A", "Model B", "Model C", "Model D"]
plot_roc_curves(data_loaders_versions, model_names, num_epochs, interval=50, num_classes=5)


In [None]:
def plot_confusion_matrices(data_loaders_versions, model_names, num_epochs, interval=50, thresholds=None):
    """
    Plots confusion matrices for multiple models at specified epoch intervals.

    Parameters:
    - data_loaders_versions: Dictionary with keys ('version_0', 'version_1', ...) containing the data for each version.
    - model_names: List of model names.
    - num_epochs: Total number of epochs.
    - interval: Epoch interval for plotting (default is 50).
    - thresholds: List of threshold values per class (default is None, uses argmax).
    """

    # Default thresholds to 0.5 for all classes if not provided
    if thresholds is None:
        thresholds = [0.5] * len(class_names)

    epoch_indices = list(range(0, num_epochs, interval))  # Select epochs at intervals

    for version_key in data_loaders_versions.keys():
        for model_idx, model_name in enumerate(model_names):
            for epoch in epoch_indices:
                # Extract scores and truth labels for the given epoch
                epoch_scores = data_loaders_versions[version_key]['scores'][epoch]  # Shape (N, 5)
                epoch_truth_labels = data_loaders_versions[version_key]['truth_labels'][epoch]  # Shape (N,)

                # Convert scores to predicted labels using custom thresholds
                predicted_labels = np.full(epoch_scores.shape[0], -1)  # Initialize with -1 (unclassified)
                
                for i in range(len(class_order)):  # Iterate over each class
                    mask = (epoch_scores[:, i] >= thresholds[i]) & (predicted_labels == -1)
                    predicted_labels[mask] = i  # Assign class index based on threshold

                # Assign unclassified samples to the most probable class if all thresholds failed
                unclassified_mask = predicted_labels == -1
                predicted_labels[unclassified_mask] = np.argmax(epoch_scores[unclassified_mask], axis=-1)

                # Compute confusion matrix
                conf_matrix = confusion_matrix(epoch_truth_labels, predicted_labels, labels=class_order)

                # Normalize confusion matrix (percentage)
                conf_matrix_percent = conf_matrix.astype(float) / conf_matrix.sum(axis=1, keepdims=True) * 100

                # Plot confusion matrices
                fig, axes = plt.subplots(1, 2, figsize=(16, 6))

                # Raw Counts Confusion Matrix
                sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="YlGnBu", xticklabels=class_names, yticklabels=class_names, ax=axes[0])
                axes[0].set_xlabel('Predicted Labels')
                axes[0].set_ylabel('True Labels')
                axes[0].set_title(f'{model_name} - {version_key} - Epoch {epoch+1} (Raw Counts)')

                # Percentage Confusion Matrix
                sns.heatmap(conf_matrix_percent, annot=True, fmt='.2f', cmap="YlGnBu", xticklabels=class_names, yticklabels=class_names, ax=axes[1])
                axes[1].set_xlabel('Predicted Labels')
                axes[1].set_ylabel('True Labels')
                axes[1].set_title(f'{model_name} - {version_key} - Epoch {epoch+1} (Percentage)')

                plt.tight_layout()
                plt.show()

# Example usage:
model_names = ["Model A", "Model B", "Model C", "Model D"]
# custom_thresholds = [0.5, 0.6, 0.7, 0.4, 0.5]  # Set per-class thresholds
plot_confusion_matrices(data_loaders_versions, model_names, num_epochs, interval=50)


In [None]:
colors = ['Blue', 'Orange', 'Green', 'Red', 'Purple']
output_classes = np.arange(5)  # Classes 0, 1, 2, 3, 4

for model_idx in range(num_models):  # Loop over 4 models
    epoch_index = -1  # Last epoch
    
    # Extract scores and truth labels for the current model
    epoch_scores = results[model_idx]["scores"][epoch_index]  # Shape: (24000, 5)
    epoch_truth_labels = results[model_idx]["truth_labels"][epoch_index]  # Shape: (24000,)

    # --- First plot: Class-wise Score Distributions ---
    rows, cols = 2, 3
    fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
    axes = axes.flatten()

    fig.suptitle(f"Model {model_idx}: Class-wise Score Distributions at Final Epoch", fontsize=16)

    for class_idx in range(len(output_classes)):
        ax = axes[class_idx]
        ax.set_title(f'Output Class {class_idx}')
        ax.set_xlabel('Score')
        ax.set_ylabel('Density')

        for truth_type in sorted(np.unique(epoch_truth_labels)):
            scores_for_truth_type = epoch_scores[epoch_truth_labels == truth_type, class_idx]
            fraction_above_0_5 = np.mean(scores_for_truth_type > 0.5)

            ax.hist(
                scores_for_truth_type,
                bins=50,
                density=True,
                alpha=0.6,
                label=f'Truth {truth_type} (>{fraction_above_0_5:.2%})',
                color=colors[truth_type % len(colors)]
            )

        ax.legend()

    if len(axes) > len(output_classes):
        fig.delaxes(axes[-1])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

    # --- Second plot: Truth Type-wise Score Distributions ---
    rows, cols = 2, 3
    fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
    axes = axes.flatten()

    unique_truth_types = sorted(np.unique(epoch_truth_labels))

    fig.suptitle(f"Model {model_idx}: Truth Type-wise Score Distributions at Final Epoch", fontsize=16)

    if len(unique_truth_types) == 1:
        axes = [axes]

    for truth_type_idx, truth_type in enumerate(unique_truth_types):
        ax = axes[truth_type_idx]
        ax.set_title(f'Truth Type {truth_type}')
        ax.set_xlabel('Score')
        ax.set_ylabel('Density')

        for class_idx in output_classes:
            scores_for_truth_type_class = epoch_scores[epoch_truth_labels == truth_type, class_idx]
            fraction_above_0_5 = np.mean(scores_for_truth_type_class > 0.5)

            ax.hist(
                scores_for_truth_type_class,
                bins=50,
                density=True,
                alpha=0.6,
                label=f'Class {class_idx} (>{fraction_above_0_5:.2%})',
                color=colors[class_idx % len(colors)]
            )

        ax.legend()

    for idx in range(len(unique_truth_types), len(axes)):
        fig.delaxes(axes[idx])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [1]:
# Colors for different truth types and output classes
colors = ['Blue', 'Orange', 'Green', 'Red', 'Purple']

def plot_classwise_score_distributions(results, model_names, num_epochs, thresholds=[0.5], output_dim=5):
    """
    Plots the class-wise score distributions for each model version at the final epoch with custom thresholds.
    """
    epoch_index = -1  # Use the last epoch

    # Loop through each version in the results
    for version_key, version_data in results.items():
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()

        # Add a global title for the entire figure
        fig.suptitle(f"Class-wise Score Distributions with Truth Label Overlays at Final Epoch ({version_key})", fontsize=16)

        for class_idx in range(output_dim):
            ax = axes[class_idx]
            ax.set_title(f'Output Class {class_idx}')
            ax.set_xlabel('Score')
            ax.set_ylabel('Density')

            # Get the scores and truth labels for the final epoch in this version
            epoch_scores = version_data["scores"][epoch_index]  # Shape: (24000, 5)
            epoch_truth_labels = version_data["truth_labels"][epoch_index]  # Shape: (24000,)

            # For each truth type (i.e., label), plot the score distribution
            for truth_type in sorted(np.unique(epoch_truth_labels)):
                # Get scores for the current truth type and current class
                scores_for_truth_type = epoch_scores[epoch_truth_labels == truth_type, class_idx]

                # Plot normalized histogram
                for threshold in thresholds:
                    # Calculate the fraction of samples with scores > threshold
                    fraction_above_threshold = np.mean(scores_for_truth_type > threshold)
                    ax.hist(
                        scores_for_truth_type,
                        bins=50,
                        density=True,  # Normalize the histogram
                        alpha=0.6,
                        label=f'Truth {truth_type} (>{fraction_above_threshold:.2%} > {threshold})',
                        color=colors[truth_type % len(colors)]
                    )

            ax.legend()

        # Remove the third column in the second row if output_dim < rows * cols
        if len(axes) > output_dim:
            fig.delaxes(axes[-1])

        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to fit the title
        plt.show()

In [2]:
def plot_truth_typewise_score_distributions(results, model_names, num_epochs, thresholds=[0.5], output_dim=5):
    """
    Plots the truth type-wise score distributions for each model version at the final epoch with custom thresholds.
    """
    epoch_index = -1  # Use the last epoch

    # Loop through each version in the results
    for version_key, version_data in results.items():
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()

        # Add a global title for the entire figure
        fig.suptitle(f"Truth Type-wise Score Distributions for All Output Classes at Final Epoch ({version_key})", fontsize=16)

        unique_truth_types = sorted(np.unique(version_data["truth_labels"][epoch_index]))

        # Handle cases where only one truth type is present
        if len(unique_truth_types) == 1:
            axes = [axes]

        for truth_type_idx, truth_type in enumerate(unique_truth_types):
            ax = axes[truth_type_idx]
            ax.set_title(f'Truth Type {truth_type}')
            ax.set_xlabel('Score')
            ax.set_ylabel('Density')

            # Get the scores and truth labels for the final epoch in this version
            epoch_scores = version_data["scores"][epoch_index]  # Shape: (24000, 5)
            epoch_truth_labels = version_data["truth_labels"][epoch_index]  # Shape: (24000,)

            for class_idx in range(output_dim):
                # Get scores for the current output class and the current truth type
                scores_for_truth_type_class = epoch_scores[epoch_truth_labels == truth_type, class_idx]

                # Plot normalized histogram for the current output class
                for threshold in thresholds:
                    # Calculate the fraction of scores > threshold
                    fraction_above_threshold = np.mean(scores_for_truth_type_class > threshold)
                    ax.hist(
                        scores_for_truth_type_class,
                        bins=50,
                        density=True,  # Normalize the histogram
                        alpha=0.6,
                        label=f'Class {class_idx} (>{fraction_above_threshold:.2%} > {threshold})',
                        color=colors[class_idx % len(colors)]
                    )

            # Add legend with updated labels
            ax.legend()

        # Remove unused subplots
        for idx in range(len(unique_truth_types), len(axes)):
            fig.delaxes(axes[idx])

        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to fit the title
        plt.show()

In [None]:
# Example usage with custom thresholds:
model_names = ["Model A", "Model B", "Model C", "Model D"]
thresholds = [0.3, 0.5, 0.7]  # Set custom thresholds for score distribution
plot_classwise_score_distributions(results, model_names, num_epochs, thresholds=thresholds, output_dim=5)
plot_truth_typewise_score_distributions(results, model_names, num_epochs, thresholds=thresholds, output_dim=5)