In this exercise we're going to use a transformer encoder to classify tracks as either $\mu^\pm$, $\pi^\pm$ or proton. Compared to the code we wrote in the previous exercise, this notebook is a bit more like how I would write code in reality. It uses classes for a custom dataloader and for the custom network architecture.

To begin with, we need to download the dataset

In [None]:
import os

# Load the track dataset:
if not os.path.isfile('larsoft_workshop_tracks.tgz'):
  !mkdir tracks
  !wget --no-check-certificate 'https://www.hep.phy.cam.ac.uk/~lwhitehead/larsoft_workshop_tracks.tgz' -O larsoft_workshop_tracks.tgz
  !tar -xzf larsoft_workshop_tracks.tgz -C tracks/

In this example we aren't looking at images, so we need to define our own dataset class. We create one by inheriting from the `torch.utils.data.Dataset` class, and we need to make sure to define the `__getitem__` and `__len__` functions.

In [None]:
import math
import pickle
import torch
import csv

# Useful class to hold all of the information that we have about our tracks
class InputTrack():
    def __init__(self, x, y, z, q, pdg, trackid, n_child_trk, n_child_shw, n_grandchild):
        self.x = x
        self.y = y
        self.z = z
        self.q = q
        self.pdg = pdg
        self.trackid = trackid
        self.n_child_trk = n_child_trk
        self.n_child_shw = n_child_shw
        self.n_grandchild = n_grandchild

class TrackDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, file_name, sequence_length=128, transform=None):
        self.file_path = file_path
        self.input_files = []
        self.get_input_files(file_name)
        self.sequence_length = sequence_length

    # Load the list of input file paths for the tracks
    def get_input_files(self, file_name):
        with open(self.file_path + file_name, 'r') as input_file_list:
            self.input_files = [row[0] for row in csv.reader(input_file_list)]
        print('Found',len(self.input_files),'tracks')

    # Label values: 0 = muon, 1 = pion, 2 = proton
    def convert_pdg_to_label(self, pdg):
        if abs(pdg) == 13:
            return 0
        elif abs(pdg) == 211:
            return 1
        elif abs(pdg) == 2212:
            return 2
        else:
            print('Track found with wrong pdg code... exiting')
            exit()

    # Function to pad short sequences at the start with values -1e9
    def pad_sequence(self, sequence):
        n_padding = self.sequence_length - sequence.size(0)
        sequence = torch.nn.functional.pad(sequence, (n_padding, 0), value=-1e9)
        return sequence

    # Get the track of interest
    def __getitem__(self, index):
        track = None
        # The pickle file location comes from the list of files
        with open(self.file_path + self.input_files[index], 'rb') as f:
            track = pickle.load(f)

        # Convert the InputTrack information into torch tensors
        x = torch.tensor(track.x, dtype=torch.float)
        y = torch.tensor(track.y, dtype=torch.float)
        z = torch.tensor(track.z, dtype=torch.float)
        q = torch.tensor(track.q, dtype=torch.float)

        n_hits = x.size(0)

        # Pad (prepending the sequence) the tensors if we need to
        if n_hits < self.sequence_length:
            x = self.pad_sequence(x)
            y = self.pad_sequence(y)
            z = self.pad_sequence(z)
            q = self.pad_sequence(q)
        # Crop the tensors if we need to
        elif n_hits > self.sequence_length:
            x = x[(n_hits - self.sequence_length):]
            y = y[(n_hits - self.sequence_length):]
            z = z[(n_hits - self.sequence_length):]
            q = q[(n_hits - self.sequence_length):]

        # Get the class label from the pdg code
        class_label = torch.tensor(self.convert_pdg_to_label(track.pdg), dtype=torch.long)
        # Create the sequence of hits (charge is already normalised in the input)
        data = torch.stack([x / 1000., y / 1000., z / 1000., q])
        # Get the additional variables (empirical scaling for n_hits)
        auxillary = torch.tensor([track.n_child_trk, track.n_child_shw,
                                  track.n_grandchild, math.log10(n_hits) / 5.0],
                                  dtype=torch.float32)

        return (data.transpose(0,1), auxillary), class_label

    def __len__(self):
        return len(self.input_files)

In [None]:
import numpy as np

def GetDatasets(batch_size, sequence_length):
    # Load the dataset and divide into train, validation and test samples, and
    # set the sequence length for padding / truncating purposes
    dataset = TrackDataset('tracks/','contained_track_files.txt', sequence_length)
    indices = np.arange(len(dataset))
    np.random.seed(42)
    np.random.shuffle(indices)

    # Define split points
    train_idx, val_idx, test_idx = np.split(indices, [int(0.7*len(indices)), int(0.9*len(indices))])
    print(len(train_idx), len(val_idx), len(test_idx))

    # Create samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)

    # Create data loaders
    training_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

    return training_loader, val_loader, test_loader

# =============================================================================

def count_particle_types(data_loader):

    n_muons = 0
    n_pions = 0
    n_protons = 0
    for (_, labels) in data_loader:
        n_muons += (labels == 0).sum().item()
        n_pions += (labels == 1).sum().item()
        n_protons += (labels == 2).sum().item()

    return n_muons, n_pions, n_protons

Now we just load our data into the three Dataloaders and we can check what breakdown of event types we have in each sample.

In [None]:
# Dataset split into three dataloaders
batch_size = 128 # Number of tracks processed in parallel
sequence_length = 128 # Number of hits per track allowed

# Load the datasets
train_loader, val_loader, test_loader = ...

n_train_muons, n_train_pions, n_train_protons = count_particle_types(train_loader)
print("Training sample breakdown:", count_particle_types(train_loader))
print("Validation sample breakdown:", count_particle_types(val_loader))
print("Test sample breakdown:", count_particle_types(test_loader))

We are now going to define our model. Previously we used the `torch.nn.Sequential` class to do this as it is straightfoward. It is, however, also quite limited in what you can do. Since each track has two sets of inputs we need to write our own custom model that inherits from the `torch.nn.Module` class.

We'll need some different network layers in addition to those we used for the MLP and CNN:
* `torch.nn.Embedding(num_embeddings, embedding_dim)`
* `torch.nn.LayerNorm(normalized_shape)`
* `torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, ..., batch_first)`
* `torch.nn.TransformerEncoder(encoder_layer, num_layers)`
* `torch.nn.AdaptiveAvgPool1d(dimension)`
Note that in our case we need to set `batch_first=True` for our encoder layers

In [None]:
class TrackPIDNetwork(torch.nn.Module):
    def __init__(self, n_features, n_classes, sequence_length, model_depth, n_heads,
                 feed_forward_depth, n_encoder_layers, dropout, n_auxillary):
        super(TrackPIDNetwork, self).__init__()

        # Input mapping uses a linear layer to expand from n_features to model_depth
        self.input_mapping = ...
        # We use a learned embedding for the position encoding. This is basically
        # a look-up table with sequence_length entries of size model_depth
        self.position_encoding = ...

        # Layer normalisation with shape equal to model_depth
        self.layer_norm = ...

        # The encoder itself. We have to define the encoder layers themselves
        # before the actual entire encoder. The encoder is effectively a stack
        # of the encoder layers
        self.encoder_layer = ...
        self.encoder = ...

        # Adaptive pooling layer to reduced to a dimension of 1
        self.pooling = ...

        # Flatten layer
        self.flatten = ...

        # The outputs for particle id (n_classes)
        self.classifier = ...

    def forward(self, x, auxillary):
        # Create the padding mask
        padding_mask = self.create_mask(x)

        # Embedding and position encoding
        x = ...
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0).repeat(x.size(0), 1)
        x = ...

        # Layer normalisation
        x = ...

        # Run the encoder, remembering we need to pass in the padding mask
        x = ...

        # Prepare for classification
        x = x.transpose(2,1) # Need to swap model_depth and sequence dimensions
        x = ... # After pooling the shape is now (batch, model_depth, 1)
        x = ... # Flatten to remove the last dimension

        # Concatenate with the auxillary variables
        x = torch.cat((x, auxillary), dim=1)
        # Final classification layer
        output = ...

        return output

    def create_mask(self, x):
        # Try making a mask
        mask = (x == -1e9).all(dim=-1)
        return mask

This function performs one epoch of either training or validation, depending on the value of `is_training`. It is completely analagous to the code we wrote before but just encapsulated in a function for simplicity

In [None]:
def one_epoch(epoch, data_loader, pid_lossfn, optimiser, initial_lr, is_training=True):

    epoch_loss = 0.0

    # Updating the learning rate is specific to training
    if is_training:
        my_model.train()
        # With transformers it is often a good idea to "warm up". The easiest
        # way to do this is to just ramp to your target learning rate. You'll
        # see later that we'll choose 1e-3 as our learning rate, so here we
        # slowly build up from 1e-4 to 1e-3 over the first 10 epochs
        if epoch < 10:
            lr = initial_lr * (epoch + 1) / 10
            print("Learning rate warmup:", lr)
            for param_group in optimiser.param_groups:
                param_group['lr'] = lr
        else:
            cosine_scheduler.step()
            for param_group in optimiser.param_groups:
                print("Learning rate cosine annealing:", param_group['lr'])
                break
    else:
        my_model.eval()

    # Loop over all of the batches in the dataset
    for batch_no, (data, labels) in enumerate(data_loader):
        data = (data[0].to(device), data[1].to(device))
        labels = labels.to(device)
        batch_loss = None

        # If training then we need to remember to do the back propagation
        if is_training:
            outputs = my_model(data[0], data[1])
            batch_loss = pid_lossfn(outputs, labels)
            optimiser.zero_grad()
            batch_loss.backward()
            optimiser.step()
        # For validation we explicitly tell torch not to calculate gradients
        # as it takes time and memory and isn't useful
        else:
            with torch.no_grad():
                outputs = my_model(data[0], data[1])
                batch_loss = pid_lossfn(outputs, labels)

        epoch_loss += batch_loss.item()

    return epoch_loss / len(data_loader)

Now let's get an instance of the network after defining some important parameters defining the specific characteristics that we want.

In [None]:
n_features = 4 # This is {x, y, z, q} for each hit
n_classes = 3 # 0 = muon, 1 = pion, 2 = proton

model_depth = 64 # Model depth of the transformer
n_heads = 4 # Number of attention heads
feed_forward_depth = 256 # Size of the feed-forward layer
n_encoder_layers = 2 # Number of stacked encoders
dropout = 0.3 # Dropout fraction used in the encoder
n_auxillary = 4 # N track children, n shower children, n total descendants, n hits

# Use a GPU if we have one available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device', device)

# Create the network with the required parameters
my_model = ...

model_parameters = filter(lambda p: p.requires_grad, my_model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Total parameters",params)

In [None]:
# Define some weights to deal with the class imbalance of the dataset
pid_weights = torch.ones(3).to(device)
pid_weights[1] = n_train_muons / n_train_pions
pid_weights[2] = n_train_muons / n_train_protons
print(pid_weights)

# This time you need to use the `weight` argument with the cross entropy loss
# function to help with class imbalance!
pid_loss_fn = ...
initial_lr = 1e-3 # N.B. 1e-4 is often a good choice for transformers
# Let's use the AdamW optimiser again
optimiser = torch.optim.AdamW(my_model.parameters(), lr=initial_lr)
# Including this as an example of something that varies the learning rate. The
# first transformer was trained using CosineAnnealing - it basically varies
# the learning rate in a periodic fashion
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=25)

my_model.to(device)

# Do the training and validation for n_epochs. Here I have set things up to keep
# track of the epoch with the best validation loss so that we can stop if things
# begin to get worse due to overtraining
n_epochs = 10
patience = 10
training_losses = {}
validation_losses = {}
best_epoch_loss = 1e9
best_epoch = -1
best_epoch_state = None
n_epoch_no_improvement = 0
for e in range(0,n_epochs):
    # Training
    one_epoch_train_loss = one_epoch(e, train_loader, pid_loss_fn, optimiser, initial_lr, True)
    # Validation
    one_epoch_valid_loss = one_epoch(e, val_loader, pid_loss_fn, optimiser, initial_lr, False)
    # Print the loss and store
    print('Epoch', e, 'training loss =', one_epoch_train_loss, ' and validation loss', one_epoch_valid_loss)
    training_losses[e] = one_epoch_train_loss
    validation_losses[e] = one_epoch_valid_loss
    # Keep track of the best epoch
    if one_epoch_valid_loss < best_epoch_loss:
        best_epoch = e
        best_epoch_loss = one_epoch_valid_loss
        best_epoch_state = my_model.state_dict()
        n_epoch_no_improvement = 0
    else:
        n_epoch_no_improvement += 1

    # If we haven't improved for a while then stop training
    if n_epoch_no_improvement >= patience:
        print("No improvement in validation loss for", patience, "epochs, stopping training")
        break




Since we've used the validation loss to determine which epoch is the best, then we use an independent sample, here called the test sample, to actually benchmark the performance of the network. I've written the loop here as opposed to using the `one_epoch` function as I build a confusion matrix to show how well we have done

In [None]:
from sklearn.metrics import confusion_matrix

print("Loading model state from best epoch (", best_epoch, ")")
my_model.load_state_dict(best_epoch_state)

# Run the test sample
test_loss = 0.0

# Make sure we are in evaluation mode without gradients
my_model.eval()
test_confusion_matrix = np.zeros((3,3), dtype=float)
with torch.no_grad():
    for batch_no, (data, labels) in enumerate(test_loader):
        data = (data[0].to(device), data[1].to(device))
        labels = labels.to(device)

        outputs = my_model(data[0], data[1])
        batch_loss = pid_loss_fn(outputs, labels)
        test_loss += batch_loss.item()

        outputs_as_class_cpu = outputs.argmax(dim=1).cpu()
        labels_cpu = labels.cpu()

        test_confusion_matrix += confusion_matrix(outputs_as_class_cpu.numpy(),
                                                  labels_cpu.numpy(),
                                                  labels=np.arange(3))

test_loss = test_loss / len(test_loader)

print(test_loss)
print(test_confusion_matrix)

This last bit of code just normalises the matrix per column so that you can see what fraction of muons, pions and protons were classifed as each class. At least in HEP we'd describe the diagonal of the matrix as the efficiency.

In [None]:
# We can also normalise this matrix by row to see what percentage of each true
# class tracks are classified as each of the three classes.
n_test_muons, n_test_pions, n_test_protons = count_particle_types(test_loader)
test_confusion_matrix[:,0] /= n_test_muons
test_confusion_matrix[:,1] /= n_test_pions
test_confusion_matrix[:,2] /= n_test_protons
print(test_confusion_matrix)