# Introduction

Here we provide the `EnergyModule` and the `NpzTrainer` as initialized for the experiments in the paper. They can be used as drop in replacement in `training_evaluation_and_loading_pretrained` to train models for other data sets. All data sets can be found in the correspoding Zenodo repository at https://doi.org/10.5281/zenodo.14750286. 

In [8]:
import e3x
import jax.numpy as jnp

from euclidean_fast_attention import NpzTrainer
from euclidean_fast_attention import EnergyModel

# 3D Cluster

In [None]:
data_path = '$PATH_TO_DATA/3d_cluster/$DATA_FILE.npz'

num_iterations = 2  # can vary between the local models
batch_size = 32

num_atoms = 32  # or num_atoms = 16 depending on $DATA_FILE

use_efa = True

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=1500,
    num_valid=500,
    num_epochs=5000,
    max_num_nodes=int(batch_size * num_atoms + 1),
    max_num_edges=int(batch_size * num_atoms * min(num_atoms, 40) + 1),
    max_num_graphs=batch_size + 1,
    energy_unit=1,
    length_unit=1,
    save_interval_steps=2000,
    use_wandb=True
)

if use_efa is True:
    model = EnergyModel(
        num_features=128,
        mp_max_degree=1,
        num_iterations=2,
        era_use_in_iterations=[0, 1],
        era_lebedev_num=50,
        era_tensor_integration=False,
        era_include_pseudotensors=False,
        era_max_degree=0,
        era_qk_num_features=16,
        era_v_num_features=32,
        era_max_frequency=jnp.pi,
        era_max_length=15,
        era_activation_fn=e3x.nn.gelu,
    )
else:
    model = EnergyModel(
        num_features=128,
        mp_max_degree=1,
        num_iterations=num_iterations,
        era_use_in_iterations=None
        )

# Charge-Dipole

In [None]:
data_path = '$PATH_TO_DATA/charge_dipole/charge_dipole_potential_train.npz'

# EFA is once trained with and without tensor integration and only the 
# model which incorporated directional information can solve the charge-dipole problem. Fig. 2B
tensor_integration = True  
use_efa = True

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=2000,
    num_valid=500,
    num_epochs=3000,
    max_num_nodes=22,
    max_num_edges=22,
    max_num_graphs=11,
    energy_unit=1,
    length_unit=1,
    save_interval_steps=500,
    use_wandb=True
)

if use_efa is True:
    model_nl = EnergyModel(
        mp_max_degree=1,
        num_features=128,
        num_iterations=2,
        era_use_in_iterations=[0, 1],
        era_lebedev_num=50,
        era_include_pseudotensors=False,
        era_max_degree=0,
        era_qk_num_features=32,
        era_v_num_features=16,
        era_max_frequency=jnp.pi,
        era_max_length=10,
        atomic_dipole_embedding=True,
        era_activation_fn=e3x.nn.gelu,
        era_tensor_integration=tensor_integration,
        era_ti_max_degree_sph=1 if tensor_integration True else None,
        era_ti_max_degree=1 if tensor_integration is True else None,
        era_ti_parametrize_coupling_paths=True if tensor_integration is True else None
    )
else:
    model_l = EnergyModel(
            mp_max_degree=1,
            num_features=128,
            num_iterations=2,
            era_use_in_iterations=None,
            atomic_dipole_embedding=True
    )


# SN2

In [10]:
data_path = '$PATH_TO_DATA/sn2/sn2_reactions_shifted.npz'

batch_size = 32
use_efa = True

rcut = 5  # rcut can be 5 or 10 for the local model

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=400_000,
    num_valid=5_000,
    num_epochs=500,
    max_num_nodes=int(batch_size*6 + 1),
    max_num_edges=int(batch_size*36 + 1),
    max_num_graphs=batch_size + 1,
    energy_unit=1.,
    length_unit=1.,
    save_interval_steps=5000,
    subtract_energy_mean=False,
    model_seed=0
)

if use_efa is True:
    model_nl = EnergyModel(
        cutoff=5.,
        num_features=128,
        num_iterations=2,
        era_use_in_iterations=[0, 1],
        era_lebedev_num=50,
        era_include_pseudotensors=False,
        era_max_degree=0,
        era_qk_num_features=16,
        era_v_num_features=32,
        era_max_frequency=jnp.pi,
        era_max_length=20,
        era_activation_fn=e3x.nn.gelu,
    )
else:
    model = EnergyModel(
        cutoff=rcut,
        num_features=128 + 34,  # local model has more features
        num_iterations=2
    )


# Cumulene

In [14]:
data_path = '$PATH_TO_DATA/cumulene/cumuelene.npz'

batch_size = 16
num_iterations = 3  # can be also 4 or 5 for the local model, see Fig. 5A

use_efa = True

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=2000,
    num_valid=500,
    num_epochs=2000,
    max_num_nodes=batch_size * 13 + 1,
    max_num_edges=batch_size * 13 * 5 + 1,
    max_num_graphs=batch_size,
    energy_unit=1,
    length_unit=1,
    save_interval_steps=2000,
    use_wandb=True,
    model_seed=0
)

if use_efa is True:
    model_nl = EnergyModel(
        cutoff=4.,
        mp_max_degree=2,
        num_features=64,
        num_iterations=3,
        era_use_in_iterations=[0, 1, 2],
        era_lebedev_num=50,
        era_include_pseudotensors=False,
        era_max_degree=1,
        era_qk_num_features=32,
        era_v_num_features=16,
        era_max_frequency=jnp.pi,
        era_max_length=15,
        atomic_dipole_embedding=False,
        era_activation_fn=e3x.nn.gelu,
        era_tensor_integration=False,
    )
else:
    model_l = EnergyModel(
        cutoff=4.,
        mp_max_degree=2,
        num_features=84,
        num_iterations=num_iterations,
        era_use_in_iterations=None
)


# Dimers

In [18]:
data_path = '$PATH_TO_DATA/dimers/DES370K_subsampled_shifted.npz'

batch_size = 16

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=4250,
    num_valid=250,
    num_epochs=6000,
    max_num_nodes=16 * (batch_size + 3) + 1,
    max_num_edges=16 * 16 * (batch_size + 3) + 1,
    max_num_graphs=batch_size + 1,
    energy_unit=1,
    length_unit=1,
    energy_weight=1.,
    forces_weight=1.,
    save_interval_steps=2000,
    model_seed=0
)

use_efa = True


if use_efa is True:
    model_nl = EnergyModel(
        cutoff=4.,
        mp_max_degree=2,
        num_features=64,
        num_iterations=3,
        era_use_in_iterations=[0, 1, 2],
        era_lebedev_num=50,
        era_include_pseudotensors=False,
        era_max_degree=1,
        era_qk_num_features=16,
        era_v_num_features=32,
        era_max_frequency=jnp.pi,
        era_max_length=15,
        era_activation_fn=e3x.nn.gelu,
        era_tensor_integration=False
    )
else: 
    model_l = EnergyModel(
        cutoff=4.,
        mp_max_degree=2,
        num_features=84,
        num_iterations=3,
    )
