# Using `PotentialTrainer` to train the model

- `aml.train.trainer.PotentialTrainer` wraps complicated training process
- It has a lot of parameters, therefore using configuration file is recommended (ex. `yaml` format)
  - `PotentialTrainer` can be initialized using `PotentialTrainer.from_config`
- See `config_schnet.yaml` for example config for this example
- Run `trainer.train()` to start training
- If you use tensorboard as logger, run `tensorboard --logdir tensorboard` to monitor progress.

In [1]:
import aml
import yaml

with open("config_schnet.yaml", "r") as f:
    config = yaml.full_load(f)

trainer = aml.train.PotentialTrainer.from_config(config)

In [2]:
trainer.train()

Training schnet_water...
Experiment directory: experiments/schnet_water
Building model...
Model info:
{'compute_force': True,
 'compute_hessian': False,
 'compute_stress': False,
 'energy_model': {'@category': 'energy_model',
                  '@name': 'schnet',
                  'cutoff': 5.0,
                  'hidden_channels': 128,
                  'n_filters': 128,
                  'n_interactions': 6,
                  'n_rbf': 50,
                  'rbf_type': 'gaussian',
                  'species': ['H', 'O'],
                  'trainable_rbf': False}}
Building datasets...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

In [2]:
# Load the trained model

model = aml.load_iap("model/schnet_water_best.ckpt").cuda()

In [3]:
# example molecule
import ase.io

atoms = ase.io.read("../data/water_test.xyz", "10")
energy_dft = atoms.get_potential_energy()
forces_dft = atoms.get_forces()

output = model.forward_atoms(atoms) # predict energy and forces for atoms
energy_pred = output["energy"].item()
forces_pred = output["force"].detach().cpu().numpy()

print("DFT energy: ", energy_dft)
print("Predicted energy: ", energy_pred)
print("DFT forces:\n", forces_dft)
print("Predicted forces:\n", forces_pred)

DFT energy:  -2079.6540811187565
Predicted energy:  -2079.65234375
DFT forces:
 [[ 0.0989196   0.43728087 -0.25122318]
 [-0.07733442 -0.27493436  0.1498364 ]
 [-0.02168371 -0.16173958  0.10132405]]
Predicted forces:
 [[ 0.09738747  0.43051213 -0.24776652]
 [-0.07365216 -0.27011785  0.14853127]
 [-0.02373531 -0.16039431  0.09923527]]


## (Optional) Compile the model to torchscript

- If the model is compiled to `torchscript`, the model can be loaded without any dependencies
- All models except "gemnet_t" can be compiled

In [4]:
aml.compile_iap(model, "model/schnet_water.pt")

RecursiveScriptModule(
  original_name=SchNet
  (species_energy_scale): RecursiveScriptModule(original_name=PerSpeciesScaleShift)
  (representation): RecursiveScriptModule(
    original_name=SchNetRepresentation
    (embedding): RecursiveScriptModule(original_name=Embedding)
    (rbf): RecursiveScriptModule(original_name=GaussianRBF)
    (interactions): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=SchnetInteractionBlock
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (0): RecursiveScriptModule(original_name=Linear)
          (1): RecursiveScriptModule(original_name=ShiftedSoftplus)
          (2): RecursiveScriptModule(original_name=Linear)
        )
        (conv): RecursiveScriptModule(
          original_name=CFConvJittable_44529a
          (aggr_module): RecursiveScriptModule(original_name=SumAggregation)
          (lin1): RecursiveScriptModule(original_name=Linear)
          (lin2): 

In [9]:
# Load compiled model
import torch
model = torch.jit.load("model/schnet_water.pt")
# also load_iap works
model = aml.load_iap("model/schnet_water.pt")