In this note book we will attempt to classify a set of neutrino interactions as either CC $\nu_\mu$, CC $\nu_e$ and NC $\nu$ events using a transformer encoder

In [None]:
import torch
import torchvision

print("torch version:", torch.__version__)
print("torchvision version:", torchvision.__version__)

Let's load the dataset. This is a sample of 30,000 images from a simple LArTPC simulation using GENIE input neutrino events containing equal numbers of CC $\nu_\mu$, CC $\nu_e$ and NC $\nu$ interactions. It will save the `.png` images to the `images` directory.

In [None]:
import os

# Load the neutrino dataset:
if not os.path.isfile('images/images.tgz'):
  !mkdir images
  !wget --no-check-certificate 'https://www.hep.phy.cam.ac.uk/~lwhitehead/genie_neutrino_images.tgz' -O images/images.tgz
  !tar -xzf images/images.tgz -C images/

# Work out the number of classes form the directory structure
root_dir = 'images/'
dir_contents = os.listdir(root_dir)
num_classes = sum(os.path.isdir(os.path.join(root_dir, item)) for item in dir_contents)

print('Dataset consists of', num_classes, 'classes')

class_names = ['CC numu', 'CC nue', 'NC']
for c in range(num_classes):
  print('Number of',class_names[c],'images:')
  !ls -1 images/$c/*.png | wc -l

We need to manipulate the input images a bit to get them into the prefered format. We also downsample them by a factor of two for convenience here (to save time for training the networks)

In [None]:
import numpy as np

# We need to define a transform to resize and scale the images when loaded
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((112, 112)),   # reduce size (they are 224 x 224)
    torchvision.transforms.ToTensor(),           # convert to tensor [0,1]
    torchvision.transforms.Lambda(lambda x: x[2].unsqueeze(0)) # extract the w view
])

# Now we can use a torchvision dataset to load these images
dataset = torchvision.datasets.ImageFolder(root="images/", transform=transform)
print("Dataset classes:", dataset.classes)       # list of class names (sorted by folder name)

# Now we need to divide this into train and validation dataloader objects
np.random.seed(24601)
indices = np.arange(len(dataset))
np.random.shuffle(indices)

# Define split points
train_idx, val_idx, = np.split(indices, [int(0.7*len(indices))])
print("Using", len(train_idx), "images for training and", len(val_idx), "for validation")

# Create samplers
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)

# Create dataloaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)

It is always a good idea to visualise your data to make sure it looks how you expect it to. With image-based inputs then it is especially easy to do this!

In [None]:
import matplotlib.pyplot as plt

numu_event = dataset.__getitem__(1)
print('True class:', numu_event[1])
fig, axes = plt.subplots(1,3)
axes[0].imshow(numu_event[0][0])

nue_event = dataset.__getitem__(10001)
print('True class:', nue_event[1])
axes[1].imshow(nue_event[0][0])

nc_event = dataset.__getitem__(20001)
print('True class:', nc_event[1])
axes[2].imshow(nc_event[0][0])

Now we can define our transformer encoder. This is a bit more complicated that before because we need to divide our images up into patches so that they can be input to a sequential-style model. Follow the comments in the code below, but as a brief overview, the new network layers of interest here are:
* `torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, ..., batch_first)`
* `torch.nn.TransformerEncoder(encoder_layer, num_layers)`
* `torch.nn.AdaptiveAvgPool1d(dimension)`

In [None]:
# Now let's try making a transformer encoder to do the same job. We'll need to
# break the image up into patches in order to encode it for the transformer
patch_dim = 14
# Calculate the sequence length (= number of patches)
seq_length = ...
print('Number of patches =', seq_length)
# Set the model depth to 64
model_depth = 64

# Small class to allow us to do a transpose of a tensor in torch.nn.Sequential
# This is needed to change the position of the patch dimension in the tensor
class Transpose(torch.nn.Module):
    def __init__(self, dim1, dim2):
        super().__init__()
        self.dim1, self.dim2 = dim1, dim2
    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)

# We have to define the transformer encoder layer too. We have to set the
# batch_first argument to True, since for some reason the PyTorch implementation
# has the batch last, meaning we'd need to perform a transpose otherwise.
# We need to set the model depth and the number of attention heads, let's use
# two for the latter. A typical choice for the feed-forward network dimension
# is four times the model depth
encoder_layer = torch.nn.TransformerEncoderLayer(
    ..., # Set the model depth
    ..., # Set the number of attention heads
    ..., # Set the feed-forward network dimension
    batch_first=True
)

# Now for the sequential model itself. Note that a lot of this is specific
# to a Vision Transformer (Encoder) as opposed to what you would use for
# generic sequence information
transformer_model = torch.nn.Sequential(
    # This first convolution is a neat trick to do the patching and encoding
    # for us in a single layer
    torch.nn.Conv2d(1, model_depth, kernel_size=patch_dim, stride=patch_dim),
    # Flatten the tensor in the second dimension
    ...,
    # Transpose dimensions 1 and 2 for the encoder
    Transpose(1,2),
    # Create the encoder from the encoder layer, and choose two such layers
    ...,
    # Transpose back to the original dimension order
    Transpose(2,1),
    # Adaptive average pooling pooling layer applied to dimension 1
    ...,   # (batch, model_depth, 1)
    # Flatten the tensor in dimension 1 this time
    ...,             # (batch, model_depth)
    # Final linear layer to predict the three classes
    ...
    # Expected softmax activation is implicit in the loss function
)

# Let's pass an image through the network just to check the output
# is of the expected shape
test_image = dataset.__getitem__(0)[0]
test_image = test_image.unsqueeze(0)
print(test_image.shape)
print(transformer_model(test_image).shape)

n_params = sum(p.numel() for p in transformer_model.parameters() if p.requires_grad)
print('Number of trainable parameters =', n_params)

In [None]:
# Use a GPU if we request one
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device', device)
transformer_model.to(device)

In [None]:
# Prepare for training
learning_rate = 0.0001
transformer_loss_fn = torch.nn.CrossEntropyLoss()
transformer_optimiser = torch.optim.AdamW(transformer_model.parameters(), lr=learning_rate, weight_decay=1e-2)

In [None]:
# Training loop
n_epochs = 10

for epoch in range(0, n_epochs):
  transformer_model.train()
  running_loss = 0.0
  for (images, labels) in train_loader:
    images = images.to(device)
    labels = labels.to(device)

    # Forward pass
    outputs = transformer_model(images)
    loss = transformer_loss_fn(outputs, labels)

    # Backward pass and optimisation
    transformer_optimiser.zero_grad()
    loss.backward()
    transformer_optimiser.step()

    running_loss += loss.item()

  # Validation
  running_val_loss = 0.0
  transformer_model.eval()
  with torch.no_grad():
    for (images, labels) in val_loader:
      images = images.to(device)
      labels = labels.to(device)

      # Make the predictions
      outputs = transformer_model(images)
      loss = transformer_loss_fn(outputs, labels)
      running_val_loss += loss.item()

  print("Epoch", epoch, "training loss:", running_loss/len(train_loader), "validation loss:", running_val_loss/len(val_loader))

In [None]:
# Make a list of incorrect classifications
def get_incorrect_classifications(model, dataloader):
  incorrect_indices = []
  with torch.no_grad():
    for (images, labels) in dataloader:
      images = images.to(device)
      predictions = model(images).cpu().numpy()

      for i in range(len(labels)):
        prediction = np.argmax(predictions[i])
        truth = labels[i].numpy()
        if prediction != truth:
          image = images[i].cpu().numpy()
          image = image.transpose([1,2,0])
          incorrect_indices.append([image, prediction, truth])

  print('Accuracy =',1 - len(incorrect_indices)/len(val_idx))
  return incorrect_indices

# Now you can modify this part to draw different images from the failures list
# You can change the value of im to look at different failures
def draw_event(incorrect_indices, index):
  image_to_plot = incorrect_indices[index][0]
  image_to_plot = np.clip(image_to_plot, 0.0, 1.0)
  fig, ax = plt.subplots(1, 1)
  print('Incorrect classification for image',index,
        ': predicted =',incorrect_indices[index][1],
        'with true =',incorrect_indices[index][2])
  ax.imshow(image_to_plot)

In [None]:
incorrect_transformer_indices = get_incorrect_classifications(transformer_model, val_loader)

In [None]:
draw_event(incorrect_transformer_indices, 2)