# Train a model

## Installation

In [1]:
%%bash

pip install graph-pes | tail -n 1

Successfully installed graph-pes-0.0.1


We now should have access to the ``graph-pes-train`` command. We can check this by running:

In [1]:
%%bash

graph-pes-train -h

usage: graph-pes-train [-h] [args ...]

Train a GraphPES model using PyTorch Lightning.

positional arguments:
  args        Config files and command line specifications. Config files
              should be YAML (.yaml/.yml) files. Command line specifications
              should be in the form nested^key=value. Final config is built up
              from these items in a left to right manner, with later items
              taking precedence over earlier ones in the case of conflicts.

options:
  -h, --help  show this help message and exit

Copyright 2023-24, John Gardner


## Data definition

We use [load-atoms](https://jla-gardner.github.io/load-atoms/) to download and split the QM7 dataset into training, validation and test datasets:

In [1]:
import ase.io
from load_atoms import load_dataset

structures = load_dataset("QM7")
train, val, test = structures.random_split([0.8, 0.1, 0.1])

ase.io.write("train.xyz", train)
ase.io.write("val.xyz", val)
ase.io.write("test.xyz", test)

Output()

## Configuration

## Let's train

In [4]:
%%bash

export LOAD_ATOMS_VERBOSE=0  # disable load-atoms progress bars
graph-pes-train quickstart-config.yaml

Seed set to 42


[graph-pes INFO]: Set logging level to INFO
[graph-pes INFO]: Started training at 2024-10-14 15:03:08.163


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: jla-gardner. Use `wandb login --relogin` to force relogin
wandb: wandb version 0.18.3 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.17.1
wandb: Run data is saved locally in graph-pes-results/wandb/run-20241014_150309-quickstart-run-2
wandb: Run `wandb offline` to turn off syncing.
wandb: Resuming run quickstart-run-2
wandb: ⭐️ View project at https://wandb.ai/jla-gardner/graph-pes-demo
wandb: 🚀 View run at https://wandb.ai/jla-gardner/graph-pes-demo/runs/quickstart-run-2


[graph-pes INFO]: Logging to graph-pes-results/quickstart-run-2/logs/rank-0.log
[graph-pes INFO]: Starting training on rank 0.
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Caching neighbour lists for 1000 structures with cutoff 3.0
[graph-pes INFO]: Caching neighbour lists for 716 structures with cutoff 3.0
[graph-pes INFO]: Setting up datasets


  return torch.load(io.BytesIO(b))


[graph-pes INFO]: Pre-fitting the model on 1,000 samples
[graph-pes INFO]: 
Model:
PaiNN(
  (z_embedding): PerElementEmbedding(
    dim=32,
    elements=['H', 'C', 'N', 'O', 'S']
  )
  (interactions): UniformModuleList(
    (0-2): 3 x Interaction(
      (filter_generator): HaddamardProduct(
        (components): ModuleList(
          (0): Sequential(
            (0): Bessel(n_features=20, cutoff=3.0, trainable=True)
            (1): Linear(in_features=20, out_features=96, bias=True)
          )
          (1): PolynomialEnvelope(cutoff=3.0, p=6)
        )
      )
      (Phi): MLP(32 → 32 → 96, activation=SiLU())
    )
  )
  (updates): UniformModuleList(
    (0-2): 3 x Update(
      (U): VectorLinear(
        (_linear): Linear(in_features=32, out_features=32, bias=False)
      )
      (V): VectorLinear(
        (_linear): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): MLP(64 → 32 → 96, activation=SiLU())
    )
  )
  (read_out): MLP(32 → 32 → 1, activation=SiLU()

/opt/miniconda3/envs/graph-pes/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/opt/miniconda3/envs/graph-pes/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


   epoch   valid/loss/total   valid/loss/per_atom_energy_mae_component   timer/its_per_s/train   timer/its_per_s/valid
       0            4.08356                                    4.08356                 0.94697                 5.03226
Error while terminating subprocess (pid=40564): 


bash: line 3: 40566 Killed: 9               graph-pes-train quickstart-config.yaml
