# Transformer Encoder for Energy Reconstruction in Telescope Array Experiment

This notebook demonstrates how to build a native PyTorch transformer-based neural network for analyzing cosmic ray data from the Telescope Array experiment. 
The model processes data from activated detectors for each cosmic-ray-induced event and predicts the energy of the primary particle.

# 1. Import required libraries

We will use:

- PyTorch: For creating and training neural networks
- NumPy: For numerical operations and data handling
- h5py: For reading HDF5 files containing our experimental data

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader

import numpy as np
import h5py as h5

# 2. Data Generation and Preprocessing

In the Telescope Array experiment, cosmic ray events are detected by an array of surface detectors spread across a large area.
We will use data that was passed thorough the reconstruction procedure and passes both *composition* and *spectrum* cuts. 

Full information on an event is given by a set of all triggered detectors.  
Each detector is characterized by 5 features:
- Its x, y, z coordinates (spatial position)
- Integral registered charge (energy deposited)
- Time of the plane from arrival (obtained from the reconstruction procedure)
- Difference in time between plane front arrival and and actual activation (helps analyze wavefront curvature)

The number of activated detectors varies from event to event. 
To pass them through NN, one should cast them to a unform "length" in the following way:
- For each batch, we find the maximum number of triggered detectors (`max_event_length`)
- Events with fewer detectors are padded with "auxiliary detectors" (zeros)
- We add a mask channel (value 1 for real detectors, 0 for padding) to allow the network to distinguish real data from padding

This results in input tensors with shape (batch_size, max_event_length, 7) where the last dimension includes the 6 detector features plus the mask. The neural network will be designed to ignore these auxiliary detectors.

Our neural network will predict the logarithm (base 10) of the primary particle energy. The ground truth values are extracted from the simulation data.

## Some technical remarks

Datsets has an option to augment data with noise. This allows to avoid overfitting and make NNs prediction more robust.

For convenience, the train dataset is made infinite via self-looping.

Detectors data is stored in a two dimensional array `dt_params` with shape `(total_detectors, 5)`, where `total_detectors` is the total number of detectors activated in all events (all detectors data is concatenated in a single array).
External indexing array `ev_starts` is used to extract data for a required event: for i-th events, the corresponding data is `data_i = dt_params[start:stop]`, where `start=ev_starts[i]` and `start=ev_starts[i+1]`.
In particular, event length array can be obtained as `np.diff(ev_starts)`.

In [2]:
# Default padding values for sequences shorter than the maximum length
dense_def_vals = torch.tensor([[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]], dtype=torch.float32)

class DatasetGenerator(IterableDataset):
    """
    Generates batches of detector data from HDF5 files for training and testing.
    
    - Handles variable-length sequences via padding
    - Supports data augmentation with Gaussian noise (additive and multiplicative)
    - Creates an infinite dataset for training via self-loop
    - Properly initializes workes so each of them reads its own part of data
    """
    def __init__(self, file, regime, batch_size, return_reminder,
                 apply_add_gauss, gauss_stds,
                 apply_mult_gauss, mult_gauss_std
                 ):
        """
        Initialize the dataset generator.
        
        Parameters:
        - file: Path to the HDF5 file containing the training data
        - regime: 'train' or 'test' mode
        - batch_size: Number of events per batch
        - return_reminder: Whether to return the last incomplete batch
        - apply_add_gauss: Apply additive Gaussian noise for data augmentation
        - gauss_stds: Standard deviations for the additive noise in physical units
        - apply_mult_gauss: Apply multiplicative Gaussian noise
        - mult_gauss_std: Standard deviation for multiplicative noise
        """
        self.file = file
        self.regime = regime
        self.batch_size = batch_size
        self.apply_add_gauss = apply_add_gauss
        self.apply_mult_gauss = apply_mult_gauss
        self.g_mult_stds = mult_gauss_std

        # Get normalization parameters and dataset size from the HDF5 file
        with h5.File(self.file,'r') as hf:
            mean = hf['norm_param/dt_params/mean'][3]
            std = hf['norm_param/dt_params/std'][3]
            self.num = hf[self.regime+'/ev_starts'].shape[0]-1
        
        # Configure noise parameters
        if apply_add_gauss:
            self.guass_add_stds = gauss_stds / std
        if apply_mult_gauss:
            self.Q_mean_noise = mean / std
            self.n_fraction = mult_gauss_std

        # Determine the stop index (end of dataset or last complete batch)
        batch_num = self.num // self.batch_size
        self.stop = self.num if return_reminder else self.batch_size * batch_num

    # Add Gaussian noise to the data for augmentation.
    def add_gauss(self, data, std):
        noise = np.random.normal(scale=self.guass_add_stds, size=data.shape)
        data += noise
        return data

    # Apply multiplicative Gaussian noise to charge values.
    def mult_gauss(self, Qs):
        noises = np.random.normal(scale=self.n_fraction, size=Qs.shape)
        return Qs + noises * (Qs + self.Q_mean_noise)

    # 
    def step(self, hf, start_ev, stop_ev, start_det, stop_det):
        """
        Process a batch of events from the HDF5 file.
        - start_ev, stop_ev: Start and stop indices for events
        - start_det, stop_det: Start and stop indices for detectors

        Returns:
        - dt_params: Detector parameters for events (as 2D array)
        - energy_labels: True energy values (log10 scale)
        """
        # Read detector parameters for events
        dt_params = hf[self.regime+'/dt_params'][start_det:stop_det]
        # Extract energy labels (log10 of primary particle energy)
        energy_labels = np.log10(hf[self.regime+'/mc_params/'][start_ev:stop_ev,3:4])
        # Apply data augmentation if enabled
        if self.apply_add_gauss:
            dt_params = self.add_gauss(dt_params, self.guass_add_stds)
        if self.apply_mult_gauss:
            dt_params[...,3] = self.mult_gauss(dt_params[...,3])
        
        return dt_params, energy_labels

    def __iter__(self):
        """
        An iterator that yields batches of data.
    
        For training data, this creates an infinite dataset via self-loop.
        For test data, this iterates once through the dataset.
        
        Yields:
        - padded: Padded detector data with mask [batch_size, max_seq_len, 7]
        - labels: Energy labels [batch_size, 1]
        """
        
        # Initialize multiple workers
        # Get worker information
        worker_info = torch.utils.data.get_worker_info()
        # Determine the range of data this worker should process
        if worker_info is None:  # single-process data loading
            worker_start = 0
            worker_end = self.stop
        else:  # in a worker process
            # Split workload 
            per_worker = int(np.ceil(self.stop / float(worker_info.num_workers)))
            worker_id = worker_info.id
            
            worker_start = worker_id * per_worker
            worker_end = min(worker_start + per_worker, self.stop)
            
            # Adjust to batch boundaries
            worker_start = (worker_start // self.batch_size) * self.batch_size
            worker_end = min(((worker_end + self.batch_size - 1) // self.batch_size) * self.batch_size, self.stop)
        
        # Open the HDF5 file within __iter__ so that each worker gets its own handle.
        with h5.File(self.file, 'r') as hf:
            start_ev = worker_start
            
            iterate = True
            while iterate:
                stop_ev = start_ev + self.batch_size
                
                # Check if we've reached the end of this worker's range
                if stop_ev > worker_end:
                    # For training, make infinite dataset by resetting to start
                    if self.regime == 'train':
                        # Reset back to the start of this worker's range
                        start_ev = worker_start
                        stop_ev = start_ev + self.batch_size
                    else:
                        iterate = False
                
                # Read detector indices for events
                ev_idxs = hf[self.regime+'/ev_starts'][start_ev:stop_ev+1]
                # Get detector parameters and energy labels
                dt_params, labels = self.step(hf, start_ev, stop_ev, ev_idxs[0], ev_idxs[-1] )

                # Make regular tensors
                # Calculate the number of detectors per event
                raw_lens = np.diff(ev_idxs).astype(np.int64)
                max_len = raw_lens.max() # Maximum sequence length in this batch

                # Convert the actual data to torch tensors.
                data = torch.from_numpy(dt_params)      # shape: (total_dets, 6)
                labels = torch.from_numpy(labels)       # shape: (total_evs, 1)
                
                # Create mask: 1 for real detectors, 0 for padding
                mask = torch.ones((data.shape[0], 1), dtype=torch.float32) # shape: (total_dets, 1)
                # Concatenate detector features and mask
                data = torch.cat([data, mask], dim=-1)
                
                # Preallocate padded tensor: shape (batch_size, max_len, 7), filled with default values.
                padded = torch.tile(dense_def_vals, (labels.shape[0], max_len, 1))
                
                # Create a boolean mask with shape (batch_size, max_len).
                # For each event, positions [0, raw_lens[i]) are True.
                mask = np.arange(max_len)[None, :] < raw_lens[:, None]  # shape: (batch_size, max_len)
                mask_tensor = torch.from_numpy(mask)
                
                # Fill the padded tensor with actual data
                padded[mask_tensor] = data

                # Move to next batch
                start_ev += self.batch_size

                yield padded.float(), labels.float() # also convert to float32

def make_datasets(file, batch_size,
                 apply_add_gauss, gauss_stds,
                 apply_mult_gauss, mult_gauss_std):
    """
    Create train and test datasets.
    
    Parameters:
    - file: Path to the HDF5 file
    - batch_size: Number of events per batch
    - apply_add_gauss: Whether to apply additive Gaussian noise (for augmentation)
    - gauss_stds: Standard deviations for additive noise
    - apply_mult_gauss: Whether to apply multiplicative Gaussian noise
    - mult_gauss_std: Standard deviation for multiplicative noise
    
    Returns:
    - train_dataset: DataLoader for training
    - test_dataset: DataLoader for testing
    """
    # Create generators for train and test datasets
    train_generator = DatasetGenerator(file, 'train', batch_size, False,
                 apply_add_gauss, gauss_stds,
                 apply_mult_gauss, mult_gauss_std)
    test_generator = DatasetGenerator(file, 'test', batch_size, False,
                 False, None,
                 False, None)

    # Create DataLoader objects
    train_dataset = DataLoader(train_generator, batch_size=None, shuffle=False, pin_memory=True, num_workers=0, prefetch_factor=None)
    test_dataset = DataLoader(test_generator, batch_size=None, shuffle=False, pin_memory=True, num_workers=0, prefetch_factor=None)

    return train_dataset, test_dataset

## Take a look at data

Below we initiate dataset generator and take a look at one batch.
Each event is padded to the maximal "length" in the batch and auxiliary detectors are added.

In [3]:
generator_config = {
  'file' : '/home3/ivkhar/TA/data/normed/composition_spectrum/taml_0325_energy.h5', # path to training file
  'batch_size' : 4, # batch size
  'apply_add_gauss' : False, # flag for addative data augmentation
  'gauss_stds' : [0., 0., 0., 0., 0. , 0.], # noise parameters
  'apply_mult_gauss' : False, # flag for multiplicative augmentation of registered charges
  'mult_gauss_std' : 0.0 # noise parameters
}

train_generator = DatasetGenerator(regime='train', return_reminder=False, **generator_config)

for data, label in train_generator.__iter__():
    break

In [4]:
print(data.shape, label.shape)
print(data[:2])

torch.Size([4, 34, 7]) torch.Size([4, 1])
tensor([[[ 0.5414, -1.5941, -1.2051, -0.4940, -0.5042, -0.3481,  1.0000],
         [ 0.5461, -0.8243, -1.1074, -0.4972, -0.2328, -0.7657,  1.0000],
         [-0.1986, -0.7951, -1.0037, -0.4665, -0.0465, -0.6553,  1.0000],
         [ 0.5435, -0.0488, -1.0852,  0.0252,  0.0443, -0.8896,  1.0000],
         [ 0.5435, -0.0488, -1.0852,  0.0252,  0.0443,  0.3297,  1.0000],
         [ 0.5442,  0.7246, -0.9282, -0.4936,  0.3160, -0.7222,  1.0000],
         [-0.9999, -0.0454, -1.0370, -0.4302,  0.4144, -0.9027,  1.0000],
         [-0.2268,  0.7237, -0.8708, -0.3854,  0.4993, -0.7661,  1.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.000

# 3. Metrics

When training neural networks, we need metrics to track performance. 
Here we implement a simple metric that tracks the average value of the loss function.

In [5]:
class MeanMetrics(nn.Module):

    # Initialize metric
    def __init__(self, name):
        super().__init__()
        self.register_buffer('value', torch.tensor(0.0))
        self.register_buffer('steps', torch.tensor(0.0))
        self.name = name

    # Define how to reset metric
    def reset(self):
        self.value.zero_()
        self.steps.zero_()

    # Define how to update state
    def update_state(self, value):
        self.value.add_(value.detach())
        self.steps.add_(1.)

    # Define the resulting metric value
    def result(self):
        return self.value / (self.steps + 1e-9) # Add small epsilon to avoid division by zero

# 4. Neural netowrk architecture

For minumal implementation, we will need to define: 
- Embedding layer (simple linear layer). It is needed to increase dimensionality of the data to fit transformer dimsnionality.
- Aggregation layer. How we gather data from all detectors to obtain a signle feature vector.
- Prediction layer (MLP). Combination of linear layers that predicts the target value - log10(E).

Further we will combine them with PyTorch implementation of the transformer architecture.

## 4.1 Embedding layer

Our initial data has 6 features, while transformer has higher dimensionality.
We need to embed your data to this higher dimensional space.

This can be done via a single linear layer.
To preserve physical data, we initialize it as identity matrix and allow NN to optimize it.

### On positional encoding

Position encoding is needed when node encodings do not provide positional information (for example, tokens in natural language processing).
In our case, nodes (detectors) has both temporal and spatial information, which yield positional encoding redundant.

In [6]:
class EmbeddingLayer(nn.Module):

    def __init__(self, dim_in, dim_out):
        """
        - dim_in: Input dimension (number of detector features)
        - dim_out: Hidden layer dimension (transformer dimensionality)
        """
        super().__init__()
        self.dense = nn.Linear(dim_in, dim_out, bias=False)
        # Initialize matrix as identity
        nn.init.eye_(self.dense.weight)

    def forward(self, x):
        # Linearly map to transformer dimensionality 
        return self.dense(x)

## 4.2 Aggregation Layer

To infer energy of the primary particle, we need to aggregate data from all detectors. 
We will use averaging for this purpose. 

Another approach is to introduce classifiaction token.
We will not use this advanced technique.

In [7]:
class AggregateLayer(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, mask):
        """ 
        Return average taking into account mask
        Parameters:
        - x: data tensor [batch_size, seq_len, d_model]
        - mask: mask for auxiliary detectors [batch_size, 1]
        """
        return torch.sum(x*mask, dim=1) / torch.sum(mask, dim=1)

## 4.3 Prediction Layer

The Prediction Layer is the final component of our network that takes the aggregated detector information and transforms it into the energy prediction.
After the transformer encoder processes and contextualizes all detector activations, we need a way to combine this information and map it to our target variable: the log10 of the primary particle's energy.

In [8]:
class PredictLayer(nn.Module):
    """
    Final prediction layer that processes aggregated detector data.
    
    This layer implements a multi-layer perceptron (MLP) that maps from
    the transformer's high-dimensional representation space to the 
    energy prediction (log10 scale). It consists of:
    
    1. A reduction layer that maps from the model dimension to a smaller dimension
    2. Optional intermediate layers for additional capacity
    3. A final output layer that produces the energy prediction
    
    Each layer is followed by a Leaky ReLU activation, except the final output
    which produces a raw scalar value.
    """

    def __init__(self, dim_in, dim_middle, dim_out, num_middle_layers):
        """
        - dim_in: Input dimension (from transformer encoder, typically d_model)
        - dim_middle: Hidden layer dimension (smaller than dim_in for efficiency)
        - dim_out: Output dimension (1 for energy prediction)
        - num_middle_layers: Number of hidden layers between reduction and output
        """
        super().__init__()
        # Initial dimension reduction
        # This compresses the high-dimensional detector representation
        # to a more manageable size for the regression task
        self.reduce = nn.Linear(dim_in, dim_middle)
        # Hidden layers for additional modeling capacity
        # Each layer maintains the same dimension (dim_middle)
        self.pre_layers = nn.ModuleList([
            nn.Linear(dim_middle, dim_middle) for _ in range(num_middle_layers)
        ])
        # Final output layer that produces the energy prediction
        # Maps from the hidden dimension to a single scalar output
        self.out = nn.Linear(dim_middle, dim_out)
        # Leaky ReLU activation function
        self.activation = F.leaky_relu

    def forward(self, x):
        """
        Parameters:
        - x: Input tensor [batch_size, dim_in]
          This contains aggregated information from all detectors for each event
        
        Returns:
        - Predictions [batch_size, dim_out]
          Log10 of the predicted energy for each event
        """
        # Initial dimension reduction with activation
        x = self.reduce(x)
        x = self.activation(x)
        # Apply each hidden layer with activation
        for layer in self.pre_layers:
            x = layer(x)
            x = self.activation(x)
        # Final prediction layer (no activation)
        return self.out(x)

## 4.4 Complete Encoder Model

To define an encoder we need to combine layers:
1. Embedding layer to increase dimensionality of the data.
2. PyTorch Transformer-Encoder Layer. It will process the data and extracting features of interest.
3. Aggregation layer to aggregate information from all detectors to a single feature vector.
4. Prediction layer to predict the log10 of the energy.

We will use mask to prohibit "real" detectors to paying attention to auxiliary ones. 
The mask is created as follows:
- Start with a 1D mask of shape [batch_size, seq_len, 1], in which 1 indicates a real detector and 0 indicates padding
- Create a 2D mask by multiplying this vector with its transpose: mask * mask.transpose(1,2)
- This gives a matrix of shape [batch_size, seq_len, seq_len] where position (i,j) is 1 only if both detector i and j are real
- Expand this to include the heads dimension: [batch_size, 1, seq_len, seq_len]. This dimensional expasion is required for proper broadcasting with multi head attention.

In [15]:
class Encoder(nn.Module):
    
    def __init__(self, num_layers, num_heads, d_model, d_ff, head_dim, input_dim, dropout,
                    dim_middle_pred, dim_out_pred, num_middle_layers_pred):
        """
        Parameters:
        - num_layers: Number of stacked encoder layers (depth of the model)
        - num_heads: Number of attention heads in each encoder layer
          (allows the model to focus on different aspects of the data)
        - d_model: Model dimension - internal representation size for detector features
        - d_ff: Feed-forward hidden dimension (typically 4x d_model)
        - head_dim: Dimension of each attention head (d_model / num_heads)
        - input_dim: Input feature dimension (6 detector features + 1 mask)
        - dropout: Dropout rate for regularization
        - dim_middle_pred: Hidden dimension in the prediction layer
        - dim_out_pred: Output dimension (1 for energy prediction)
        - num_middle_layers_pred: Number of hidden layers in prediction network
        """
        super().__init__()
        self.num_heads = num_heads
        # Embedding layer: Projects the 6 detector features into the model dimension
        # We use input_dim-1 because the last dimension is the mask
        self.embedding_layer = EmbeddingLayer(input_dim-1, d_model)

        # This is how to initialize PyToch Transformer layer
        # PyTorch native transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            #norm_first=True
        )       
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers
        )
               
        # Aggregation layer: aggragate data from all detectors.
        # It will be used to infer log10 energyof the primary particle.
        self.aggr_layer = AggregateLayer()
        # Prediction layer: Takes the aggregated detector information and
        # predicts the energy of the primary particle
        self.predict_layer = PredictLayer(d_model, dim_middle_pred, dim_out_pred, num_middle_layers_pred)

    def compile(self, optim_kwargs, scheduler_kwargs):
        """
        Configure the model for training by defining loss function, optimizer,
        learning rate scheduler, and evaluation metrics.
        
        Parameters:
        - optim_kwargs: Optimizer parameters
        - scheduler_kwargs: Learning rate scheduler parameters
        """
        # Define loss function - Mean Squared Error for regression task
        self.loss = nn.functional.mse_loss
        # Configure Adam optimizer
        self.optimizer = torch.optim.Adam(self.parameters(), **optim_kwargs)
        # Learning rate scheduler that reduces LR when performance plateaus
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **scheduler_kwargs)
        # Define metrics to track during training
        metric_names = ['mse_logE_loss']
        self.metrics = [ MeanMetrics(name=mn) for mn in metric_names ]

    def forward(self, inputs):
        """  
        The data flow follows these steps:
        1. Split data into detector features and mask
        2. Create attention mask to prevent attending to padding
        3. Embed detector features into higher-dimensional space
        4. Process through multiple encoder layers
        5. Aggregate information from all detectors
        6. Generate energy prediction
        
        Parameters:
        - inputs: Input tensor with shape [batch_size, seq_len, input_dim]
          where seq_len is the maximum number of detectors in the batch
          and the last feature is the mask
        
        Returns:
        - Energy predictions [batch_size, 1] in log10 scale
        """
        # Split input into data features and mask
        # x: [batch_size, seq_len, 6] - detector features
        # mask: [batch_size, seq_len, 1] - binary mask
        x, mask = inputs[:, :, :-1], inputs[:, :, -1:]
        # Create 2D attention mask from the 1D feature mask
        
        # We need to cast mask to the form required by PyTorch
        mask_bool = (mask[:,:,0] == 0) 
        mask_att = torch.tile(mask_bool.unsqueeze(-1) * torch.transpose(mask_bool.unsqueeze(-1), 1, 2), [self.num_heads,1,1]) # (batch_size*num_heads, l, l)
        
        # Embed input features into the model's higher-dimensional space
        # This projection preserves physical meaning while allowing more expressive representations
        x = self.embedding_layer(x)
        # Process through the stack of encoder layers
        # Each layer refines the detector representations
        # Mask parameters prevent transformer for attending auxiliary detectors
        x = self.transformer_encoder(src=x,
            src_key_padding_mask=mask_bool,
            mask = mask_att
                                    )
        # Aggregate data to a single vector
        aggr = self.aggr_layer(x, mask) #torch.sum(x*mask, dim=1) / torch.sum(mask, dim=1)
        # Generate energy prediction from the aggregated event representation
        preds = self.predict_layer(aggr)
        return preds

    def update_metrics(self, metric_updates):
        """
        Update tracking metrics with new values.
        
        Parameters:
        - metric_updates: List of new metric values from the most recent step
        """
        for m_update, m_tracker in zip(metric_updates, self.metrics):
            m_tracker.update_state(m_update)

    def train_step(self, data, labels):
        """
        Perform a single training step with backpropagation.
        
        This method:
        1. Sets the model to training mode
        2. Performs forward pass
        3. Calculates loss
        4. Computes gradients via backpropagation
        5. Updates model parameters
        6. Updates metrics
        
        Parameters:
        - data: Input detector data [batch_size, seq_len, input_dim+1]
        - labels: True energy values [batch_size, 1] in log10 scale
        
        Returns:
        - Dictionary of metrics including loss and learning rate
        """
        # Set model to training mode (enables dropout, batch norm updates, etc.)
        self.train()
        # Zero gradients from previous step
        # This is necessary because PyTorch accumulates gradients
        self.optimizer.zero_grad()
        # Forward pass: Generate predictions
        preds = self.forward(data)
        # Calculate loss between predictions and true values
        loss = self.loss(preds, labels)
        # Backward pass: Compute gradient of loss with respect to parameters
        loss.backward()
        # Update weights using the optimizer
        self.optimizer.step()
        # Update tracking metrics
        self.update_metrics([loss])
        # Return metrics dictionary for loggin
        return {**{m_tracker.name: m_tracker.result() for m_tracker in self.metrics},
               "learning_rate": self.optimizer.param_groups[0]['lr']}

    def test_step(self, data, labels):
        """
        Perform a single validation/test step without parameter updates.
        """
        # Set model to evaluation mode (disables dropout, freezes batch norm, etc.)
        self.eval()
        # Disable gradient calculation
        with torch.no_grad():
            # Forward pass: Generate predictions
            preds = self.forward(data)
            # Calculate loss between predictions and true values
            loss = self.loss(preds, labels)
        # Update tracking metrics
        self.update_metrics([loss])
        # Return metrics dictionary for logging
        return {**{m_tracker.name: m_tracker.result() for m_tracker in self.metrics},
               "learning_rate": self.optimizer.param_groups[0]['lr']}

# 5. Preparing for Training

## 5.1 Set various configurations

### 5.1.1 Set configuration of datasets for training

In [16]:
generator_config = {
  'file' : '/home3/ivkhar/TA/data/normed/composition_spectrum/taml_0325_energy.h5', # path to training file
  'batch_size' : 128, # batch size
  'apply_add_gauss' : False, # flag for addative data augmentation
  'gauss_stds' : [0., 0., 0., 0., 0. , 0.], # noise parameters
  'apply_mult_gauss' : False, # flag for multiplicative augmentation of registered charges
  'mult_gauss_std' : 0.0 # noise parameters
}

### 5.1.2 Set model configuration

In [17]:
# Neural network architecture parameters
nn_arch_params = {
  'num_layers': 5,       # Number of transformer encoder layers
  'num_heads': 4,        # Number of attention heads per layer
  'd_model': 128,        # Model dimension
  'd_ff': 512,           # Feed-forward hidden dimension (4 * d_model)
  'head_dim': 32,        # Dimension of each attention head
  'input_dim': 7,        # Input features dimension (6 + 1 mask)
  'dropout': 0.,         # Dropout rate
  'dim_middle_pred': 32, # Prediction hidden dimension
  'dim_out_pred': 1,     # Output dimension (1 for energy)
  'num_middle_layers_pred': 1  # Number of prediction hidden layers
}

optimizer_params ={
  'lr': 0.0005  # Learning rate
  }

scheduler_params = {
  'factor': 0.25,    # Factor to reduce learning rate by
  'patience': 4      # Number of epochs with no improvement before reducing LR
  }

### 5.1.3 Set Training Configuration (including Early Stopping)

In [18]:
# Training parameters
model_name = 'taml_test'            # Name for saving the model
patience = 8                        # Early stopping patience
train_steps_per_epoch = 1000         # Number of batches per epoch
test_steps_per_epoch = 500          # Number of test batches per epoch
min_delta = 1.e-4                   # Minimum improvement for early stopping
num_epochs = 10                     # Maximum number of epochs; increase for real training

## 5.2 Define training loop

In [19]:
def train_model(nn_arch_builder):

    # Create datasets
    train_dataset, test_dataset = make_datasets(**generator_config)

    # Create model and move to GPU
    model = nn_arch_builder(**nn_arch_params)
    model.to('cuda')
    # Compile the model
    model.compile(optimizer_params, scheduler_params)

    # Create infinite training data iterator
    train_iter = iter(train_dataset)
    
    # Initialize early stopping variables
    best_loss = 1.e9
    wait = 0  # Counter for patience
    
    # Move metrics to GPU
    for metric in model.metrics:
        metric.to('cuda')
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}")
    
        ### TRAINING PHASE
        # Reset metrics
        for metric in model.metrics:
            metric.reset()
    
        # Train for specified number of steps
        for step in range(train_steps_per_epoch):
            # Get next batch from infinite iterator
            batch = next(train_iter)
            # Move data to GPU
            data, labels = [b.to('cuda') for b in batch]
            # Perform training step
            all_metrics = model.train_step(data, labels)
        # Print training metrics
        print(f"Train loss: {all_metrics['mse_logE_loss'].item()}")
    
        ### VALIDATION PHASE
        # Reset metrics
        for metric in model.metrics:
            metric.reset()
    
        # Validate on test dataset
        for i, batch in enumerate(test_dataset):
            if i >= test_steps_per_epoch:
                break
            data, labels = [b.to('cuda') for b in batch]
            all_metrics = model.test_step(data, labels)
            
        print(f"Test loss: {all_metrics['mse_logE_loss'].item()}")
    
        # Early stopping check
        val_loss = all_metrics['mse_logE_loss']
        # EarlyStopping and ModelCheckpoint
        if val_loss < best_loss - min_delta:
            # We have improvement
            best_loss = val_loss
            wait = 0
        else:
            # No improvement
            wait += 1
            if wait >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
    
        # Update learning rate based on validation loss
        model.scheduler.step(val_loss)

## 5.3 Train the model

In [20]:
train_model(Encoder)

Epoch 1
Train loss: 0.07115709036588669
Test loss: 0.008752821944653988
Epoch 2
Train loss: 0.00928259827196598
Test loss: 0.009067647159099579
Epoch 3
Train loss: 0.008210983127355576
Test loss: 0.007554568350315094
Epoch 4
Train loss: 0.007794237229973078
Test loss: 0.007981672883033752
Epoch 5
Train loss: 0.007148415315896273
Test loss: 0.007173273712396622
Epoch 6
Train loss: 0.006880099885165691
Test loss: 0.006401329301297665
Epoch 7
Train loss: 0.006728038191795349
Test loss: 0.006596859078854322
Epoch 8
Train loss: 0.006346757989376783
Test loss: 0.006017283536493778
Epoch 9
Train loss: 0.006508741527795792
Test loss: 0.007102299015969038
Epoch 10
Train loss: 0.006196258589625359
Test loss: 0.006346041336655617
