# Training a TrajCast Model

This notebook provides a short guide to training a *TrajCast* model for paracetamol. Usually, this involves the following steps:
- Generate a dataset from raw MD trajectory data (not covered here).
- Prepare and choose the appropriate hyperparameters in a configuration file.
- Train the *TrajCast* model.
- Track the training progress.

> **_Note_:** 
> This code is still under development. This notebook is intended to provide help and guidance in getting started. Future updates will focus on making the code more efficient and user-friendly.

## Download the Dataset

Rather than generating our dataset from scratch by computing displacement and velocity vectors at a specified time interval, here, we skip this step and download a subset of the dataset used for the experiments of paracetmol in our preprint, available on HuggingFace. This dataset comprises 500 configurations for training, 125 for validation, and 100 for testing, respectively. **These datasets are only 5% of the size used in the experiments in our manuscript.** Of course, feel free to use the datasets used in the original experiments.

> **_Note_:** 
> You can preprocess your own dataset using `compute_additional_fields` within the `ASETrajectory` class. We will add this to the notebook later. 

In [None]:
from huggingface_hub import hf_hub_download

for dataset in ["train", "val", "test"]:
    hf_hub_download(
        repo_id="ibm-research/trajcast.datasets-arxiv2025",
        repo_type="dataset",
        revision="main",
        filename=f"example/{dataset}.extxyz",
        local_dir="../data",
    )

## Preparing the Configuration Dictionary

With the data at the right place, we can now start preparing for training. For this we create configuration dictionary with all necessary hyperparameters and settings. *TrajCast* arranges this dictionary based on three different categories: model, data, and training.

> **_Note_:** 
> Here, we build a **small model** that is likely not very accurate, allowing for training on a CPU, as this notebook is primarily intended to help users get started with the code. However, feel free to adjust the parameters to match those in the paper for more accurate models.

### Model Hyperparameters

In [None]:
model_dict = {
    "precision": 32,  # Floating point precision, 32 for single, 64 for double
    "num_chem_elements": 4,  # Number of different chemical species the model means to describe
    "edge_cutoff": 4.0,  # Distance cutoff for message passing
    "num_edge_rbf": 8,  # Number of radial basis functions edge
    "num_edge_poly_cutoff": 6,  # p of polynomial cutoff function,
    "vel_max": 0.11,  # Cutoff for velocities obtained from Maxwell-Boltzmann distribution
    "num_vel_rbf": 8,  # Number of radial basis functions velocities
    "max_rotation_order": 1,  # Rotation order up to which we will resolve features and spherical harmonics
    "num_hidden_channels": 16,  # Number of features per irrep
    "num_mp_layers": 3,  # Number of message passing layers
    "edge_mlp_kwargs": {  # Settings for MLP producing weights for edge tensor product
        "n_neurons": [
            16,
            16,
            16,
        ],  # Number of neurons per layer, here 3 hidden layers with 16 neurons each
        "activation": "silu",  # Chosen activation function
    },
    "vel_mlp_kwargs": {  # Settings for MLP producing weights for velocity tensor product
        "n_neurons": [
            16,
            16,
            16,
        ],  # Number of neurons per layer, here 3 hidden layers with 16 neurons each
        "activation": "silu",  # Chosen activation function
    },
    "nl_gate_kwargs": {  # Settings for non-linear gates at the end of the update
        "activation_scalars": {  # Activation functions for scalar features
            "o": "tanh",  # Odd features
            "e": "silu",  # Even features
        },
        "activation_gates": {  # Settings for scalars which will be used to scale L>0 features after non-linearity
            "e": "silu"  # We only use even features for gating
        },
    },
    "conserve_ang_mom": True,  # Whether to conserve angular momentum,
    "o3_backend": "e3nn",  # Whether to use e3nn or cueq, we recommend the latter particularly for large systems
    "net_lin_mom": [
        0.0,
        0.0,
        0.0,
    ],  # Net linear momentum you'd expect (usually the 0 vector if momentum was properly zeroed in equilibrium MD)
    "net_ang_mom": [0.0, 0.0, 0.0],  # Same as above but for angular momentum
}

### Dataset for Training

> **_Note_:** 
> This is the training data only. Currently, validation data will be passed as part of the training dictionary. 

In [None]:
data_dict = {
    "root": ".",  # Directory where the data lies
    "name": "paracetamol_training",  # Name the processed dataset should have
    "cutoff_radius": 4.0,  # Cutoff for defining edges between nodes, should be the same as for model, will be fixed later
    "files": [
        "../data/example/train.extxyz"
    ],  # Files with the data, can be multiple ones
    "rename": True,  # Dependent on the precision it will add a tag to the processed filename
    "atom_type_mapper": {  # Mapping chemical atom types to variables within the model, not necessary but less error prone
        1: 0,  # H -> 0
        6: 1,  # C -> 1
        7: 2,  # N -> 2
        8: 3,  # O -> 3
    },
}

### Training Settings


In [None]:
train_dict = {
    "seed": 1705,  # Random seed for initialisation
    "model_type": "EfficientTrajCastModel",  # The type of model you want to train, do not change!
    "device": "cpu",  # Device on which to run the training, set to cpu here in case this is run without
    "restart_latest": False,  # Whether to start from an old checkpoint
    "target_field": "target",  # Name of the field where the model will save its prediction
    "reference_fields": [  # Where the true labels are saved
        "displacements",
        "update_velocities",
    ],
    "batch_size": 10,  # How many configurations are contained in one batch, here 10
    "max_grad_norm": 0.5,  # Gradient clipping
    "num_epochs": 10,  # Maximum number of epochs to be performed, usually longer, here we just pick 10
    "criterion": {  # This is to set up the loss function, will be updated and simplified later
        "loss_type": {"main_loss": "mse"},
        "learnable_weights": False,
    },
    "optimizer": "adam",  # Which optimizer to use
    "optimizer_settings": {  # Settings for chosen optimizer
        "lr": 0.01,  # Learning rate
        "amsgrad": True,  # Whether AMSGrad is turned on
    },
    "scheduler": ["ReduceLROnPlateau"],  # Which Schedulers to use, can be multiple
    "scheduler_settings": {  # Settings dictionary for each scheduler
        "ReduceLROnPlateau": {"factor": 0.8, "patience": 25, "min_lr": 0.0001}
    },
    "chained_scheduler_hp": {  # Interaction between schedulers, this is outdated and will be removed soon
        "milestones": [
            10000000
        ],  # For using only 1 scheduler make sure this value is larger than the total number of updates
        "per_epoch": True,  # Adjust LR per epoch rather than per batch
        "monitor_lr_scheduler": False,  # For debugging, LR is monitored anyway
    },
    "tensorboard_settings": {  # Validation and tracking of weights, loss, and LR happens in tensorboard
        "loss": True,  # Track loss
        "lr": True,  # Track lr
        "loss_validation": {  # Compute validation loss based on a validation set defined in "data"
            "data": {
                "root": ".",  # Directory where the data lies
                "name": "paracetamol_validation",  # Name the processed dataset should have
                "cutoff_radius": 4.0,  # Cutoff for defining edges between nodes, should be the same as for model and training set, will be fixed later
                "files": [
                    "../data/example/val.extxyz"
                ],  # Files with the data, can be multiple ones
                "rename": True,  # Dependent on the precision it will add a tag to the processed filename
                "atom_type_mapper": {  # Mapping chemical atom types to variables within the model, should be the same as for training set
                    1: 0,  # H -> 0
                    6: 1,  # C -> 1
                    7: 2,  # N -> 2
                    8: 3,  # O -> 3
                },
            }
        },
    },
}

Assemble the dictionaries and build trainer.

In [None]:
from trajcast.model.training import Trainer

config = {}
config["model"] = model_dict
config["data"] = data_dict
config["training"] = train_dict

trainer = Trainer(config)

Once, the `Trainer` class is intialised, we can easily access it's attributes such as `dataset` and `model`. Now, however, we will start training our model for the specified short number of 10 epochs.

# Training

Training a *TrajCast* model is simple and just requires a call of the `train` routine of the `Trainer` class. This will immediately kick-off the training and will generate directories for logging, checkpointing, and tracking the training progress.

In [None]:
trainer.train()

Let's have a look at the produced directories and files:
- `logs`: Logging of all model and training parameters as well as training and validation loss of each epoch.
- `checkpoints`: Stores state dictionaries of best model and all relevant information of the latest epoch.
- `tb_log`: Tensorboard with validation to track and easily visualize relevant metrics.

Let us now extract and visualize the training and validation loss:

In [None]:
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt

metrics = ["loss_training", "loss_validation"]
fig, ax = plt.subplots(1, 1)
for metric in metrics:
    event = event_accumulator.EventAccumulator(f"tb_log/{metric}")
    event.Reload()
    data = event.Scalars("loss")
    epochs = [i.step for i in data]
    loss = [i.value for i in data]
    ax.plot(epochs, loss, label=metric)

ax.set_ylabel("Loss")
ax.set_xlabel("Epoch")
ax.legend()

> **_Note_:** 
> You can also plot the MAE for displacements and velocities by replacing the paths.

Once the model training has finished, ideally with converged learning curves, we can deploy our model for inference and forecasting. For more information on this please refer to this [example notebook](../inference/forecasting.ipynb).

The last thing to do here, is to save our configuration dictionary to a YAML file so we can load our state dictionaries:


In [None]:
trainer.dump_config_to_yaml("config_example.yaml")