# Potential training
An example notebook which shows how to train a high-dimensional neural network potential (HDNNP). 

In [1]:
# !gpustat

In [2]:
from utils import set_env
# set_env('.env')

In [3]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["JAX_PLATFORM_NAME"] = "cpu" 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 

## Imports

In [4]:
import logging
import numpy as np
from pathlib import Path
import matplotlib.pylab as plt
import seaborn as sns
import random
from tqdm import tqdm
from collections import defaultdict
import jax.numpy as jnp
import jax
import jaxip
from jaxip.types import _dtype as default_dtype
from jaxip.datasets import RunnerDataset
from jaxip.potentials import NeuralNetworkPotential
from jaxip.logger import LoggingContextManager

In [24]:
# jaxip.logger.set_logging_level(logging.DEBUG) 
default_dtype.FLOATX = jnp.float64
print(f"default dtype: {default_dtype.FLOATX.dtype}")
print(f"default device: {jax.devices()[0]}")

default dtype: float64
default device: TFRT_CPU_0


## Dataset

In [6]:
base_dir = Path('GRN')

In [7]:
structures = RunnerDataset(Path(base_dir, "input.data"), persist=True) 
# structures = RunnerDataset(Path(base_dir, "input.data"), transform=ToStructure(r_cutoff=3.0), persist=True) 
print("Total number of structures:", len(structures))
structures

Total number of structures: 801


RunnerDataset(filename='GRN/input.data', persist=True, dtype=float64)

In [8]:
# indices = random.choices(range(len(structures)), k=100)
# structures = [structures[i] for i in range(len(structures))] # len(structures) 

In [9]:
# import torch
# validation_split = 0.10
# nsamples = len(structures)
# split = int(np.floor(validation_split * nsamples))
# train_structures, valid_structures = torch.utils.data.random_split(structures, lengths=[nsamples-split, split])
# structures = valid_structures

In [10]:
s = structures[0]
s

Structure(natoms=24, elements=('C',), dtype=float64)

In [11]:
# energies = jnp.asarray([x.total_energy for x in structures]).reshape(-1)
# print("Energy difference:", max(energies) - min(energies))
# sns.histplot(energies);

In [12]:
# with LoggingContextManager(level=logging.DEBUG):
# structures[0].to_dict()

In [13]:
# from ase.visualize import view
# atoms = s.to_ase()
# view(atoms, viewer="x3d", repeat=1)

In [14]:
# from ase.io import read, write
# write("atoms.png", atoms * (2, 2, 1), rotation='30z,-80x')
# write("atoms.xyz", atoms * (2, 2, 1))
# ![atoms](atoms.png)

In [15]:
# from jaxip.atoms import Structure
# sp = Structure.from_ase(atoms)
# view(sp.to_ase(), viewer="x3d", repeat=3)

## Potential

In [16]:
nnp = NeuralNetworkPotential.from_file(Path(base_dir, "input.nn"))
nnp

NeuralNetworkPotential(atomic_potential={'C': AtomicPotential(
  descriptor=ACSF(element='C', symmetry_functions=30),
  scaler=Scaler(scale_type='center', scale_min=0.0, scale_max=1.0),
  model=NeuralNetworkModel(hidden_layers=((15, 'tanh'), (15, 'tanh')), param_dtype=float64),
)})

In [17]:
# nnp.load()
# nnp.load_scaler()

### Extrapolation warnings

In [18]:
# nnp.set_extrapolation_warnings(100)

##### Fit scaler

In [19]:
structures = [structures[index] for index in range(5)]

In [20]:
# with LoggingContextManager(level=logging.DEBUG):
nnp.fit_scaler(structures)

Fitting descriptor scaler...


100%|██████████| 5/5 [00:10<00:00,  2.17s/it]

Done.





In [25]:
time nnp(s)

CPU times: user 47.8 ms, sys: 34 µs, total: 47.9 ms
Wall time: 46.1 ms


Array(-6.82868052, dtype=float64)

In [26]:
time nnp.compute_forces(s)

CPU times: user 118 ms, sys: 3.99 ms, total: 122 ms
Wall time: 119 ms


Array([[-0.21749556, -0.22101345,  0.96314385],
       [-0.17348052,  0.29058047,  0.09322139],
       [ 0.12319294,  0.46653307, -0.58487702],
       [-0.36234939,  0.26546542, -0.04409426],
       [ 0.346702  , -0.07358062, -0.13121177],
       [-0.30885594, -0.39307608,  0.54907396],
       [-0.07475244, -0.35618296,  0.53893394],
       [ 0.23104194, -0.1023122 ,  0.74780426],
       [-0.24752292, -0.17490092, -0.50934252],
       [-0.59892267,  0.59693202,  0.55618236],
       [-0.12395165,  0.11799627, -0.29239672],
       [ 0.19501384, -0.24643927, -0.54155392],
       [ 0.46368609, -0.01020692, -0.2686671 ],
       [ 0.21723097,  0.53499804, -0.3288631 ],
       [ 0.48568427,  0.18817638, -0.66564165],
       [-0.09097604,  0.31544902,  0.31615664],
       [ 0.47305574,  0.0651856 , -0.57309809],
       [-0.52593941,  0.40127842, -0.21241424],
       [ 0.31030249,  0.23289228, -0.98879699],
       [-0.46335954, -0.63778809, -0.12070011],
       [ 0.22355744, -0.5229901 ,  0.406

## Training

In [None]:
h = nnp.fit_model(structures)

for sub in h:
    if 'loss' in sub:
        plt.plot(h['epoch'], h[sub], label=sub)
plt.legend();

In [None]:
# nnp.save()

## Validation

### Energy

In [None]:
print(f"{len(structures)=}")
true_energy = [s.total_energy for s in structures]
pred_energy = [nnp(s) for s in structures]
ii = range(len(structures))
plt.scatter(true_energy, pred_energy, label='NNP')
plt.plot(true_energy, true_energy, 'r', label="REF")
plt.xlabel("true energy")
plt.ylabel("pred energy")
plt.legend()
plt.show()

### Force

In [None]:
import jax.numpy as jnp

true_forces = []
pred_forces = []
print(f"{len(structures)=}")
for structure in structures:
    true_forces_per_structure = structure.force
    pred_forces_per_structure = nnp.compute_force(structure)
    true_forces.append(true_forces_per_structure)
    pred_forces.append(pred_forces_per_structure)

dim = 0
to_axis = {d: c for d, c in enumerate('xyz')}
true_forces = jnp.concatenate(true_forces, axis=0)
pred_forces = jnp.concatenate(pred_forces, axis=0)

plt.scatter(true_forces[:, dim], pred_forces[:, dim], label='NNP')
plt.plot(true_forces[:, dim], true_forces[:, dim], 'r', label='REF')

label= f"force [{to_axis[dim]}]"
plt.ylabel("pred " + label)
plt.xlabel("true " + label)
plt.legend()
plt.show()