# Transformers Tutorial: Part III

In [1]:
!pip install einops

Defaulting to user installation because normal site-packages is not writeable
[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/torchvision-0.21.0+7af6987-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/setuptools-75.8.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/pillow-11.1.0-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible repl

In [2]:
from dataloader import load_data
import torch
import torch.nn as nn
import utils
from network import PET2

### Let's open the training and validation files containing examples for top quarks (signal) and QCD jets (background)

In [3]:
input_folder = '/global/cfs/cdirs/trn016/transformer'
train_data = load_data('top',input_folder,batch=256,dataset_type='train',num_evt = 100_000)
val_data = load_data('top',input_folder,batch=256,dataset_type='val')

In [4]:
print (f"Loading {len(train_data)} batches of events for training and {len(val_data)} for validation")

Loading 390 batches of events for training and 1574 for validation


### Let's now load the PET Model

In [5]:
config = {
    'input_dim':4,
    'hidden_size': 128,
    'num_transformers': 8, #number of transformer blocks used
    'num_transformers_head':2, #number of transformer blocks used in the task-specific block
    'num_heads':8, #number of heads for multi-head attention
    'K':10, #number of neighbors considered for the kNN
}

In [6]:
model = PET2(**config) #remember the inputs are delta eta, delta phi, log(pT), log(E)

### Now we are going to create the training class that will train the model, but first, let's set up the learning rate and the optimizer

In [7]:
optimizer = torch.optim.Adam
lr = 5e-4
epochs = 10
patience = 10 # Number of consecutive epochs to stop the training if the validation loss does not improve

In [8]:
trainer = utils.Trainer(train_data,val_data,model,lr,optimizer)

### Let's train the model!

In [None]:
trainer.train(epochs)

Epoch 1: train loss=0.2764, validation loss=0.2403


### Now let's evaluate the model

In [None]:
test_data = load_data('top',input_folder,batch=128,dataset_type='test')
predictions, labels = trainer.evaluate(test_data)

In [None]:
utils.print_metrics(predictions,labels)

### Now let's load the pre-trained weights

In [None]:
utils.restore_checkpoint(model,input_folder,'best_model_pretrain_s.pt')
#These messages are all fine and related to model layers that are not relevant for classiciation tasks

In [None]:
optimizer = torch.optim.Adam
lr = 5e-5 #Notice the learning rate is much smaller than before
epochs = 10
patience = 10 # Number of consecutive epochs to stop the training if the validation loss does not improve
trainer = utils.Trainer(train_data,val_data,model,lr,optimizer)

In [None]:
trainer.train(epochs)

### Because the pre-trained model already starts from useful weights, they are quicker to overtrain

In [None]:
predictions, labels = trainer.evaluate(test_data)

In [None]:
utils.print_metrics(predictions,labels)

### Try changing the hyperparameters of the model to see if you can improve the results!