# Introduction

This notebook aims to give a slightly more comprehensive introduction into training, and evaluating the model on the example of the 3D cluster data. This will cover training a local and EFA augmented model, as well as collecting test metrics from an hold out test data set. We also show how one can load and use the pre-trained model from the checkpoints in `../pretrained`. 

Other models, e.g. for cumulene, SN2, dimers and so on, can be trained in a similar fashion, by simply replacing the `NpzTrainer` and the `EnergyModel` with the correct hyperparameters. This is done in a per-dataset fashion in the `models` notebook.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import optax
import numpy as np
import e3x
import pathlib
import wandb

from euclidean_fast_attention import EnergyModel
from euclidean_fast_attention import NpzTrainer

import orbax.checkpoint as ocp
import pathlib

def load_params(ckpt_dir):
    loaded_mngr = ocp.CheckpointManager(
        pathlib.Path(ckpt_dir).expanduser().absolute().resolve(),
        item_names=('params', ),
        item_handlers={
            'params': ocp.StandardCheckpointHandler(),
        },
        options=ocp.CheckpointManagerOptions(step_prefix="ckpt"),
    )
    restored = loaded_mngr.restore(
        loaded_mngr.latest_step(),
        args=None
    )

    return restored['params']

# Train the model from scratch

In [None]:
# Start by setting the hyper parameters

num_iterations = 2
batch_size = 32
use_efa = True

# Number of atoms in cluster either 16 or 32
num_atoms = 32

if num_atoms == 32:
    radius_string = '5'
elif num_atoms == 16:
    radius_string = '2point5'
else:
    raise ValueError('Only 3D cluster data sets for N=16 and N=32 exist.')

if use_efa is True:
    efa_handle = 'with_efa'
else:
    efa_handle = 'no_efa'

In [None]:
use_wandb = False

if use_wandb is True:
    # If wandb is used we must initialize the run
    wandb.init(
        project='euclidean_fast_attention', 
        group='3d_cluster', 
        name=f'efa_{num_atoms=}_radius_{radius_string}_{num_iterations=}_{batch_size=}'
    )

else:
    
    log_interval_steps = 10_000

ckpt_dir = pathlib.Path(
    f'3d_cluster_{num_atoms=}_radius_{radius_string}_{num_iterations=}_{batch_size=}_{efa_handle}'
).resolve().expanduser()

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=1500,
    num_valid=500,
    num_epochs=5000,
    max_num_nodes=int(batch_size * num_atoms + 1),
    max_num_edges=int(batch_size * num_atoms * min(num_atoms, 40) + 1),
    max_num_graphs=batch_size + 1,
    energy_unit=1,
    length_unit=1,
    log_interval_steps=log_interval_steps,
    save_interval_steps=2000,
    use_wandb=use_wandb
)

if use_efa is True:
    model = EnergyModel(
        num_features=128,
        mp_max_degree=1,
        num_iterations=num_iterations,
        era_use_in_iterations=[0, 1],
        era_lebedev_num=50,
        era_tensor_integration=False,
        era_include_pseudotensors=False,
        era_max_degree=0,
        era_qk_num_features=16,
        era_v_num_features=32,
        era_max_frequency=jnp.pi,
        era_max_length=15,
        era_activation_fn=e3x.nn.gelu,
    )
else:
    model = EnergyModel(
        num_features=128,
        mp_max_degree=1,
        num_iterations=num_iterations,
        era_use_in_iterations=None
        )

schedule = optax.exponential_decay(
    init_value=1e-3,
    transition_steps=int(
        npztrainer.num_train
        / (npztrainer.max_num_graphs - 1)
        * npztrainer.num_epochs
    ),
    decay_rate=1e-5 / 1e-3,
)

optimizer = optax.adam(learning_rate=schedule)

_ = npztrainer.run_training(model, optimizer, ckpt_dir=ckpt_dir)

params = load_params(ckpt_dir)

# After training, the model can be evluated as
metrics = npztrainer.run_testing(
    params=params, 
    model=model, 
    num_test=500, 
    collect_predictions=False
)

# Load the Model From Checkpoint

Alternatively, one can load the model parameters from the pre-trained checkpoints. For example, execute the code below for 
`use_efa = True` and `use_efa = False` to replicate the points in Fig. 3B.

In [None]:
# Start by setting the hyper parameters

num_iterations = 2
batch_size = 32
use_efa = False

# Number of atoms in cluster either 16 or 32
num_atoms = 32

if num_atoms == 32:
    radius_string = '5'
elif num_atoms == 16:
    radius_string = '2point5'
else:
    raise ValueError('Only 3D cluster data sets for N=16 and N=32 exist.')

if use_efa is True:
    efa_handle = 'with_efa'
else:
    efa_handle = 'without_efa'

In [None]:
params = load_params(
    ckpt_dir=pathlib.Path('../pretrained/3d_cluster/').resolve() / f'3d_cluster_{num_atoms=}_radius_{radius_string}_{num_iterations=}_{batch_size=}_{efa_handle}'
)

npztrainer = NpzTrainer(
    data_dir=data_path,
    num_train=1500,
    num_valid=500,
    num_epochs=5000,
    max_num_nodes=int(batch_size * num_atoms + 1),
    max_num_edges=int(batch_size * num_atoms * min(num_atoms, 40) + 1),
    max_num_graphs=batch_size + 1,
    energy_unit=1,
    length_unit=1,
    save_interval_steps=2000,
    use_wandb=False
)

if use_efa is True:
    model = EnergyModel(
        num_features=128,
        mp_max_degree=1,
        num_iterations=num_iterations,
        era_use_in_iterations=[0, 1],
        era_lebedev_num=50,
        era_tensor_integration=False,
        era_include_pseudotensors=False,
        era_max_degree=0,
        era_qk_num_features=16,
        era_v_num_features=32,
        era_max_frequency=jnp.pi,
        era_max_length=15,
        era_activation_fn=e3x.nn.gelu,
    )
else:
    model = EnergyModel(
        num_features=128,
        mp_max_degree=1,
        num_iterations=num_iterations,
        era_use_in_iterations=None
        )

metrics = npztrainer.run_testing(
    params=params, 
    model=model, 
    num_test=500, 
    collect_predictions=False
)

In [None]:
print(f'{efa_handle} and {num_iterations=}')
print(f'Forces RMSE: {metrics[0]['forces_rmse']}')