## Particle collisions


Author: [Alessio Devoto](https://alessiodevoto.github.io/)

In [1]:
# Here we just install it from pip
# If you want access to the files, you can still download the repository though
!pip install torchmetrics --quiet     # for accuracy and metrics
!pip install sparticles --quiet   --use-deprecated=legacy-resolver     # our library

In [7]:
import torch
from torch_geometric.data import InMemoryDataset, download_url
import pandas as pd
from torch_geometric.data import Data
from tqdm import tqdm
import os
import shutil
import tarfile
import glob
from sparticles.transforms import MakeHomogeneous
import numpy as np
np.__version__

'2.1.2'

In [3]:
# Random state for shuffling the dataset.
RANDOM_STATE = 42

# Names of the directories in the raw directory.
RAW_DIR_NAMES = ['signal', 'singletop', 'ttbar']

# Constant labels for noise and signal.
SIGNAL_LABEL = 1
BACKGROUND_LABEL = 0

# Match between directory and event type.
EVENT_LABELS = {
    'signal': SIGNAL_LABEL,
    'singletop': BACKGROUND_LABEL,
    'ttbar': BACKGROUND_LABEL
}

# Number of events to keep for each event type.
# The total number of events in the dataset is the sum of the values in this dictionary.
# We can use these values to have a more balanced dataset.
DEFAULT_EVENT_SUBSETS = {
    'signal': 463056,
    'singletop': 242614,
    'ttbar': 6093298
}

# These are the columns we should keep from the raw pandas dataframe.
# The nan columns are just a hack as we need to have the same number of columns for each row.
USEFUL_COLS = [
    # jet 1
    'pTj1', 'etaj1', 'phij1', 'j1_quantile', 'nan', 'nan',
    # jet 2
    'pTj2', 'etaj2', 'phij2', 'j2_quantile', 'nan', 'nan',
    # jet 3
    'pTj3', 'etaj3', 'phij3', 'j3_quantile', 'nan', 'nan',
    # b1
    'pTb1', 'etab1', 'phib1', 'b1_quantile', 'b1m', 'nan',
    # b2
    'pTb2', 'etab2', 'phib2', 'b2_quantile', 'b2m', 'nan',
    # lepton
    'pTl1', 'etal1', 'phil1', 'nan', 'nan', 'nan',
    # energy
    'ETMiss', 'nan', 'ETMissPhi', 'nan', 'nan', 'metsig_New',
]

# A markdown table to display the structure of a single event.
EVENT_TABLE = """
    Each event is a graph with 6/7 nodes. Each node is built from the raw file as follows:

    | Particle          | Feature 1 | Feature 2 | Feature 3   | Feature 4     | Feature 5 | Feature 6    |
    |-------------------|-----------|-----------|-------------|---------------|-----------|--------------|
    | jet1              |  'pTj1'   | 'etaj1'   |   'phij1'   | 'j1_quantile' |    nan    |     nan      |
    | jet2              |  'pTj2'   | 'etaj2'   |   'phij2'   | 'j2_quantile' |    nan    |     nan      |
    | jet3 (optional)   |  'pTj3'   | 'etaj3'   |   'phij3'   | 'j3_quantile' |    nan    |     nan      |
    | b1                |  'pTb1'   | 'etab1'   |   'phib1'   | 'b1_quantile' |   'b1m'   |     nan      |
    | b2                |  'pTb2'   | 'etab2'   |   'phib2'   | 'b2_quantile' |   'b2m'   |     nan      |
    | lepton            |  'pTl1'   | 'etal1'   |   'phil1'   |      nan      |    nan    |     nan      |
    | energy            | 'ETMiss'  |   nan     | 'ETMissPhi' |      nan      |    nan    | 'metsig_New' |
    """

class EventsDataset(InMemoryDataset):
    """
    Dataset of graphs representing collisions of particles.
    There are three types of event:
        - signal, label 1
        - singletop, label 0
        - ttbar, label 0

    Each event is a graph with 6 or 7 nodes and 6 attributes. Graphs are fully connected.

    Args:
        root (str): Root directory where the dataset should be saved.
        url (str): URL to download the dataset from.
        event_subsets (dict, optional): Dictionary containing the number of events to keep for each event type. Defaults to {'signal': 463056, 'singletop': 242614, 'ttbar': 6093298}.
        add_edge_index (bool, optional): Whether to add the fully connected edge index to the data objects. Defaults to True.
        delete_raw_archive (bool, optional): Whether to delete the raw archive after extracting it. Defaults to False.
        transform (callable, optional): A function/transform that takes in a `torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. Defaults to None.
        pre_transform (callable, optional): A function/transform that takes in a `torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. Defaults to None.
        pre_filter (callable, optional): A function that takes in a `torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. Defaults to None.
        download_type: If it is set to 1, it extracts all the h5 files in signal folder, if it is set to 2, it extracts the h5 file with all mixed signals.
        signal_filename: The name of the signal file to be processed.
    """
    def __init__(
            self,
            root,
            url,
            event_subsets: dict = DEFAULT_EVENT_SUBSETS,
            add_edge_index: bool = True,
            delete_raw_archive: bool = False,
            transform=None,
            pre_transform=None,
            pre_filter=None,
            download_type: int = 2,
            signal_filename: str = 'Wh_hbb_fullMix.h5'):  # Added signal_filename argument

        self.url = url
        self.delete_raw_archive = delete_raw_archive
        self.event_subsets = event_subsets
        self.add_edge_index = add_edge_index
        self.download_type = download_type  # Store download type
        self.signal_filename = signal_filename  # Store signal filename
        self.subset_string = '_'.join([f'{k}_{v}' for k, v in sorted(self.event_subsets.items())])

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return RAW_DIR_NAMES

    @property
    def processed_file_names(self):
        # Notice the processed file names depend on the number of events we keep for each event type.
        return [f'events_{self.subset_string}.pt']

    @property
    def event_structure(self):
        """
        Returns the event structure of the dataset.
        The event structure is a table that describes the different types of events that can occur in the dataset.
        Returns:
            str: A string containing a markdown table representing the event structure of the dataset.
        """
        return EVENT_TABLE

    def download(self):
        # Download raw directories to `self.raw_dir`.
        print(f'Downloading {self.url} to {self.raw_dir}...')
        print('This may take a while...')
        raw_archive = download_url(self.url, self.raw_dir, filename='events.tar', log=False)

        print('Extracting files...')
        with tarfile.open(raw_archive) as tar:
            if self.download_type == 1:
                # Extract all files in the folder
                tar.extractall(self.raw_dir)
            elif self.download_type == 2:
                members = tar.getmembers()
                for member in members:
                    # Extract the file which contains all signals mixed or the specified signal file.
                    if 'signal' in member.name and self.signal_filename not in member.name:
                        continue
                    tar.extract(member, self.raw_dir)

        if self.delete_raw_archive:
            os.remove(raw_archive)

        # In case the compressed file contains a single directory, we move the files to the raw_dir.
        print('Moving files...')
        for dir in self.raw_file_names:
            dirpath = glob.glob(f'{self.raw_dir}/**/{dir}', recursive=True)[0]
            shutil.move(dirpath, self.raw_dir)
            print(f'Moved {dirpath} to {self.raw_dir}')

        print('Cleaning up...')
        # Remove the directories which are not in self.raw_file_names.
        for f in os.listdir(self.raw_dir):
            if f not in self.raw_file_names + ['events.tar']:
                try:
                    shutil.rmtree(os.path.join(self.raw_dir, f))
                except NotADirectoryError:
                    os.remove(os.path.join(self.raw_dir, f))

        """
        At this stage, we should have the following directory structure.
        Notice h5 file names can change.

        root
        ├── processed
        └── raw
            ├── signal
            │   └── <specified signal file>
            ├── singletop
            │   └── singletop.h5
            └── ttbar
                └── ttbar.h5
        """

    def process(self):
     # Create a dictionary of h5 files, where keys are the event types and values are the path to the h5 file.
     # We don't know the .h5 file names, so we use glob to find them.

     h5_files = {}

     for d in self.raw_file_names:
        dir_path = os.path.join(self.raw_dir, d)
        if d == 'signal':
            signal_file_path = os.path.join(dir_path, self.signal_filename)
            if os.path.exists(signal_file_path):
                h5_files[d] = signal_file_path
        else:
            h5_files[d] = glob.glob(f'{dir_path}/*.h5', recursive=True)[0]

     data_list = []

     for event_type, h5_file in h5_files.items():
        # Labels is the same for all events in the same directory.
        label = EVENT_LABELS[event_type]
        # Read data into pandas dataframe and filter out useless columns.
        graphs = pd.read_hdf(h5_file)
        graphs.drop(columns=list(set(graphs.columns) - set(USEFUL_COLS)), inplace=True)
        # Hackish way to have all rows with the same number of columns.
        graphs['nan'] = torch.nan
        # Rearrange columns to have the same order as USEFUL_COLS and create index column.
        graphs = graphs[USEFUL_COLS].reset_index()
        # Shuffle the dataframe and possibly keep only part of it.
        graphs = graphs.sample(n=self.event_subsets[event_type], random_state=RANDOM_STATE)

        for row in tqdm(graphs.values, total=graphs.shape[0], desc=f'Processing events in {h5_file}'):
            event_id = int(row[0])
            graph_features = row[1:]
            if (event_type not in 'ttbar' and event_type  not in  'singletop' and event_type  not in 'signal'):
                values_array = []

                # Iterate through every element in the array
                for element in row[1:]:
                    # If the element is a dictionary add all of the values to a new array
                    if isinstance(element, dict):
                        values_array.extend(element.values())
                    # If the element is a NaN simply add a simple NaN to the array
                    elif np.isnan(element):
                        values_array.append(np.nan)

                # Convert the array of values into a numpy array
                graph_features = np.array(values_array)

            x = torch.from_numpy(graph_features).reshape(7, -1)

            x = x[x[:, 0] > 0]

            # graphs are all fully connected
            edge_index = None
            if self.add_edge_index:
                directed_edge_index = torch.combinations(torch.arange(x.shape[0]), 2)
                edge_index = torch.cat([directed_edge_index, directed_edge_index.flip(1)], dim=0).T

            # TODO should we add the edge index here? Knowing it is fully connected, does it make sense to waste space for this ?
            # TODO make the event id a constant across multiple datasets

            data_list.append(Data(
                x=x,
                event_id=f'{event_type}_{event_id}',
                y=label,
                edge_index=edge_index,
            ))

     if self.pre_filter is not None:
        data_list = [data for data in data_list if self.pre_filter(data)]

     if self.pre_transform is not None:
        data_list = [self.pre_transform(data) for data in data_list]

     data, slices = self.collate(data_list)
     torch.save((data, slices), self.processed_paths[0])

### 3. Training a Simple GNN

Here you can define your GNN, train them on the data, and see what happens. You can do it with [pytorch-lightning](https://lightning.ai/) or in plain PyTorch. Here I do it in Pytorch.

In [6]:
dataset = EventsDataset(
    root='/hepstore/janders/Analysis_Wh',
    url='https://cernbox.cern.ch/s/0nh0g7VubM4ndoh/download',
    delete_raw_archive=False,
    add_edge_index=True,
    event_subsets={'signal': 2000, 'singletop': 1000, 'ttbar': 1000},
    transform=MakeHomogeneous(),
    download_type=2
)

Processing...
Processing events in /hepstore/janders/Analysis_Wh/raw/signal/Wh_hbb_fullMix.h5:   0%| | 0/2000 [00:00<?, ?i


TypeError: expected np.ndarray (got numpy.ndarray)

Before training, we have to split the data into train and test set as usual.

In [None]:
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split


# generate indices: instead of the actual data we pass in integers
train_indices, test_indices = train_test_split(
    range(len(dataset)),
    train_size=0.8,
    stratify=[g.y.item() for g in dataset], # to have balanced subsets
    random_state=42
)

train_graphs = Subset(dataset, train_indices)
test_graphs = Subset(dataset, test_indices)

print(f'Train set contains {len(train_graphs)} graphs, Test set contains {len(test_graphs)} graphs')


In [None]:
# Dataloaders allow us to batch datasets together and speed up the training
# For more info about dataloaders: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#mini-batches

train_loader = DataLoader(train_graphs, batch_size=96, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=96, shuffle=False)

We define a simple GNN

In [None]:
from torch_geometric.nn import GCNConv, global_mean_pool
import torch

MANUAL_SEED = 1234

class GCN(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, num_classes):
        super(GCN, self).__init__()
        torch.manual_seed(MANUAL_SEED)
        self.conv1 = GCNConv(input_channels, hidden_channels)
        self.relu = torch.nn.ReLU()
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.aggregate = global_mean_pool
        self.head = torch.nn.Linear(hidden_channels, num_classes)


    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        x = self.aggregate(x, batch)
        x = self.head(x)
        return x

We run the training, now saving the output

In [None]:
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from tqdm import tqdm # for nice bar


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(input_channels=12, hidden_channels=36, num_classes=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)


train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

compute_acc = Accuracy(task='binary')

train_epochs = 50  # Training epochs

model.train()
for epoch in range(train_epochs):
  print(f'Training epoch: {epoch}')
  for batch in tqdm(train_loader, leave=False):
    optimizer.zero_grad()
    out = model(batch.x.float(), batch.edge_index, batch.batch)
    loss = F.binary_cross_entropy_with_logits(out.squeeze(), batch.y.float())
    loss.backward()
    train_losses.append(loss.detach().item())
    train_accuracies.append(compute_acc(out.detach().squeeze(), batch.y.float()))
    optimizer.step()

  print(f'Validation epoch: {epoch}')
  with torch.no_grad():
    for batch in tqdm(test_loader, leave=False):
      out = model(batch.x.float(), batch.edge_index, batch.batch)
      loss = F.binary_cross_entropy_with_logits(out.squeeze(), batch.y.float())
      test_losses.append(loss.detach().item())
      test_accuracies.append(compute_acc(out.detach().squeeze(), batch.y.float()))

  # NEW: Save model snapshot for every epoch... you can save it every 10?
  torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'test_losses': test_losses,
        'train_accuracies': train_accuracies,
        'test_accuracies': test_accuracies
    }, f'epoch_{epoch}_snapshot.pth')

Plot the training and test loss and acuracy

In [None]:
#NEW
import torch
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

# Load the saved snapshots
snapshot_paths = ["epoch_{}_snapshot.pth".format(epoch) for epoch in range(train_epochs)]
train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []
total_epochs = 0

for snapshot_path in snapshot_paths:
    snapshot = torch.load(snapshot_path)
    train_losses.extend(snapshot['train_losses'])
    test_losses.extend(snapshot['test_losses'])
    train_accuracies.extend(snapshot['train_accuracies'])
    test_accuracies.extend(snapshot['test_accuracies'])

epochs = np.arange(1, train_epochs + 1)

fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=("Train loss", "Test loss", "Train acc", "Test acc"))

fig.add_trace(go.Scatter(x=epochs[:len(train_losses)], y=train_losses), row=1, col=1)
fig.add_trace(go.Scatter(x=epochs[:len(test_losses)], y=test_losses), row=1, col=2)
fig.add_trace(go.Scatter(x=epochs[:len(train_accuracies)], y=train_accuracies), row=2, col=1)
fig.add_trace(go.Scatter(x=epochs[:len(test_accuracies)], y=test_accuracies), row=2, col=2)

fig.update_layout(height=800, width=1000, title='Training results', showlegend=False)

fig.show()


## 4. Plot output score

Plot the score directly from the training

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming test_loader.dataset.y contains the true labels and out.detach().squeeze() contains the predicted logits

# Concatenate true labels and predicted logits
true_labels = torch.cat([batch.y for batch in test_loader.dataset]).numpy()
predicted_logits = torch.cat([model(batch.x.float(), batch.edge_index, batch.batch).detach().squeeze() for batch in test_loader]).numpy()

# Convert logits to probabilities using sigmoid function
predicted_probs = torch.sigmoid(torch.tensor(predicted_logits)).numpy()

# Filter predictions based on true labels (signal and background)
signal_indices = true_labels == 1
background_indices = true_labels == 0

signal_probs = predicted_probs[signal_indices]
background_probs = predicted_probs[background_indices]

# Plotting
plt.figure(figsize=(10, 6))

plt.hist(signal_probs, bins=50, color='skyblue', alpha=0.7, label='Signal', density=True)
plt.hist(background_probs, bins=50, color='orange', alpha=0.7, label='Background', density=True)

plt.title('Model Output Distribution on Validation Dataset')
plt.xlabel('Predicted Probabilities')
plt.ylabel('Density')
plt.legend()
plt.show()

## 5. Load data from snapsot

You should take an epoch after the loss and acuracy have leveled off but before any overtraining.  Here I just take the last one

In [None]:
def add_score_from_snapshot(dataset, model, snap_name):
    # Load snapshot and set model to that point
    checkpoint = torch.load(snap_name)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Loop over dataset and apply model
    new_data = []
    for elem in dataset:
        output = model(elem.x.float(), elem.edge_index, elem.batch)
        prob = torch.sigmoid(output).squeeze().detach()
        # Make a new variable on the graph to add the score
        elem.score = prob
        new_data.append(elem)

    return new_data

Load snapshot and add score to graphs

In [None]:
graphs_with_score = add_score_from_snapshot(test_graphs, model, "epoch_40_snapshot.pth")

Plot score from snapshot.  Should be same as above for last epoch

In [None]:
# Plot score from the snapshot
plt.figure(figsize=(10, 6))
#plt.yscale("log")
plt.xlabel("GNN Score")
plt.ylabel("Density")
plt.hist([d.score for d in graphs_with_score if d.y == 0], alpha = 0.5, bins = 50, density = True, label = "Background", color = "orange")
plt.hist([d.score for d in graphs_with_score if d.y == 1], alpha = 0.5, bins = 50, density = True, label = "Signal", color = "skyblue")
plt.legend()
plt.show()

Plot a varibale (here lepton pT) for high and low score background, compared to signal.  The high-score background has a longer tail, more like the signal

In [None]:
# Get leptons for signal events
sig_lep= [g.x[5] if g.x.shape == (7,6) else g.x[4] for g in graphs_with_score if g.y == 1]

# Get leptons for background events with low or high score
threshold = 0.5
bkg_lep_high = [g.x[5] if g.x.shape == (7,6) else g.x[4] for g in graphs_with_score if g.y == 0 and g.score > threshold]
bkg_lep_low = [g.x[5] if g.x.shape == (7,6) else g.x[4] for g in graphs_with_score if g.y == 0 and g.score < threshold]

# Plot
bins = np.linspace(0,1000, 100)
plt.figure(figsize=(10, 6))
plt.xlabel("$p_T^{lep}$")
plt.ylabel("Density")
plt.yscale("log")
plt.hist([f[0].item() for f in sig_lep], bins = bins,
        color="skyblue", alpha = 0.5, label = "Signal", density = True)
plt.hist([f[0].item() for f in bkg_lep_high], bins = bins,
        color="orange", alpha = 0.5, label = f"Bkg (score > {threshold}", density = True)
plt.hist([f[0].item() for f in bkg_lep_low], bins = bins,
        color="green", alpha = 0.5, label =  f"Bkg (score < {threshold}", density = True)
plt.legend()
plt.show()