# Getting started

Example codes illustrate how to use [Pantea](https://pantea.readthedocs.io/).

In [1]:
# !gpustat

## Imports

In [2]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"  # disable GPU

In [3]:
# import logging
# from pantea.logger import set_logging_level
# set_logging_level(logging.INFO)

## Dataset

### RuNNer
Read input dataset in [RuNNer](https://www.uni-goettingen.de/de/560580.html) format.

In [4]:
from pantea.datasets import Dataset
structures = Dataset.from_runner("input.data", persist=False)
print("Total number of structures:", len(structures))
# structures.preload()
structures

Total number of structures: 20


Dataset(datasource=RunnerDataSource(filename='input.data', dtype=float64), persist=False)

#### Split train and validation structures

In [5]:
# import torch
# validation_split = 0.032
# nsamples = len(structures)
# split = int(np.floor(validation_split * nsamples))
# train_structures, valid_structures = torch.utils.data.random_split(structures, lengths=[nsamples-split, split])
# structures = valid_structures

## Structure

In [33]:
structure = structures[0]
structure

Structure(natoms=12, elements=('H', 'O'), dtype=float64)

In [34]:
from ase.visualize import view
atoms = structure.to_ase()
# view(atoms, viewer='ngl') # ase, ngl

In [8]:
from ase.io.vasp import write_vasp
write_vasp('POSCAR', atoms)

### Compare between structures

In [9]:
# from pantea.utils.compare import compare
# compare(structures[0], structures[1])

### Calculate distance between atoms

In [10]:
from pantea.atoms import calculate_distances
distances = calculate_distances(structure)
distances[0, :5]

Array([0.        , 2.24720275, 3.85385356, 4.84207409, 7.55265933],      dtype=float64)

In [11]:
# sns.displot(dis.flatten(), bins=20)
# plt.axvline(dis.mean(), color='r');

### Find neighboring atom

In [12]:
from pantea.atoms import Neighbor
neighbor = Neighbor.from_structure(structure, r_cutoff=10.0)
print(neighbor)
print("Number of neighbors for atom index 0:", sum(neighbor.masks[0]))

Neighbor(r_cutoff=10.0)
Number of neighbors for atom index 0: 11


### Per-atom energy offset

In [13]:
structure = structures[0]
atom_energy = {'O': 2.4, 'H': 1.2}

structure.add_energy_offset(atom_energy)
structure.total_energy

Array(-12.93900273, dtype=float64)

## Descriptor

Atomic environment descriptor.

In [14]:
from pantea.descriptors import ACSF
from pantea.descriptors.acsf import G2, G3, G9, CutoffFunction, NeighborElements

In [15]:
# Define cutoff, radial, and angular symmetry functions
cfn = CutoffFunction.from_type("tanhu", r_cutoff=12.0)
g2_1 = G2(cfn, 0.0, 0.001)
g2_2 = G2(cfn, 0.0, 0.01)
g3_1 = G3(cfn, 0.2, 1.0, 1.0, 0.0)
g9_1 = G3(cfn, 0.2, 1.0, 1.0, 0.0)
# Create an ACSF descriptor for Oxygen atoms with multiple symmetry functions
acsf = ACSF(
    central_element='O',
    radial_symmetry_functions=(
        (g2_1, NeighborElements('H')),
        (g2_2, NeighborElements('H')),
    ),
    angular_symmetry_functions=(
        (g3_1, NeighborElements('H', 'H')),
        (g9_1, NeighborElements('H', 'O')),
    ),
)
acsf

AtomCenteredSymmetryFunction(central_element='O', num_symmetry_functions=4)

### Computing descriptor values

In [16]:
descriptor_values = acsf(structure) # only for O atoms
descriptor_values

Array([[1.14844196e+00, 9.88176531e-01, 1.47813356e-04, 4.32482880e-04],
       [1.08272759e+00, 9.35614402e-01, 2.52484053e-04, 7.64628911e-06],
       [1.21152837e+00, 1.01618940e+00, 3.80667373e-05, 3.74757210e-04],
       [1.01457901e+00, 8.42276766e-01, 4.21045778e-05, 1.53015035e-08]],      dtype=float64)

### Gradient

In [17]:
gradients = acsf.grad(structure) # gradient respect to the atom positions
gradients[:1, ...]

Array([[[-0.02516027,  0.00975464,  0.08007459],
        [-0.01986903, -0.01250941,  0.06779186],
        [ 0.00015445, -0.00028165, -0.00015613],
        [-0.00093546,  0.00050747,  0.00044732]]], dtype=float64)

## Scaler

Descriptor scaler.

In [18]:
from pantea.descriptors import DescriptorScaler, ScalerParams
from tqdm import tqdm

### Fitting scaling parameters

In [19]:
scaler = DescriptorScaler.from_type('scale_center')

for index, structure in enumerate(tqdm(structures)):  
    descriptor_values = acsf(structure)
    if index == 0:
        scaler_params = scaler.fit(descriptor_values) 
    else:
        scaler_params = scaler.partial_fit(scaler_params, descriptor_values)
    
scaler

100%|██████████| 20/20 [00:00<00:00, 44.04it/s]


DescriptorScaler(transform='_scale_center', scale_range=(0.0, 1.0))

In [20]:
scaled_descriptor_values = scaler(scaler_params, descriptor_values)
scaled_descriptor_values

Array([[ 0.11808065,  0.16078358, -0.11087884,  0.32829958],
       [ 0.09382786,  0.1297158 ,  0.2480281 , -0.33211773],
       [ 0.27612704,  0.18111383, -0.24553955,  0.29089544],
       [-0.04685936, -0.06034315, -0.1896889 , -0.33306557]],      dtype=float64)

## Model

In [21]:
from pantea.models import NeuralNetworkModel
from pantea.models.nn import UniformInitializer
from flax import linen as nn
import jax
import jax.numpy as jnp

In [22]:
model = NeuralNetworkModel(
    hidden_layers=(
        (8, 'tanh'), 
        (8, 'tanh'),
    ),
)
model

NeuralNetworkModel(hidden_layers=((8, 'tanh'), (8, 'tanh')), dtype=float64)

In [23]:
rng = jax.random.PRNGKey(2022)                               # PRNG Key
inputs = jnp.ones(shape=(8, acsf.num_symmetry_functions))    # Dummy Input

model_params = model.init(rng, inputs)                             # Initialize the parameters
jax.tree.map(lambda x: x.shape, model_params)                      # Check the parameters

{'params': {'layers_0': {'bias': (8,), 'kernel': (4, 8)},
  'layers_2': {'bias': (8,), 'kernel': (8, 8)},
  'layers_4': {'bias': (1,), 'kernel': (8, 1)}}}

In [24]:
energies = model.apply(model_params, scaled_descriptor_values)
energies # per atom energies

Array([[-0.01620077],
       [-0.18288607],
       [ 0.0744489 ],
       [ 0.06749975]], dtype=float64)

### Atomic Potential

An atomic potential calculates the energy of a specific element in structures. It forms the basic building block of the final potential, which typically contains multiple elements. Atomic potential bundles up all the necessary components such as descriptors, scalers, and models in order to output the per-atomic energy.

In [35]:
from pantea.potentials.nnp import AtomicPotential

In [26]:
atomic_potential = AtomicPotential(
    descriptor=acsf,
    scaler=scaler,
    model=model,
)

atomic_potential

AtomicPotential(
  descriptor=AtomCenteredSymmetryFunction(central_element='O', num_symmetry_functions=4),
  scaler=DescriptorScaler(transform='_scale_center', scale_range=(0.0, 1.0)),
  model=NeuralNetworkModel(hidden_layers=((8, 'tanh'), (8, 'tanh')), dtype=float64),
)

In [27]:
energies =  atomic_potential.apply(model_params["params"], scaler_params, structure)
energies

Array([[-0.01620077],
       [-0.18288607],
       [ 0.0744489 ],
       [ 0.06749975]], dtype=float64)

## Neural Network Potential

An instance of neural network potential (NNP) including descirptor, scaler, and model for multiple elements can be initialzied directly from the input potential files. 

In [28]:
from pantea.datasets import Dataset
from pantea.potentials import NeuralNetworkPotential
from ase.visualize import view
from pathlib import Path

### Read dataset

In [29]:
base_dir = Path(".")

# Atomic data
structures = Dataset.from_runner(Path(base_dir, "input.data"))

structure = structures[0]
structure

Structure(natoms=12, elements=('H', 'O'), dtype=float64)

### Load potential parameters

In [30]:
# Potential
nnp = NeuralNetworkPotential.from_runner(Path(base_dir, "input.nn"))

# nnp.save()
nnp.load()

### Predictions

Warm-up period is bacause of the lazy class loading and just-in-time (JIT) compilation.

In [31]:
total_energy = nnp(structure)
total_energy

Array(-33.97294888, dtype=float64)

In [32]:
forces = nnp.compute_forces(structure)
forces[:5]

Array([[-0.00120265, -0.00401054,  0.02040878],
       [-0.08127553, -0.05667687,  0.0694481 ],
       [ 0.03140665,  0.09825961,  0.0487879 ],
       [ 0.08668581, -0.00724281,  0.03676635],
       [-0.02163175, -0.0238923 ,  0.01323156]], dtype=float64)