In [1]:
import os
from concurrent.futures import ProcessPoolExecutor

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from uf3.data import io
from uf3.data import geometry
from uf3.data import composition
from uf3.representation import bspline
from uf3.representation import process
from uf3.regression import least_squares
from uf3.forcefield import calculator
from uf3.forcefield import lammps
from uf3.util import parallel
from uf3.util import plotting



# $\text{UF}_{2,3}$ Demo: Elemental tungsten

**Data split**
- Training set: 1939 configurations (stratified 20% of the dataset)

- Holdout: 7754 configurations (remaining 80%)

**Inputs**
- ```w-14.xyz``` (30 mb)
- ```training_idx.txt``` (10 kb, included for reproducibility purposes)

**Outputs**
- ```df_features_uf23.h5``` (650 mb)
- ```model_uf23.json``` (3 kb)

In [4]:
%%html
<style>
  table {margin-left: 0 !important;}
</style>

 Step         | Estimated Time 
:-------------|:--------------
Preprocessing | 10 seconds
Featurization | 2.5 core-hours (parallelizable)
Training      | 4 seconds
Prediction    | 3 seconds
Plotting      | 10 seconds

# User Parameters

```element_list (list)```: list of element symbols

```degree (int)```: truncation of many-body expansion. A value of 3 yields a two-and-three-body potential.

In [5]:
element_list = ['W']
degree = 3

Initialize the ```ChemicalSystem``` and inspect interactions.

Elements involved in each interactions are sorted by electronegativity.

In [6]:
chemical_system = composition.ChemicalSystem(element_list=element_list,
                                             degree=degree)
print("Pairs:", chemical_system.interactions_map[2])

Pairs: [('W', 'W')]


In [7]:
print("Trios:", chemical_system.interactions_map[3])

Trios: [('W', 'W', 'W')]


```r_min_map (dict)```: map of minimum pair distance per interaction (angstroms). 
    If unspecified, defaults to 1.0 for all interactions.
    
```r_max_map (dict)```: map of maximum pair distance per interaction (angstroms). 
    If unspecified, defaults to 6.0 angstroms for all interactions, which probably encompasses at least 2nd-nearest neighbors.
    
```resolution_map (dict)```: map of resolution (number of knot intervals) per interaction. 
    For the cubic basis, the number of basis functions equals three more than the number of knot intervals.
    This is, in turn, negated by ```trailing_trim```.
    If unspecified, defaults to 20 for all two-body interactions and 5 for three-body interactions.
    
```trailing_trim (int)```: number of trailing basis functions to trim, defaults to 3.
 - ```= 0```: hard cutoff at ```r_max```
 - ```= 1```: function goes to zero at ```r_max```
 - ```= 2```: first derivative goes to zero at ```r_max```
 - ```= 3```: second derivative goes to zero at ```r_max```

**Note: the demo's resolution and cutoffs (3.5-3.5-7.0Å, 5-5-10) are small to reduce runtime and filesize.**

**Results in the manuscript use (4.25-4.25-8.5Å, 10-10-20), requiring about 4 core-hours and 6 gb.**

In [8]:
r_min_map = {("W", "W"): 1.5,
             ("W", "W", "W"): [1.5, 1.5, 1.5],
            }
r_max_map = {("W", "W"): 5.5,
             ("W", "W", "W"): [3.5, 3.5, 7.0],
            }
resolution_map = {("W", "W"): 25,
                  ("W", "W", "W"): [5, 5, 10],
                 }
trailing_trim = 3

# Initialize basis

In [12]:
bspline_config = bspline.BSplineBasis(chemical_system,
                                      r_min_map=r_min_map,
                                      r_max_map=r_max_map,
                                      resolution_map=resolution_map,
                                      trailing_trim=trailing_trim)

```bspline_config.get_interaction_partitions()``` yields the number of coefficients for each n-body interaction (one-body terms, two-body terms, three-body terms, ...) as well as the starting index in the coefficient vector for each interaction.

In [13]:
bspline_config.get_interaction_partitions()[0]

{'W': 1, ('W', 'W'): 28, ('W', 'W', 'W'): 144}

In [14]:
bspline_config.get_interaction_partitions()[1]

{'W': 0, ('W', 'W'): 1, ('W', 'W', 'W'): 29}

# Import potential

Start from here if you want to skip the long training part.

In [17]:
model = least_squares.WeightedLinearModel(bspline_config)

model.load(filename="../tungsten_extxyz/model_uf23.json")

calc = calculator.UFCalculator(model)

In [1]:
# geom = df_data.iloc[3000]['geometry'].copy()  # 12-atom cell Create my own cell

# Build a JAX MD potential

In [33]:
import jax.numpy as jnp
from jax import jit, grad, vmap, value_and_grad

from jax.config import config
config.update("jax_enable_x64", True)

from jax_md import space, smap, energy, minimize, quantity, simulate, partition, util

from uf3.util import jax_utils
from uf3.jax.potentials import uf3_pair, uf3_neighbor

In [20]:
ndspline2 = calc.pair_potentials[('W','W')]
ndspline3 = calc.trio_potentials[('W','W','W')]

In [21]:
coefficients2 = jnp.asarray(ndspline2.coefficients)
coefficients2 = coefficients2[:,0]

knots2 = ndspline2.knots
knots2 = knots2[0]

print(coefficients2.shape)

print(knots2)



(28,)
[1.5  1.5  1.5  1.5  1.66 1.82 1.98 2.14 2.3  2.46 2.62 2.78 2.94 3.1
 3.26 3.42 3.58 3.74 3.9  4.06 4.22 4.38 4.54 4.7  4.86 5.02 5.18 5.34
 5.5  5.5  5.5  5.5 ]


In [22]:
coefficients3 = jnp.asarray(ndspline3.coefficients)
coefficients3 = coefficients3[:,:,:,0]

knots3 = ndspline3.knots
knots3 = [jnp.asarray(i) for i in knots3]

print(coefficients3.shape)

print(knots3)

(8, 8, 13)
[DeviceArray([1.5, 1.5, 1.5, 1.5, 1.9, 2.3, 2.7, 3.1, 3.5, 3.5, 3.5, 3.5],            dtype=float64), DeviceArray([1.5, 1.5, 1.5, 1.5, 1.9, 2.3, 2.7, 3.1, 3.5, 3.5, 3.5, 3.5],            dtype=float64), DeviceArray([1.5 , 1.5 , 1.5 , 1.5 , 2.05, 2.6 , 3.15, 3.7 , 4.25, 4.8 ,
             5.35, 5.9 , 6.45, 7.  , 7.  , 7.  , 7.  ], dtype=float64)]


JAX MD requires the cell to be at least twice as long as the longest potential cutoff in each dimension (at least currently).
We thus build an artificially larger cell. The resulting energy of the system is scaled by the same factor as the number of atoms in the system, but the stress and forces remain the same.

In [23]:
longest_range = max(knots2[-1], knots3[0][-1])
new_geom = jax_utils.scale_atoms(geom, longest_range)

In [25]:
new_geom.set_calculator(calc)
print("Energy:", new_geom.get_potential_energy())
print("Stresses (numerical):", new_geom.get_stress())
forces = new_geom.get_forces()
print("Forces:\n", forces[:13])
print("Max force:", np.max(np.abs(forces)))

Energy: -1511.356168483806
Stresses (numerical): [-0.27284373 -0.21269526 -0.19587426  0.00673861 -0.03251023 -0.06668803]
Forces:
 [[ 0.9247198   7.1954469  -0.51517995]
 [-3.09391715  0.34519563  0.45007935]
 [ 3.60074214  1.16019415  0.47379295]
 [ 4.3535802  -4.65271381  0.62942499]
 [-1.80424397 -2.67622904 -1.71355421]
 [-1.70798109  3.64077167  1.33918404]
 [-2.27916515 -3.19984066  1.08843362]
 [ 2.23753848  1.25635827 -1.37486839]
 [ 1.95888195 -0.86112301 -0.01602678]
 [-1.66730536 -6.8770263   0.13213979]
 [-4.3160293   3.18834981 -0.11026939]
 [ 1.79317945  1.48061639 -0.38315601]
 [ 0.9247198   7.1954469  -0.51517995]]
Max force: 7.1954469006795785


In [30]:
box = jnp.asarray(new_geom.cell)
displacement, shift = space.periodic_general(box, fractional_coordinates=False)

R = jnp.asarray(new_geom.positions)

nf, ef = uf3_neighbor(displacement, box, cutoff=5.5, knots2=knots2, knots3=knots3)

In [31]:
nbrs = nf.allocate(R)

In [32]:
ef(R, neighbor=nbrs, coefficients2=coefficients2, coefficients3=coefficients3)

-55.683579202529835


DeviceArray(-237.23930415, dtype=float64)