In [1]:
import torch

In [2]:
from train_particle_transformer import ParticleTransformer

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ParticleTransformer().to(device)

# Load the saved model state
checkpoint = torch.load('best_particle_transformer.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# Set model to evaluation mode
model.eval()

Loaded length stats: {'diff_mean': 7.4578, 'diff_std': 4.97939947784871}


ParticleTransformer(
  (length_predictor): LengthPredictor(
    (particle_embedding): Linear(in_features=4, out_features=256, bias=False)
    (sequence_layers): ModuleList(
      (0-2): 3 x Sequential(
        (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=256, out_features=256, bias=False)
        (2): ReLU()
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (attention): Sequential(
      (0): Linear(in_features=256, out_features=1, bias=False)
      (1): Softmax(dim=1)
    )
    (length_head): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=False)
      (1): ReLU()
      (2): Dropout(p=0.1, inplace=False)
      (3): Linear(in_features=128, out_features=1, bias=False)
    )
  )
  (particle_embedding): Linear(in_features=4, out_features=256, bias=False)
  (sequence_layers): ModuleList(
    (0-2): 3 x Sequential(
      (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=256,

In [6]:
import energyflow as ef
from train_particle_transformer import ParticleDataset
from torch.utils.data import DataLoader
import numpy as np

# Load test data
print("Loading test data...")
data = ef.zjets_delphes.load(
    "Herwig",
    pad=False, 
    cache_dir="../data",
    source="zenodo",
    which="all"
)

# Create test dataset
test_sim = data["sim_particles"]
test_gen = data["gen_particles"]
test_dataset = ParticleDataset(test_sim, test_gen)

# Create dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False
)

# Get a single batch
batch = next(iter(test_loader))
sim_features = batch["sim_features"].to(device, dtype=torch.float32)

Loading test data...


In [11]:

# Generate sequence
with torch.no_grad():
    predicted_seq = model.generate(sim_features)

# Convert to numpy for analysis
predicted_seq = predicted_seq.cpu().numpy()
gen_seq = batch["gen_features"].numpy()[:,:batch["gen_length"][0]]

print("Predicted sequence shape:", predicted_seq.shape)
print("Generated sequence shape:", gen_seq.shape)

# Print first 5 particles of each sequence
print("\nSimulated sequence (first 5):")
print(sim_features[0].cpu().numpy()[:5])

print("\nGenerated sequence (first 5):")
print(gen_seq[0][:min(5, batch["gen_length"][0])])

print("\nPredicted sequence (first 5):") 
print(predicted_seq[0][:5])

Predicted sequence shape: (1, 26, 4)
Generated sequence shape: (1, 25, 4)

Simulated sequence (first 5):
[[ 0.11039242 -0.10999276  0.333286    0.1       ]
 [ 0.02987454 -0.11790374  0.15472978  0.2       ]
 [ 0.01788264 -0.14657624  0.19709224  0.2       ]
 [ 0.01106268 -0.20399679  0.25202745  0.1       ]
 [ 0.2119501  -0.01510048 -0.03257734  0.2       ]]

Generator sequence (first 5):
[[ 0.22400826  0.00229552 -0.07039601  0.9       ]
 [ 0.0056566  -0.22012563 -0.01429803  0.2       ]
 [ 0.00815064  0.20607793 -0.07526214  1.        ]
 [ 0.03117605 -0.10038943  0.11691111  0.2       ]
 [ 0.01304285  0.10844737 -0.1567741   0.2       ]]

Predicted sequence (first 5):
[[ 0.          0.          0.          0.        ]
 [-0.00075159 -0.00068024 -0.00205913  0.00864768]
 [-0.0012172   0.00240981  0.01172171  0.04643966]
 [ 0.00589175 -0.00944618  0.03549908  0.10920764]
 [ 0.01533815 -0.02036647  0.04667786  0.13387969]]
