# Quick start

Below scripts illustrates how to use different modules in [JAXIP](https://jax.readthedocs.io/).

In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["JAX_PLATFORM_NAME"] = "cpu" 

In [None]:
from jaxip.types import dtype as default_dtype
import jax.numpy as jnp
import jax
# default_dtype.FLOATX = jnp.float64
print(f"{default_dtype.FLOATX=}")

In [None]:
import numpy as np
from pathlib import Path
import matplotlib.pylab as plt
import seaborn as sns
from tqdm import tqdm

## Dataset

In [None]:
base_dir = Path('./H2O_2')

In [None]:
from jaxip.datasets import RunnerStructureDataset
structures = RunnerStructureDataset(Path(base_dir, "input.data"), persist=True)
print("Total number of structures:", len(structures))
structures

In [None]:
structures = [structures[i] for i in range(10)]

##### Data loader

In [None]:
# from torch.utils.data import DataLoader

##### Split train and validation structures

In [None]:
# import torch
# validation_split = 0.032
# nsamples = len(structures)
# split = int(np.floor(validation_split * nsamples))
# train_structures, valid_structures = torch.utils.data.random_split(structures, lengths=[nsamples-split, split])
# structures = valid_structures

## Structure

In [None]:
s = structures[0]
s

In [None]:
# from ase.visualize import view
# from ase.io.vasp import write_vasp
# atoms = s.to_ase_atoms()
# atoms
# view(atoms)
# write_vasp('POSCAR', ase_atoms)

##### Compare between structures

In [None]:
from jaxip.utils.compare import compare
compare(structures[0], structures[1])

##### Calculate distance btween atoms

In [None]:
dis, _ = s.calculate_distance(atom_index=0)
dis.shape

In [None]:
sns.displot(dis, bins=20)
plt.axvline(dis.mean(), color='r')

##### Add/remove per-atom energy offset

In [None]:
# structure = structures[0]
# atom_energy = {'O': 2.4, 'H': 1.2}

# structure.add_energy_offset(atom_energy)
# structure.total_energy

## Descriptor
Atomic environment descriptor

In [None]:
from jaxip.descriptors.acsf import ACSF, G2, G3, G9, CutoffFunction

In [None]:
# acsf = ACSF('Ne')
# cfn = CutoffFunction(3.0, cutoff_type='tanhu')
# acsf.add(G2(cfn, eta=1.00, r_shift=0.00), "Ne")
# acsf.add(G2(cfn, eta=1.00, r_shift=0.25), "Ne")
# acsf.add(G2(cfn, eta=1.00, r_shift=0.50), "Ne")
# acsf.add(G2(cfn, eta=1.00, r_shift=0.75), "Ne")
# acsf.add(G2(cfn, eta=1.00, r_shift=1.00), "Ne")

# acsf

In [None]:
acsf = ACSF('O')

cfn = CutoffFunction(12.0)
g2_1 = G2(cfn, 0.0, 0.001)
g2_2 = G2(cfn, 0.0, 0.01)
g3_1 = G3(cfn, 0.2, 1.0, 1.0, 0.0)
g9_1 = G3(cfn, 0.2, 1.0, 1.0, 0.0)

acsf.add(g2_1, 'H')
acsf.add(g2_2, 'H')
acsf.add(g3_1, 'H', 'H')
acsf.add(g3_1, 'H', 'O')
acsf.add(g9_1, 'H', 'O')
acsf

In [None]:
val = acsf(s)
val[1]
# val.shape   

In [None]:
sns.displot(val[:, 2], bins=20)

#### Gradient

In [None]:
%time acsf.grad(s, 0, 3)

In [None]:
# time [acsf.grad(structures[4], i, 0) for i in range(4)]

## Scaler
Descriptor scaler

In [None]:
from jaxip.descriptors.scaler import DescriptorScaler

In [None]:
scaler = DescriptorScaler(scale_type='scale_center')
# acsf = nnp.descriptor["H"]

for structure in tqdm(structures):
    x = acsf(structure)
    scaler.fit(x)

scaler

In [None]:
scaled_x = []
for structure in tqdm(structures):
    x = acsf(structure)
    scaled_x.append(scaler(x))

scaled_x = jnp.concatenate(scaled_x, axis=0)
scaled_x.shape

In [None]:
sx = scaled_x[:, 5]
sns.displot(sx, bins=30)
plt.axvline(sx.mean(), color='r', lw=3);
plt.axvline(0, color='k');

## Model

In [None]:
from jaxip.models.nn import NeuralNetworkModel
from jaxip.models.initializer import UniformInitializer
from flax import linen as nn

In [None]:
nn = NeuralNetworkModel(
    hidden_layers=((8, 't'), (8, 't')),
    kernel_initializer=UniformInitializer(weights_range=(-1, 1)),
    # param_dtype=jnp.float64,
)

In [None]:
rng = jax.random.PRNGKey(2022) # PRNG Key
x = jnp.ones(shape=(8, acsf.num_symmetry_functions)) # Dummy Input
params = nn.init(rng, x) # Initialize the parameters
jax.tree_map(lambda x: x.shape, params) # Check the parameters

In [None]:
eng = nn.apply(params, scaled_x[:, :])

In [None]:
sns.displot(eng, bins=30);