### Potential example
An example notebook for constructing a high-dimensional neural network potential (HDNNP). 

In [None]:
# !gpustat

In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["JAX_PLATFORM_NAME"] = "cpu" 
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 

### Imports

In [None]:
import logging
import numpy as np
from pathlib import Path
import matplotlib.pylab as plt
import seaborn as sns
import random
from tqdm import tqdm
from collections import defaultdict
import jax.numpy as jnp
import jax

import jaxip
from jaxip import logger
from jaxip.types import dtype as default_dtype
from jaxip.datasets import RunnerStructureDataset
from jaxip.potentials import NeuralNetworkPotential
from jaxip.logger import LoggingContextManager

In [None]:
# print(jaxip.__doc__)
# print(f"version: {jaxip.__version__}")

# jaxip.logger.set_logging_level(logging.DEBUG)
# default_dtype.FLOATX = jnp.float64
print(f"default dtype: {default_dtype.FLOATX.dtype}")

## Dataset

In [None]:
base_dir = Path('./GRN')

In [None]:
structures = RunnerStructureDataset(Path(base_dir, "input.data"), persist=True) 
# structures = RunnerStructureDataset(Path(base_dir, "input.data"), transform=ToStructure(r_cutoff=3.0), persist=True) 
print("Total number of structures:", len(structures))
structures

In [None]:
indices = random.choices(range(len(structures)), k=10)
structures = [structures[i] for i in indices] 

In [None]:
# energies = jnp.asarray([x.total_energy for x in structures]).reshape(-1)
# print("Energy difference:", max(energies) - min(energies))
# sns.histplot(energies);

In [None]:
s = structures[0]
s

In [None]:
# with LoggingContextManager(level=logging.DEBUG):
# structures[0].to_dict()

In [None]:
# from ase.visualize import view
# atoms = s.to_ase_atoms()
# atoms
# view(atoms, viewer="x3d", repeat=2)

In [None]:
# from ase.io import read, write
# write("atoms.png", atoms * (2, 2, 1), rotation='30z,-80x')
# write("atoms.xyz", atoms * (2, 2, 1))
# ![atoms](atoms.png)

In [None]:
# from jaxip.structure import Structure
# sp = Structure.create_from_ase(atoms)
# view(sp.to_ase_atoms(), viewer="x3d", repeat=3)

## Potential

In [None]:
nnp = NeuralNetworkPotential.create_from(Path(base_dir, "input.nn"))
nnp

In [None]:
# from jaxip.potentials import NeuralNetworkPotential
# from jaxip.potentials import NeuralNetworkPotentialSettings as Settings
    
# settings = Settings(**nnp.settings.dict())
# nnp2 = NeuralNetworkPotential(settings)    
    
# nnp.settings.to_json('h2o.json')
# settings = Settings.from_json('h2o.json')
# nnp2 = NeuralNetworkPotential(settings)

# nnp2

##### Extrapolation warnings

In [None]:
# nnp.set_extrapolation_warnings(100)

##### Fit scaler

In [None]:
nnp.fit_scaler(structures)

In [None]:
time nnp(s)

In [None]:
time force = nnp.compute_force(s)

### Training

In [None]:
h = nnp.fit_model(structures, epochs=10, batch_size=1)

for sub in h:
    if 'loss' in sub:
        plt.plot(h['epoch'], h[sub], label=sub)
plt.legend();

### Validation

##### Energy

In [None]:
print(f"{len(structures)=}")
true_energy = [s.total_energy for s in structures]
pred_energy = [nnp(s) for s in structures]
ii = range(len(structures))
plt.scatter(true_energy, pred_energy, label='NNP')
plt.plot(true_energy, true_energy, 'r', label="REF")
plt.xlabel("true energy")
plt.ylabel("pred energy")
plt.legend()
plt.show()

if "Ne" in nnp.elements:
    plt.plot(ii, true_energy, '.-', label="NNP");
    plt.plot(ii, pred_energy, '.-', label="REF");
    plt.legend()
    plt.show()

##### Force

In [None]:
import jax.numpy as jnp

true_forces = defaultdict(list)
pred_forces = defaultdict(list)

print(f"{len(structures)=}")
for structure in structures:
    true_forces_per_structure = structure.get_forces()
    pred_forces_per_structure = nnp.compute_force(structure)

    for element in nnp.elements:
        true_forces[element].append(true_forces_per_structure[element])
        pred_forces[element].append(pred_forces_per_structure[element])

dim = 0
to_axis = {d: c for d, c in enumerate('xyz')}
for element in nnp.elements:
    true_forces[element] = jnp.concatenate(true_forces[element], axis=0)
    pred_forces[element] = jnp.concatenate(pred_forces[element], axis=0)

    plt.scatter(true_forces[element][:, dim], pred_forces[element][:, dim], label='NNP')
    plt.plot(true_forces[element][:, dim], true_forces[element][:, dim], 'r', label='REF')
    plt.title(f"Element: {element}")
   
    label= f"force [{to_axis[dim]}]"
    plt.ylabel("pred " + label)
    plt.xlabel("true " + label)
    plt.legend()
    plt.show()
    