# Potential Training
An example notebook which shows how to train a high-dimensional neural network potential (HDNNP). 

In [1]:
# !gpustat

In [2]:
from utils import set_env
set_env('.env')

In [3]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
# os.environ["JAX_PLATFORM_NAME"] = "cpu" 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 

## Imports

In [4]:
import logging
import numpy as np
from pathlib import Path
import matplotlib.pylab as plt
import seaborn as sns
import random
from tqdm import tqdm
from collections import defaultdict
import jax.numpy as jnp
import jax
import jaxip
from jaxip.types import dtype as default_dtype
from jaxip.datasets import RunnerDataset
from jaxip.potentials import NeuralNetworkPotential
from jaxip.logger import LoggingContextManager

In [5]:
# jaxip.logger.set_logging_level(logging.INFO) 
default_dtype.FLOATX = jnp.float64
print(f"default dtype: {default_dtype.FLOATX.dtype}")

default dtype: float64


## Dataset

In [6]:
base_dir = Path('home/H2O_2')

In [7]:
structures = RunnerDataset(Path(base_dir, "input.data"), persist=True) 
# structures = RunnerStructureDataset(Path(base_dir, "input.data"), transform=ToStructure(r_cutoff=3.0), persist=True) 
print("Total number of structures:", len(structures))
structures

Total number of structures: 1593


RunnerDataset(filename='home/H2O_2/input.data', transform=ToStructure())

In [8]:
# indices = random.choices(range(len(structures)), k=100)
structures = [structures[i] for i in range(20)] # len(structures) 

In [9]:
# energies = jnp.asarray([x.total_energy for x in structures]).reshape(-1)
# print("Energy difference:", max(energies) - min(energies))
# sns.histplot(energies);

In [10]:
s = structures[0]
s

Structure(natoms=192, elements=('H', 'O'), dtype=float64)

In [11]:
# with LoggingContextManager(level=logging.DEBUG):
# structures[0].to_dict()

In [12]:
# from ase.visualize import view
# atoms = s.to_ase_atoms()
# view(atoms, viewer="x3d", repeat=1)

In [13]:
# from ase.io import read, write
# write("atoms.png", atoms * (2, 2, 1), rotation='30z,-80x')
# write("atoms.xyz", atoms * (2, 2, 1))
# ![atoms](atoms.png)

In [14]:
# from jaxip.atoms import Structure
# sp = Structure.create_from_ase(atoms)
# view(sp.to_ase_atoms(), viewer="x3d", repeat=3)

## Potential

In [15]:
nnp = NeuralNetworkPotential.create_from_file(Path(base_dir, "input.nn"))
nnp

NeuralNetworkPotential(atomic_potential={'H': AtomicPotential(
  descriptor=ACSF(element='H', symmetry_functions=27, r_cutoff=12.0),
  scaler=Scaler(scale_type='center', scale_min=0.0, scale_max=1.0),
  model=NeuralNetworkModel(hidden_layers=((20, 'tanh'), (20, 'tanh')), param_dtype=float64),
), 'O': AtomicPotential(
  descriptor=ACSF(element='O', symmetry_functions=30, r_cutoff=12.0),
  scaler=Scaler(scale_type='center', scale_min=0.0, scale_max=1.0),
  model=NeuralNetworkModel(hidden_layers=((20, 'tanh'), (20, 'tanh')), param_dtype=float64),
)})

In [19]:
# nnp.load()
# nnp.load_scaler()

### Extrapolation warnings

In [20]:
# nnp.set_extrapolation_warnings(100)

##### Fit scaler

In [None]:
nnp.fit_scaler(structures)

In [22]:
time nnp(s)

CPU times: user 1.19 s, sys: 285 ms, total: 1.48 s
Wall time: 1.45 s


Array(-1103.34228721, dtype=float64)

In [None]:
time nnp.compute_force(s)

## Training

In [None]:
h = nnp.fit_model(structures)

for sub in h:
    if 'loss' in sub:
        plt.plot(h['epoch'], h[sub], label=sub)
plt.legend();

In [23]:
# nnp.save()

INFO Saving scaler parameters for element (Ne): scaling.010.data
INFO Saving model weights for element (Ne): weights.010.pkl


## Validation

### Energy

In [None]:
print(f"{len(structures)=}")
true_energy = [s.total_energy for s in structures]
pred_energy = [nnp(s) for s in structures]
ii = range(len(structures))
plt.scatter(true_energy, pred_energy, label='NNP')
plt.plot(true_energy, true_energy, 'r', label="REF")
plt.xlabel("true energy")
plt.ylabel("pred energy")
plt.legend()
plt.show()

### Force

In [None]:
import jax.numpy as jnp

true_forces = []
pred_forces = []
print(f"{len(structures)=}")
for structure in structures:
    true_forces_per_structure = structure.force
    pred_forces_per_structure = nnp.compute_force(structure)
    true_forces.append(true_forces_per_structure)
    pred_forces.append(pred_forces_per_structure)

dim = 0
to_axis = {d: c for d, c in enumerate('xyz')}
true_forces = jnp.concatenate(true_forces, axis=0)
pred_forces = jnp.concatenate(pred_forces, axis=0)

plt.scatter(true_forces[:, dim], pred_forces[:, dim], label='NNP')
plt.plot(true_forces[:, dim], true_forces[:, dim], 'r', label='REF')

label= f"force [{to_axis[dim]}]"
plt.ylabel("pred " + label)
plt.xlabel("true " + label)
plt.legend()
plt.show()