In [None]:
# !!! Make sure to use pytorch-2.3.1 kernel !!!

# Only need to run this once. May need to restart kernel after first run!
!pip install --user energyflow seaborn

In [None]:
from typing import NamedTuple, Dict, Optional, Tuple

import numpy as np
import energyflow as ef
import h5py

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn import metrics

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms


from matplotlib import pyplot as plt
import seaborn as sb

from tqdm.notebook import tqdm
from scipy.ndimage import gaussian_filter1d

# For Typing
from numpy import ndarray as Array
from torch import Tensor

In [None]:
sb.set_theme(context="notebook", style="whitegrid", font_scale=1.5, rc={"figure.figsize": (9, 6)})

# Dataset

We will load in and use a dataset of quark and gluon jets generated by Pythia.

This dataset comes from the EnergyFlow authors, big thanks to them! \
https://energyflow.network/ \
https://arxiv.org/abs/1810.05165

## Download Data
Jets will each be loaded as a list of constituents, storing  $(p_T, \eta, \phi, \texttt{PID})$ for each constituent.

From the dataset description:
>Each dataset consists of two components:
>
>X : a three-dimensional numpy array of the jets with shape (num_data,max_num_particles,4).\
>y : a numpy array of quark/gluon jet labels (quark=1 and gluon=0).

This will be a **Set Classification Problem**. Mapping a set of objects to a single label for the entire set.

In [None]:
X, y = ef.qg_jets.load(cache_dir="/global/cfs/cdirs/ntrain1/attention/energyflow/")

num_events, max_constituents, num_basic_features = X.shape

print("Number of events   : ", num_events)
print("Max constituents   : ", max_constituents)
print("Number of basic features : ", num_basic_features)

In [None]:
# Perform a simple train-test split for example purposes
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

## Preprocessing

Split the data into the kinematics and pid parts.

1. Constituents are stored as padded arrays with maximum length. Use PID to determine masking vector.
2. Normalize kinematics to make it easier to learn
3. Convert PIDs into one-hot vectors.
4. Append the two back together to form the final dataset.

This function will also convert from the numpy arrays into torch tensors, which will make it easier to feed into neural networks later.

In [None]:
class PreprocessingData(NamedTuple):
    normalizer: StandardScaler
    one_hot_encoder: OneHotEncoder
    
def preprocess_data(X: Array, y: Array, preprocessing_data: Optional[PreprocessingData] = None) -> Tuple[Array, PreprocessingData]:
    pt, rapidity, phi, pid = np.split(X, 4, axis=-1)
    pid = pid.astype(np.int64)

    # Averege Jet kinematics
    total_pt = np.sum(pt, axis=1, keepdims=True)
    average_rapidity = np.average(rapidity, weights=pt, axis=1, keepdims=True)
    average_phi = np.average(phi, weights=pt, axis=1, keepdims=True)

    # Shift features to be centered around the average jet kinematics
    pt = pt / total_pt
    rapidity = rapidity - average_rapidity
    phi = phi - average_phi

    # Split phi into sin and cos components.
    sin_phi, cos_phi = np.sin(phi), np.cos(phi)

    # Combine features into a single array.
    X_kinematics = np.concatenate([pt, rapidity, sin_phi, cos_phi], axis=-1)
    X_pid = pid[..., 0]

    # Sort kinematics by PT.
    # This doesnt change training, but makes it easier to visualize the data later.
    sorting_indices = np.argsort(X_kinematics[..., 0], axis=1)
    sorting_indices = np.flip(sorting_indices, axis=1)  
    X_kinematics = X_kinematics[np.arange(len(X_kinematics))[:, None], sorting_indices]
    X_pid = X_pid[np.arange(len(X_kinematics))[:, None], sorting_indices]
    
    # Convert targets into binary boolean values.
    y = y > 0.5

    # Determine mask from PID. This mask is a NEGATIVE mask.
    # If mask it True -> Constituent is INVALID.
    # This is torch convention.
    mask = X_pid == 0
    
    # Only fit a normalizer on training data, not on testing data.
    # Reuse normalizer from training data for testing data.
    if preprocessing_data is None:
        # Fit normalizer on masked data.
        normalizer = StandardScaler()
        normalizer.fit(X_kinematics[~mask])

        # Fit one-hot encoder on PID data.
        one_hot_encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        one_hot_encoder.fit(X_pid[~mask].reshape(-1, 1))

        preprocessing_data = PreprocessingData(
            normalizer=normalizer,
            one_hot_encoder=one_hot_encoder
        )

    # Apply transformation to data.
    X_kinematics = preprocessing_data.normalizer.transform(X_kinematics.reshape(-1, 4)).reshape(X_kinematics.shape)
    X_pid = preprocessing_data.one_hot_encoder.transform(X_pid.reshape(-1, 1)).reshape(*X_pid.shape, -1)

    # Construct final features
    X = np.concatenate([X_kinematics, X_pid], axis=-1)
    
    # Mask out bad features using the mask we computed earlier.
    X = np.where(mask[..., None], 0, X)

    # Convert to torch tensor format.
    X = torch.from_numpy(X).float()
    y = torch.from_numpy(y).float()
    mask = torch.from_numpy(mask)

    dataset = TensorDataset(X, y, mask)

    return dataset, preprocessing_data

In [None]:
# Create datasets.
# Notice how we reuse the preprocessing data from the training data.
train_dataset, preprocessing_data = preprocess_data(X_train, y_train)
test_dataset, preprocessing_data = preprocess_data(X_test, y_test, preprocessing_data=preprocessing_data)

In [None]:
# Grab an example batch and examine its contents.
# Batches consist of (X, y, mask) tuples.
batch = test_dataset[:32]
for i, element in enumerate(batch):
    print("batch[{}].shape = {}\n        .dtype = {}".format(i, element.shape, element.dtype))

num_features = batch[0].shape[-1]

# Attention

Here we implement a basic attention mechanism on constituents.

A couple of helper function are provided for working with padded arrays with a mask.

In [None]:
# A utility function we provide for you.
# A masked verion of the softmax function so that the invalid elements are ignored and set to 0 at all times.
def masked_softmax(similarities: Tensor, mask: Tensor, dim: int = -1) -> Tensor:
    """
    Input Shapes
    ------------
    similarities: [B, N, N]
    mask: [B, N]

    Output Shape
    ------------
    masked_softmax_similarities: [B, N, N]

    """

    B, N = mask.shape

    # Attention mask will be where both the row and column are valid
    square_mask = mask.reshape(B, N, 1) & mask.reshape(B, 1, N)

    # Set masked logits to negative infinity.
    # When performing softmax, these logits will become zero.
    similarities = similarities.masked_fill(square_mask, float('-inf'))

    # Perform softmax.
    attention_weights = F.softmax(similarities, dim=dim)

    # The masked rows will have invalid attention weights, convert them to zero for simplicity.
    attention_weights = torch.nan_to_num(attention_weights, nan=0.0)

    return attention_weights


# A utility function we provide for you.
# Mask out any vector to make sure any invalid entries are zero.
def mask_vector(vector: Tensor, mask: Tensor) -> Tensor:
    """
    Input Shapes
    ------------
    vector: [B, N, D]
    mask: [B, N]

    Output Shape
    ------------
    masked_vector: [B, N, D]

    """
    return vector.masked_fill(mask[..., None], 0.0)


## Basic Self-Attention

First we implement the basic dot-product self-attention layer.
We will not worry about multi-head attention for now.

Your assignment will be to implement this attention mechanism following the formulas presented in the slides or attention is all you need paper. \
https://arxiv.org/pdf/1706.03762

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, features: int):
        super().__init__()

        self.num_features = num_features

        self.query = nn.Linear(features, features)
        self.key = nn.Linear(features, features)
        self.value = nn.Linear(features, features)

    def forward(self, x: Tensor, src_key_padding_mask: Tensor) -> Tensor:
        """
        Input shapes
        ------------
        x: [B, N, D]
        mask: [B, N]

        Output shapes
        ------------
        O: [B, N, D]
        A: [B, N, N]
        """

        # ----------------------------------------------------
        # Assignment 1.
        # Implement the SelfAttention module.
        # HINT: Torch can perform batch matrix multiplication using the @ operator.
        #       A: [B, I, J]
        #       B: [B, J, K]
        #       C: [B, I, K]
        #       C = A @ B 
        # ----------------------------------------------------

        # Compute query, key, value.
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Compute attention similarities.
        S = (Q @ K.transpose(1, 2)) / np.sqrt(self.num_features)

        # Compute attention weights
        A = masked_softmax(S, src_key_padding_mask)

        # Compute output.
        O = A @ V

        return O, A

## The Transformer Encoder Block

Now we combine the self-attention block we implemented above with the other components of a transformer.

This will form a basic building block that we can repeat over and over again.

Implement the internal logic of the transformer encoder following the slides or the original paper. \
https://arxiv.org/pdf/1706.03762

In [None]:
# Following the original paper, the feed forward part of the transformer uses two linear layers.
# First we expand the dimensions, apply a non-linearlity, and then contract the dimensions back.
class FeedForwardBlock(nn.Sequential):
    def __init__(self, features: int, feed_forward_features: int):
        super().__init__(
            nn.Linear(features, feed_forward_features),
            nn.ReLU(),
            nn.Linear(feed_forward_features, features)
        )

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, features: int, feed_forward_features: int):
        super().__init__()

        self.attention = SelfAttention(features)
        self.layer_norm1 = nn.LayerNorm(features)

        self.feed_forward = FeedForwardBlock(features, feed_forward_features)
        self.layer_norm2 = nn.LayerNorm(features)

    def forward(self, x: Tensor, src_key_padding_mask: Tensor) -> Tensor:
        """
        Input shapes
        ------------
        x: [B, N, D]
        src_key_padding_mask: [B, N]

        Output shapes
        ------------
        A: [B, N, N]
        O: [B, N, D]
        
        """

        # -------------------------------------------------------------------------------------------
        # Assignment 2.
        # Implement the TransformerBlock module.

        # NOTE: There are two version of the transformer. PRE-NORM and POST-NORM.
        # The different is whether to apply the layer-norm before attention / feed-forward or after.
        # Some interestin discussion on the difference: https://arxiv.org/pdf/2002.04745
        # Both are good for this tutorial, but PRE-NORM is generally better in my experience.
        # -------------------------------------------------------------------------------------------

        # Transformer Block
        y = self.layer_norm1(x)
        y = self.attention(y, src_key_padding_mask)[0]
        x = x + y
        
        # Feed Forward Block
        y = self.layer_norm2(x)
        y = self.feed_forward(y)
        x = x + y

        return mask_vector(x, src_key_padding_mask)

## Jet Embedding

Since this is a Set Classification problem, we will need some way to add an additional input / output responsible for summarizing the entire jet. \
This is very common in particle physics, as we typically want to summarize a set to extract an observable.

This will consist of:

1. Storing a learned vector to add as an extra input. Also make sure the mask is updated accordingly.
2. Extracting the learned vector after processing to get a summary of all of the inputs.

![Event Embedding](./EventEmbedding.PNG)

In [None]:
class AddJetEmbedding(nn.Module):
    def __init__(self, features: int):
        super().__init__()

        self.jet_embedding = nn.Parameter(torch.randn(1, 1, features))

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """
        Input shapes
        ------------
        x: [B, N, D]
        mask: [B, N]

        Output shapes
        ------------
        x: [B, N + 1, D]
        mask: [B, N + 1]
        
        """
        # -------------------------------------------------------------------------------------------
        # Assignment 3.
        # Implement the AddJetEmbedding module.
        # -------------------------------------------------------------------------------------------
        
        B, N, D = x.shape

        # Add the jet embedding to the input.
        jet_embedding = self.jet_embedding.expand(B, 1, D)
        x = torch.cat([jet_embedding, x], dim=1)

        # Create a mask for the jet embedding.
        jet_mask = torch.zeros(B, 1, dtype=mask.dtype, device=mask.device)
        mask = torch.cat([jet_mask, mask], dim=1)

        return x, mask
    

class ExtractJetEmbedding(nn.Module):
    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """
        Input shapes
        ------------
        x: [B, N, D]
        mask: [B, N]

        Output shapes
        ------------
        x: [B, D]
        mask: [B]
        
        """

        # -------------------------------------------------------------------------------------------
        # Assignment 3.
        # Implement the ExtractJetEmbedding module.
        # -------------------------------------------------------------------------------------------

        # Extract the jet embedding from the input.
        return x[:, 0], mask[:, 0]

# Classification Task

Now we define the entire network for the quark / gluon set classification task. \
We combine the block we defined above with simple linear embedding and output layers to create a complete network.

In [None]:
# Simple Linear embedding layer.
class Embedding(nn.Linear):
    def __init__(self, input_features: int, hidden_features: int):
        super().__init__(input_features, hidden_features)

# Simple Linear output layer for binary classification.
# This layer outputs logits!!!
class ClassifierOutput(nn.Linear):
    def __init__(self, input_features: int):
        super().__init__(input_features, 1)

In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, input_features: int, attention_features: int, feed_forward_features: int, num_transformer_layers: int):
        super().__init__()

        self.embedding = Embedding(input_features, attention_features)
        self.add_jet_embedding = AddJetEmbedding(attention_features)

        self.transformer_layers = nn.ModuleList([
            TransformerEncoder(attention_features, feed_forward_features) for _ in range(num_transformer_layers)
        ])

        # If you couldn't finish the previous assignments, you can use the nn.TransformerEncoderLayer instead.
        # Also this would be the official torch version.
        # You can play around with the number of heads, norm_first, etc. after the initial run through.
        # ----------------------------------------------------------------------------------------------------
        # self.transformer_layers = nn.ModuleList([
        #     nn.TransformerEncoderLayer(
        #         d_model=attention_features, 
        #         nhead=1, 
        #         dim_feedforward=feed_forward_features, 
        #         dropout=0.0, 
        #         batch_first=True, 
        #         norm_first=True
        #     ) for _ in range(num_transformer_layers)
        # ])

        self.extract_jet_embedding = ExtractJetEmbedding()
        self.output = ClassifierOutput(attention_features)

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """
        Input shapes
        ------------
        x: [B, N, I]
        mask: [B, N]

        Output shapes
        ------------
        logits: [B]
        
        """
        # -------------------------------------------------------------------------------------------
        # Assignment 4.
        # Implement the TransformerClassifier module.
        # -------------------------------------------------------------------------------------------

        # Embedding Layer
        x = self.embedding(x)
        x = mask_vector(x, mask)

        # Add Jet Embedding
        x, mask = self.add_jet_embedding(x, mask)

        # Transformer Layers
        for layer in self.transformer_layers:
            x = layer(x, src_key_padding_mask=mask)

        # Extract Jet Embedding
        x, mask = self.extract_jet_embedding(x, mask)
        
        # Output Layer
        logits = self.output(x)
        logits = logits.squeeze(-1)

        return logits

## Training Loop

A simple training loop on this example dataset.
There are two networks to choose from, a larger and smaller network.

The larger network is able to almost match the results of 1810.05165 without tuning. \
Further performance improvement can be had by fine-tuning the archtecture to the classification problem

In [None]:
# Big Network on GPU ~ 20 minutes to train
# Small Network on GPU ~ 2 minutes to train

USE_CUDA = True
BIG_NETWORK = True

if BIG_NETWORK:
    ATTENTION_FEATURES = 128
    FEED_FORWARD_FEATURES = 256
    NUM_TRANSFORMER_LAYERS = 6
    NUM_EPOCHS = 40

else:
    ATTENTION_FEATURES = 64
    FEED_FORWARD_FEATURES = 128
    NUM_TRANSFORMER_LAYERS = 4
    NUM_EPOCHS = 10

network = TransformerClassifier(
    input_features=num_features,
    attention_features=ATTENTION_FEATURES,
    feed_forward_features=FEED_FORWARD_FEATURES,
    num_transformer_layers=NUM_TRANSFORMER_LAYERS
)

optimizer = torch.optim.Adam(
    network.parameters(), 
    lr=1e-3
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True
)

if USE_CUDA:
    network = network.cuda()

In [None]:
losses = []

with tqdm(range(NUM_EPOCHS), position=0, desc='Training Epoch') as epoch_progress_bar:
    for epoch in epoch_progress_bar:
        with tqdm(train_dataloader, position=1, leave=False, desc='Training Batch') as batch_progress_bar:
            for batch in batch_progress_bar:
                X, y, src_key_padding_mask = batch

                if USE_CUDA:
                    X = X.cuda()
                    y = y.cuda()
                    src_key_padding_mask = src_key_padding_mask.cuda()

                logits = network(X, src_key_padding_mask)
                loss = F.binary_cross_entropy_with_logits(logits, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_progress_bar.set_postfix_str("Loss: {:.4f}".format(loss.item()))
                losses.append(loss.item())

if BIG_NETWORK:
    torch.save(network.state_dict(), 'checkpoint.pth')

In [None]:
plt.plot(gaussian_filter1d(losses, sigma=100.0))
plt.xlabel("Batch Number")
plt.ylabel("Loss")
plt.title("Training Loss")

## Evaluate Network

In [None]:
# Skip to here if you dont want to wait
if BIG_NETWORK:
    network.load_state_dict(torch.load('checkpoint.pth'))

test_predictions = []

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=64
)

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        X, y, src_key_padding_mask = batch

        if USE_CUDA:
            X = X.cuda()
            y = y.cuda()
            src_key_padding_mask = src_key_padding_mask.cuda()

        logits = network(X, src_key_padding_mask)
        probabilities = torch.sigmoid(logits)
        test_predictions.append(probabilities.cpu().numpy())

test_predictions = np.concatenate(test_predictions)

print(metrics.classification_report(y_test, test_predictions > 0.5))

In [None]:
fpr, tpr, _ = metrics.roc_curve(y_test, test_predictions)
plt.plot(fpr, tpr, label="AUC: {:.3f}".format(metrics.roc_auc_score(y_test, test_predictions)))
plt.plot([0, 1], [0, 1], linestyle='--')
plt.legend()
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")

## Examine Attention

Add torch hooks to capture the attention weights from each attention layer.
This will let us examine what the network is paying attention to.

In [None]:
import matplotlib.colors as colors

def plot_attention(A):
    num_constituents = A.shape[0]
    
    plt.imshow(A[::-1], cmap='viridis', aspect='auto')
    plt.xlabel("Constituent")
    plt.ylabel("Constituent")
    plt.xticks(np.arange(num_constituents) - 0.5, [f"" for x in np.arange(num_constituents)])
    plt.yticks(np.arange(num_constituents) - 0.5, [f"" for x in np.arange(num_constituents)])

    plt.xticks(np.arange(num_constituents), [f"{x}" for x in np.arange(num_constituents)], minor=True)
    plt.yticks(np.arange(num_constituents), [f"{x}" for x in np.arange(num_constituents)][::-1], minor=True)

    plt.colorbar(label="Attention Weight")

In [None]:
attentions = []

def attention_hook(module, input, output):
    attentions.append(output[1])
    return output

hooks = []

for name, module in network.named_modules():
    if isinstance(module, (SelfAttention, nn.MultiheadAttention)):
        hooks.append(module.register_forward_hook(attention_hook))

Pick an example jet and lets look at the attention weights for this jet.

In [None]:
# Extract an example jet.
INDEX = 8

x, y, mask = test_dataset[INDEX:INDEX  + 1]
x = x.cuda()
mask = mask.cuda()
y = y.numpy()

# Remove the padding jets since this is a one element batch.
num_constituents = (~mask).sum().item()
x = x[:, :num_constituents]
mask = mask[:, :num_constituents]

num_constituents

In [None]:
# Run the network, keeping track of the attention weights for each transformer layer.
with torch.no_grad():
    attentions.clear()
    network(x, mask)
    attentions = [attention[0].cpu().numpy() for attention in attentions]

### Plot individual attention matrices for every layer and every input

First we look at the individual attention weights for every layer, sequentially.

These are the raw $A$ normalized similarity matrices we extracted during training.

In [None]:
for i, A in enumerate(attentions):
    plt.figure(figsize=(9, 6))
    plot_attention(A)
    plt.title(f"Attention Matrix for Layer {i + 1}")
    plt.xticks(rotation=-90, minor=True)

### Plot a combined attention matrix

We can combine all of the attention matrices together by simply multiplying them together.

If $A_i$ is the attention matrix for layer $i$, then the total attention is simply:
$$
A = \prod_{i=1}^N A_i
$$

In [None]:
A = np.eye(num_constituents + 1)
for attention in attentions:
    A = attention @ A

plot_attention(A)
plt.title("Total Attention Matrix")
plt.xticks(rotation=-90, minor=True);

### Focus on the Jet Vector

Since we are only performing a Jet-level classification of all of the constituents, all of the rows are actually the same.
We can just look at the first vector's attention (the special jet vector) to get a better look at the importance.

In [None]:
plt.bar(np.arange(A.shape[0]), A[0])
plt.xlabel("Constituent")
plt.ylabel("Attention Weight")
plt.title("Jet Vector Attention Weights")

# Multi Classification Task

As a challenge, lets say that we want to predict something about every element of a set.

We will load in the SPANet dataset from here. https://github.com/Alexanders101/SPANet/blob/master/docs/TTBar.md

- This dataset consists of a collection of jets. 
- Jets will be represented as collections of $(p_T, \eta, \phi, m)$ four-momentum vectors.
- Targets will be the parton labels originating from a full-hadronic $t \bar{t}$ decay.
- We will also limit the dataset to only events where all 6 partons are assigned to a jet.

We will not implement the full SPANet algorithm here. We focus on a more basic element-wise approach as an example. This has key limitations:
1. This approach will not know about or take into account how many of each parton should exist.
2. This approach will know nothing about symmetry or symmetric assignments.

This will serve as an introduction to using attention for set-wise classification.

In [None]:
TRAIN_FILE = "/global/cfs/cdirs/ntrain1/attention/full_hadronic_ttbar/training.h5"
TRAIN_SIZE = 10_000_000

TEST_FILE = "/global/cfs/cdirs/ntrain1/attention/full_hadronic_ttbar/testing.h5"
TEST_SIZE = 1_000_000

def load_dataset(filepath, size):
    with h5py.File(filepath, 'r') as file:
        # Extract jet kinematics
        x = np.stack((
            file["INPUTS"]["Source"]["pt"][:size],
            file["INPUTS"]["Source"]["eta"][:size],
            file["INPUTS"]["Source"]["phi"][:size],
            file["INPUTS"]["Source"]["mass"][:size]
        ), axis=-1)

        mask = ~file["INPUTS"]["Source"]["MASK"][:size]
        
        # Assign per-jet labels
        # Also keep track of how many parton are even reconstructable.
        # For this example we only want to focus on events with all 6 partons assigned to jets.
        y = np.zeros((x.shape[0], x.shape[1]), dtype=np.int64)

        dummy_index = np.arange(x.shape[0])
        num_targets = np.zeros(x.shape[0], dtype=np.int64)

        y[dummy_index, file["TARGETS"]["t1"]["b"][:size]] = 1
        num_targets += file["TARGETS"]["t1"]["b"][:size] >= 0

        y[dummy_index, file["TARGETS"]["t1"]["q1"][:size]] = 2
        num_targets += file["TARGETS"]["t1"]["q1"][:size] >= 0

        y[dummy_index, file["TARGETS"]["t1"]["q2"][:size]] = 3
        num_targets += file["TARGETS"]["t1"]["q2"][:size] >= 0

        y[dummy_index, file["TARGETS"]["t2"]["b"][:size]] = 4
        num_targets += file["TARGETS"]["t2"]["b"][:size] >= 0

        y[dummy_index, file["TARGETS"]["t2"]["q1"][:size]] = 5
        num_targets += file["TARGETS"]["t2"]["q1"][:size] >= 0

        y[dummy_index, file["TARGETS"]["t2"]["q2"][:size]] = 6
        num_targets += file["TARGETS"]["t2"]["q2"][:size] >= 0

        y = np.where(mask, 0, y)

    # Hardcoded normalization from training dataset for simplicity
    x = x - np.array([6.9162e+01, 0, 0,  8.3870e+00])
    x = x / np.array([50.5310,  1.4054,  1.8138,  6.5079])

    # Full Events only
    x = x[num_targets == 6]
    y = y[num_targets == 6]
    mask = mask[num_targets == 6]

    # Convert to torch
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).long()
    mask = torch.from_numpy(mask).bool()

    return TensorDataset(x, y, mask)

train_dataset = load_dataset(TRAIN_FILE, TRAIN_SIZE)
test_dataset = load_dataset(TEST_FILE, TEST_SIZE)

In [None]:
# Simple Linear multi-output layer for multi-class classification.
# This layer outputs logits!!!
class MultiClassifierOutput(nn.Linear):
    def __init__(self, input_features: int, num_classes: int):
        super().__init__(input_features, num_classes)

class RemoveEventEmbedding(nn.Module):
    def forward(self, x, mask):
        return x[:, 1:], mask[:, 1:]

class TransformerMultiClassifier(nn.Module):
    def __init__(self, input_features: int, num_classes: int, attention_features: int, feed_forward_features: int, num_transformer_layers: int):
        super().__init__()

        self.embedding = Embedding(input_features, attention_features)
        self.add_event_embedding = AddJetEmbedding(attention_features)

        self.transformer_layers = nn.ModuleList([
            TransformerEncoder(attention_features, feed_forward_features) for _ in range(num_transformer_layers)
        ])

        self.remove_event_embedding = RemoveEventEmbedding()
        self.output = MultiClassifierOutput(attention_features, num_classes)

    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """
        Input shapes
        ------------
        x: [B, N, I]
        mask: [B, N]

        Output shapes
        ------------
        logits: [B, N]
        
        """
        # -------------------------------------------------------------------------------------------
        # Assignment 6.
        # Implement the TransformerMultiClassifier module.
        # 
        # You may optionally also use an event summary vector in this network as well.
        # Although we dont use the event vector for classification, there has been work 
        # showing that the extra vector allows the network to learn contextual information
        # related to all objects.
        # -------------------------------------------------------------------------------------------

        # Embedding Layer
        x = self.embedding(x)
        x = mask_vector(x, mask)

        x, mask = self.add_event_embedding(x, mask)

        # Transformer Layers
        for layer in self.transformer_layers:
            x = layer(x, src_key_padding_mask=mask)
        
        # Output Layer
        x, mask = self.remove_event_embedding(x, mask)
        logits = self.output(x)

        return logits

In [None]:
# Small Network on GPU ~ 20 minutes to train

USE_CUDA = True

ATTENTION_FEATURES = 64
FEED_FORWARD_FEATURES = 128
NUM_TRANSFORMER_LAYERS = 6
NUM_EPOCHS = 10

# This time only feeding in kinematics (pt, y, sin(phi), cos(phi)
# Outputting the PID as a classification task.
network = TransformerMultiClassifier(
    input_features=4,
    num_classes=7,
    attention_features=ATTENTION_FEATURES,
    feed_forward_features=FEED_FORWARD_FEATURES,
    num_transformer_layers=NUM_TRANSFORMER_LAYERS
)

optimizer = torch.optim.Adam(
    network.parameters(), 
    lr=1e-3
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True
)

if USE_CUDA:
    network = network.cuda()

In [None]:
losses = []

with tqdm(range(NUM_EPOCHS), position=0, desc='Training Epoch') as epoch_progress_bar:
    for epoch in epoch_progress_bar:
        with tqdm(train_dataloader, position=1, leave=False, desc='Training Batch') as batch_progress_bar:
            for batch in batch_progress_bar:
                X, y, src_key_padding_mask = batch

                if USE_CUDA:
                    X = X.cuda()
                    y = y.cuda()
                    src_key_padding_mask = src_key_padding_mask.cuda()

                # Ask the network for pid logits
                logits = network(X, src_key_padding_mask)

                # Only compute the loss for valid constituents.
                masked_logits = logits[~src_key_padding_mask]
                masked_targets = y[~src_key_padding_mask]
                loss = F.cross_entropy(masked_logits, masked_targets)

                # Perform backpropagation.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_progress_bar.set_postfix_str("Loss: {:.4f}".format(loss.item()))
                losses.append(loss.item())

torch.save(network.state_dict(), 'multi_checkpoint.pth')

In [None]:
plt.plot(gaussian_filter1d(losses, sigma=100.0))
plt.xlabel("Batch Number")
plt.ylabel("Loss")
plt.title("Training Loss")

In [None]:
# Skip to here if you dont want to wait
network.load_state_dict(torch.load('multi_checkpoint.pth'))

test_predictions = []
test_targets = []

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=64
)

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        X, y, src_key_padding_mask = batch

        if USE_CUDA:
            X = X.cuda()
            y = y.cuda()
            src_key_padding_mask = src_key_padding_mask.cuda()

        # Ask the network for pid logits
        logits = network(X, src_key_padding_mask)

        # Only compute the loss for valid constituents.
        masked_logits = logits[~src_key_padding_mask]
        masked_targets = y[~src_key_padding_mask]

        masked_probabilities = F.softmax(masked_logits, dim=-1)
        test_predictions.append(masked_probabilities.cpu().numpy())
        test_targets.append(masked_targets.cpu().numpy())

test_predictions = np.concatenate(test_predictions)
test_targets = np.concatenate(test_targets)

print(metrics.classification_report(test_targets, test_predictions.argmax(-1)))

In [None]:
attentions = []

def attention_hook(module, input, output):
    attentions.append(output[1])
    return output

hooks = []

for name, module in network.named_modules():
    if isinstance(module, (SelfAttention, nn.MultiheadAttention)):
        hooks.append(module.register_forward_hook(attention_hook))

In [None]:
# Extract an example jet.
INDEX = 2

x, y, mask = test_dataset[INDEX:INDEX  + 1]
x = x.cuda()
mask = mask.cuda()
y = y.numpy()

# Remove the padding jets since this is a one element batch.
num_vectors = (~mask).sum().item()
x = x[:, :num_vectors]
mask = mask[:, :num_vectors]

y = y[0, :num_vectors]

# Run the network, keeping track of the attention weights for each transformer layer.
with torch.no_grad():
    attentions.clear()

    # Split the data into kinematics and PID.
    p = network(x, mask)[0].cpu().numpy()
    attentions = [attention[0].cpu().numpy() for attention in attentions]
    
print("True      : ", y)
print("Predicted : ", p.argmax(-1))

In [None]:
PARTON_NAMES = ["$\\emptyset$", "$b_1$", "$q_1$", "$q_1$", "$b_2$", "$q_2$", "$q_2$"]

for i, A in enumerate(attentions):
    fig, ax = plt.subplots(figsize=(9, 7))
    plot_attention(A)
    plt.title(f"Attention Matrix for Layer {i + 1}")
    
    has_extra_event_vector = A.shape[0] > num_vectors
    if has_extra_event_vector:
        ax.set_yticklabels((["E"] + [PARTON_NAMES[int(x)] for x in y])[::-1], minor=True)
        ax.set_xticklabels(["E"] + [PARTON_NAMES[int(x)] for x in y], minor=True)
    else:
        ax.set_yticklabels([PARTON_NAMES[int(x)] for x in y][::-1], minor=True)
        ax.set_xticklabels([PARTON_NAMES[int(x)] for x in y], minor=True)
        
    plt.tight_layout()

# SPANet

The previous network showed a simple solution to the jet-parton assignment problem. It didn't perform very well because it didn't take into account symmetries.

If you're ready to try out the full symmetric attention network, you can follow the guide here: https://github.com/Alexanders101/SPANet/blob/master/docs/TTBar.md



1. Open a terminal window in jupyterlab.
2. Run `source setup_spanet.sh` to clone and load environment.
3. Follow the guide as normal. Ignore the first two section setting up environment and getting the data. All of that is done for you.