# Potential training
An example notebook which shows how to train a high-dimensional neural network potential (HDNNP). 

## Imports

In [4]:
import pantea
from pantea.types import default_dtype
from pantea.datasets import Dataset
from pantea.potentials import NeuralNetworkPotential 
from pantea.logger import LoggingContextManager

import logging
import numpy as np
from pathlib import Path
import matplotlib.pylab as plt
import seaborn as sns
import random
from tqdm import tqdm
from collections import defaultdict
import jax.numpy as jnp
import jax

In [5]:
# pantea.logger.set_logging_level(logging.DEBUG) 
default_dtype.FLOATX = jnp.float64
print(f"default dtype: {default_dtype.FLOATX.dtype}")
print(f"default device: {jax.devices()[0]}")

default dtype: float64
default device: cuda:0


## Dataset

In [6]:
base_dir = Path('GRN')

In [7]:
structures = Dataset.from_runner(Path(base_dir, "input.data"), persist=True) 
# structures = RunnerDataset(Path(base_dir, "input.data"), transform=ToStructure(r_cutoff=3.0), persist=True) 
print("Total number of structures:", len(structures))
structures

Total number of structures: 801


Dataset(datasource=RunnerDataSource(filename='GRN/input.data', dtype=float64), persist=True)

In [8]:
# indices = random.choices(range(len(structures)), k=100)
# structures = [structures[i] for i in range(len(structures))] # len(structures) 

In [9]:
# import torch
# validation_split = 0.10
# nsamples = len(structures)
# split = int(np.floor(validation_split * nsamples))
# train_structures, valid_structures = torch.utils.data.random_split(structures, lengths=[nsamples-split, split])
# structures = valid_structures

In [10]:
s = structures[0]
s

Structure(natoms=24, elements=('C',), dtype=float64)

In [11]:
# energies = jnp.asarray([x.total_energy for x in structures]).reshape(-1)
# print("Energy difference:", max(energies) - min(energies))
# sns.histplot(energies);

In [12]:
# with LoggingContextManager(level=logging.DEBUG):
# structures[0].to_dict()

In [13]:
# from ase.visualize import view
# atoms = s.to_ase()
# view(atoms, viewer="x3d", repeat=1)

In [14]:
# from ase.io import read, write
# write("atoms.png", atoms * (2, 2, 1), rotation='30z,-80x')
# write("atoms.xyz", atoms * (2, 2, 1))
# ![atoms](atoms.png)

In [15]:
# from pantea.atoms import Structure
# sp = Structure.from_ase(atoms)
# view(sp.to_ase(), viewer="x3d", repeat=3)

## Potential

In [16]:
nnp = NeuralNetworkPotential.from_file(Path(base_dir, "input.nn"))
nnp

NeuralNetworkPotential(atomic_potential={'C': AtomicPotential(
  descriptor=ACSF(central_element='C', symmetry_functions=30),
  scaler=DescriptorScaler(scale_type='center', scale_min=0.0, scale_max=1.0),
  model=NeuralNetworkModel(hidden_layers=((15, 'tanh'), (15, 'tanh')), param_dtype=float64),
)})

In [17]:
# nnp.load()
# nnp.load_scaler()

### Extrapolation warnings

In [18]:
# nnp.set_extrapolation_warnings(100)

##### Fit scaler

In [19]:
structures = [structures[index] for index in range(5)]

In [20]:
# with LoggingContextManager(level=logging.DEBUG):
nnp.fit_scaler(structures)

Fitting descriptor scaler...


100%|██████████| 5/5 [00:06<00:00,  1.21s/it]

Done.





In [21]:
time nnp(s)

CPU times: user 5.12 s, sys: 175 ms, total: 5.3 s
Wall time: 4.5 s


Array(2.46366166, dtype=float64)

In [22]:
time nnp.compute_forces(s)

CPU times: user 28.4 s, sys: 706 ms, total: 29.1 s
Wall time: 26.6 s


Array([[-0.73818518, -1.11250519,  2.19987526],
       [-0.60619055,  1.33811234,  0.23699985],
       [ 0.39212853,  1.86935634, -1.38083394],
       [-1.24370451,  1.38747383, -0.16639707],
       [ 1.17766708, -0.06752427, -0.27530439],
       [-0.90934336, -1.97845911,  1.25125193],
       [-0.08749226, -1.3495124 ,  1.21127079],
       [ 0.6463172 , -0.81786491,  1.80367847],
       [-0.95089421, -1.12067475, -1.07121341],
       [-1.98676834,  2.32836088,  1.22984266],
       [-0.67455432,  0.96647708, -0.77369288],
       [ 0.62557632, -1.00776267, -1.20962975],
       [ 1.50403599,  0.24804631, -0.56286071],
       [ 0.83629443,  2.39212317, -0.65869627],
       [ 1.844709  ,  1.33191195, -1.44894777],
       [-0.52273629,  1.91610702,  0.86284399],
       [ 1.44972348,  0.02283899, -1.30880034],
       [-1.74267315,  1.41118732, -0.40835022],
       [ 0.93299649,  0.35661005, -2.28477157],
       [-1.45420376, -2.42909438, -0.36560638],
       [ 0.59659081, -2.43054665,  0.849

## Training

In [None]:
h = nnp.fit_model(structures)

for sub in h:
    if 'loss' in sub:
        plt.plot(h['epoch'], h[sub], label=sub)
plt.legend();

In [None]:
# nnp.save()

## Validation

### Energy

In [None]:
print(f"{len(structures)=}")
true_energy = [s.total_energy for s in structures]
pred_energy = [nnp(s) for s in structures]
ii = range(len(structures))
plt.scatter(true_energy, pred_energy, label='NNP')
plt.plot(true_energy, true_energy, 'r', label="REF")
plt.xlabel("true energy")
plt.ylabel("pred energy")
plt.legend()
plt.show()

### Force

In [None]:
import jax.numpy as jnp

true_forces = []
pred_forces = []
print(f"{len(structures)=}")
for structure in structures:
    true_forces_per_structure = structure.force
    pred_forces_per_structure = nnp.compute_force(structure)
    true_forces.append(true_forces_per_structure)
    pred_forces.append(pred_forces_per_structure)

dim = 0
to_axis = {d: c for d, c in enumerate('xyz')}
true_forces = jnp.concatenate(true_forces, axis=0)
pred_forces = jnp.concatenate(pred_forces, axis=0)

plt.scatter(true_forces[:, dim], pred_forces[:, dim], label='NNP')
plt.plot(true_forces[:, dim], true_forces[:, dim], 'r', label='REF')

label= f"force [{to_axis[dim]}]"
plt.ylabel("pred " + label)
plt.xlabel("true " + label)
plt.legend()
plt.show()