# Transformers Tutorial: Part III

In [None]:
from dataloader import load_data
import torch
import torch.nn as nn
import utils
from network import PET2

### Let's open the training and validation files containing examples for top quarks (signal) and QCD jets (background)

In [None]:
input_folder = '/global/cfs/cdirs/trn016/transformer'
train_data = load_data('top',input_folder,batch=256,dataset_type='train',num_evt = 100_000)
val_data = load_data('top',input_folder,batch=256,dataset_type='val')

In [None]:
print (f"Loading {len(train_data)} batches of events for training and {len(val_data)} for validation")

### Let's now load the PET Model

In [None]:
config = {
    'input_dim':4,
    'hidden_size': 128,
    'num_transformers': 8, #number of transformer blocks used
    'num_transformers_head':2, #number of transformer blocks used in the task-specific block
    'num_heads':8, #number of heads for multi-head attention
    'K':10, #number of neighbors considered for the kNN
}

In [None]:
model = PET2(**config) #remember the inputs are delta eta, delta phi, log(pT), log(E)

### Now we are going to create the training class that will train the model, but first, let's set up the learning rate and the optimizer

In [None]:
optimizer = torch.optim.Adam
lr = 5e-4
epochs = 10
patience = 10 # Number of consecutive epochs to stop the training if the validation loss does not improve

In [None]:
trainer = utils.Trainer(train_data,val_data,model,lr,optimizer)

### Let's train the model!

In [None]:
trainer.train(epochs)

### Now let's evaluate the model

In [None]:
test_data = load_data('top',input_folder,batch=128,dataset_type='test')
predictions, labels = trainer.evaluate(test_data)

In [None]:
#utils.print_metrics(predictions,labels)

### Now let's load the pre-trained weights

In [None]:
utils.restore_checkpoint(model,input_folder,'best_model_pretrain_s.pt')
#These messages are all fine and related to model layers that are not relevant for classiciation tasks

In [None]:
optimizer = torch.optim.Adam
lr = 5e-5
epochs = 10
patience = 10 # Number of consecutive epochs to stop the training if the validation loss does not improve
trainer = utils.Trainer(train_data,val_data,model,lr,optimizer)

In [None]:
trainer.train(epochs)

In [None]:
predictions, labels = trainer.evaluate(test_data)

In [None]:
utils.print_metrics(predictions,labels)

### Try changing the hyperparameters of the model to see if you can improve the results!