# Train a model

## Installation

In [1]:
%%bash

pip install graph-pes | tail -n 1

Successfully installed graph-pes-0.0.1


We now should have access to the ``graph-pes-train`` command. We can check this by running:

In [1]:
%%bash

graph-pes-train -h

usage: graph-pes-train [-h] [args ...]

Train a GraphPES model using PyTorch Lightning.

positional arguments:
  args        Config files and command line specifications. Config files
              should be YAML (.yaml/.yml) files. Command line specifications
              should be in the form nested^key=value. Final config is built up
              from these items in a left to right manner, with later items
              taking precedence over earlier ones in the case of conflicts.

options:
  -h, --help  show this help message and exit

Copyright 2023-24, John Gardner


In [None]:
import ase.io
from load_atoms import load_dataset

structures = load_dataset("QM7")
train, test = structures.random_split([0.8, 0.2])

ase.io.write("train.xyz", train)
ase.io.write("test.xyz", test)

## Configuration

## Let's train

In [7]:
%%bash

export LOAD_ATOMS_VERBOSE=0  # disable load-atoms progress bars
graph-pes-train quickstart-config.yaml

Seed set to 42


[graph-pes INFO]: Set logging level to INFO
[graph-pes INFO]: Started training at 2024-10-14 14:38:52.397
[graph-pes INFO]: Output directory: graph-pes-results/quickstart-run
[graph-pes INFO]: 
Logging using WandbLogger(
  project="graph-pes-demo",
  id="quickstart-run",
  save_dir="graph-pes-results"
)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: jla-gardner. Use `wandb login --relogin` to force relogin
wandb: wandb version 0.18.3 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.17.1
wandb: Run data is saved locally in graph-pes-results/wandb/run-20241014_143853-quickstart-run
wandb: Run `wandb offline` to turn off syncing.
wandb: Resuming run quickstart-run
wandb: ⭐️ View project at https://wandb.ai/jla-gardner/graph-pes-demo
wandb: 🚀 View run at https://wandb.ai/jla-gardner/graph-pes-demo/runs/quickstart-run


[graph-pes INFO]: Logging to graph-pes-results/quickstart-run/logs/rank-0.log
[graph-pes INFO]: 
model:
   graph_pes.models.PaiNN:
      layers: 3
      cutoff: 3.0
data:
   graph_pes.data.load_atoms_dataset:
      id: QM7
      cutoff: 3.0
      n_train: 1000
      n_valid: 100
loss: graph_pes.training.loss.PerAtomEnergyLoss()
fitting:
   pre_fit_model: true
   max_n_pre_fit: 5000
   early_stopping_patience: null
   trainer_kwargs:
      max_epochs: 200
      accelerator: auto
      enable_model_summary: false
   loader_kwargs:
      num_workers: 0
      persistent_workers: false
      batch_size: 16
      pin_memory: false
   optimizer:
      graph_pes.training.opt.Optimizer:
         name: AdamW
         lr: 0.0001
   scheduler: null
   swa: null
general:
   seed: 42
   root_dir: graph-pes-results
   run_id: quickstart-run
   log_level: INFO
   progress: logged
wandb:
   project: graph-pes-demo

[graph-pes INFO]: 
FittingData(
  train=ASEDataset(1,000, labels=['energy']),
  valid=AS

You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/u/vld/jesu2890/miniconda3/envs/graphs/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=35` in the `DataLoader` to improve performance.
/u/vld/jesu2890/miniconda3/envs/graphs/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num

   epoch   valid/loss/total   valid/loss/per_atom_energy_mae_component   timer/its_per_s/train   timer/its_per_s/valid
       0            4.07476                                    4.07476                83.33334               207.14285
       1            2.65200                                    2.65200                83.33334               202.38097
       2            0.39732                                    0.39732                90.90909               221.42857
       3            0.28913                                    0.28913                90.90909               214.28572
       4            0.21064                                    0.21064                83.33334               207.14285
       5            0.16655                                    0.16655                90.90909               209.52380
       6            0.14174                                    0.14174                83.33334               207.14285
       7            0.11001                     