In [4]:
import jax
import jax.numpy as jnp
import jraph
import numpy as np
import time

from dataclasses import dataclass
from pprint import pprint

from mlff.mdx.potential.mlff_potential_sparse import load_model_from_workdir
from mlff.utils import jraph_utils
from mlff.data import AseDataLoaderSparse
from mlff.data import NpzDataLoaderSparse

from so3lr import So3lr

In [5]:
dataloader = NpzDataLoaderSparse(
    'data/ethanol.npz',
)

data, stats = dataloader.load(
    cutoff=5.,
    cutoff_lr=np.inf,
    calculate_neighbors_lr=True,
    pick_idx=np.arange(50) # take the first 50 frames from the data set
)

MLFF:root:Load data from data/ethanol.npz.
MLFF:root:Calculate short range neighbors within cutoff=5.0 Ang.
MLFF:root:Calculate long-range neighbors within cutoff_lr=inf Ang.
  np.ceil(bin_size * nbins_c / face_dist_c).astype(int)

MLFF:root:... done!
55it/s]

In [3]:
# The data loader takes the input data file and converts each structure into a `jraph.GraphsTuple`.

pprint(data[0])

GraphsTuple(nodes={'positions': array([[-0.3712776 , -0.45252128, -0.19895536],
       [-0.72997075,  0.68413099,  0.79622273],
       [ 0.98474324, -0.24509398, -0.47843776],
       [-0.59848449, -1.4733098 ,  0.19563715],
       [-1.00308552, -0.30490059, -1.10558217],
       [-0.52242124,  1.73814903,  0.17984159],
       [-1.6873976 ,  0.53773135,  1.23071945],
       [ 0.08736885,  0.90074002,  1.4408355 ],
       [ 1.214412  , -0.26932519, -1.46412348]]), 'atomic_numbers': array([6, 6, 8, 1, 1, 1, 1, 1, 1]), 'forces': array([[-26.65845641,  23.24921427,   2.1633537 ],
       [ 33.23849668,  45.80338958, -77.83582996],
       [ 43.15698382,  -6.54653358, -44.62892596],
       [  6.9142078 ,   7.25861413,   0.43256605],
       [  3.35740248,  -5.6187004 ,   5.12580848],
       [-34.62580734, -52.22708068,  23.59973159],
       [-23.99129633,   9.05796712,  15.11034046],
       [ 15.82917769, -17.15955234,  35.86599296],
       [-17.10727453,  -3.75687692,  40.18842496]]), 'hirshfel

In [4]:
# Create a fresh iterator.

batch_size = 1

# Batch the graphs.
batched_graphs = jraph.dynamically_batch(
    data,
    n_node=stats['max_num_of_nodes'] * batch_size + 1,
    n_edge=stats['max_num_of_edges'] * batch_size + 1,
    n_graph=batch_size + 1,
    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1
)


so3lr_calc = So3lr(
    calculate_forces=True
)

i = 0

start = time.time()
for graph_batch in batched_graphs:
    # Transform the batched graph to inputs dict.
    inputs = jraph_utils.graph_to_batch_fn(
        graph_batch
    )

    output = so3lr_calc(inputs)

end = time.time()

print('Total time =', end - start)

Total time = 59.07566785812378


In [6]:
# Create a fresh iterator.

batch_size = 1

# Batch the graphs. We discuss below why this is neccessary when running with jax.jit.
batched_graphs = jraph.dynamically_batch(
    data,
    n_node=stats['max_num_of_nodes'] * batch_size + 1,
    n_edge=stats['max_num_of_edges'] * batch_size + 1,
    n_graph=batch_size + 1,
    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1
)

so3lr_calc_jit = jax.jit(so3lr_calc)

i = 0
start = time.time()

predicted = []

for graph_batch in batched_graphs:
    # Transform the batched graph to inputs dict.
    inputs = jraph_utils.graph_to_batch_fn(
        graph_batch
    )
    
    if i == 0:
        start_compile = time.time()
    
    output = jax.block_until_ready(so3lr_calc_jit(inputs))

    if i == 0:
        end_compile = time.time()
    
    graph_batch.nodes['forces_true'] = outputs['forces']
    graph_batch.nodes['energy_true'] = outputs['energy']
    graph_batch.nodes['dipole_vec_true'] = outputs['dipole_vec']
    graph_batch.nodes['hirshfeld_ratios_true'] = outputs['hirshfeld_ratios']

    predicted.append(jraph.unbatch(graph_batch))
    
    i += 1
    

end = time.time()

print('Total number of iterations = ', i)
print('Total time = ', end - start)
print('Time for JIT compile =', end_compile - start_compile)

Total number of iterations =  50
Total time =  6.100495100021362
Time for JIT compile = 5.946508169174194


In [12]:
# You can see the last entry has zeros, due to padding graph. See next section for explanation.
pprint(output)

{'dipole_vec': Array([[-0.1538189 , -0.2870038 , -0.02246577],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32),
 'energy': Array([-14.574012,   0.      ], dtype=float32),
 'forces': Array([[-3.282136  , -2.2960505 ,  2.179946  ],
       [ 2.5132644 ,  1.5762587 , -0.6037756 ],
       [-0.3309193 , -1.0987675 , -0.62489426],
       [ 0.10972196,  1.1264346 ,  0.9013539 ],
       [ 0.80140257,  1.6329639 , -0.98639023],
       [ 0.11060164,  0.6977358 ,  1.1602308 ],
       [ 0.6589254 , -0.90906775,  0.14454728],
       [-0.73322606, -1.1292245 ,  0.02703559],
       [ 0.1523651 ,  0.39971742, -2.1980531 ],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32),
 'hirshfeld_ratios': Array([0.79468775, 0.76813877, 0.92335534, 0.62973064, 0.58136225,
       0.59849924, 0.58792037, 0.59552044, 0.5994324 , 0.        ],      dtype=float32)}


# Why jraph and batching hustle?

You might wonder, why going through all the hustle with `jraph.dynamically_batch` and not just simply define an `inputs` dict for each structure `SO3LR` should be evaluated on. As illustrated above, `jax.jit` takes a long time during the first call, as it creates the computational graph and optimizes it via XLA afterwards. This is great, since we have seen that this can drastically speed up calculations even on CPU. There is a caveat though: `jax.jit` assumes static shapes of the input arrays. As an example, consider some function and its jitted counterpart

```python
fn = lambda A, B: do stuff ..
fn_jit = jax.jit(fn)
```

Every time the shapes of `A` and `B` change, `jax.jit` will trigger a recompile, introducing a drastic computational overhead.


When using `SO3LR` for structures of different size, this would for example correspond to

```python
inputs1['positions'].shape = (N1, 3)
inputs2['positions'].shape = (N2, 3)
inputs2['positions'].shape = (N2, 3)
```

where `N1`, `N2` and `N3` are the number of atoms of the first and second structure. Therefore, when calling `so3lr_jit` for each input there is triggered a recompile. To avoid such issues one can pad all arrays in `inputs` that depend on the input structure to some pre-defined value. Given some `iterator` over `jraph.GraphTuples` this is exactly what `jraph.dynamically_batch` is doing via `n_node`, `n_edge`, `n_graph` and `n_pair`. Afterwards, we transform the batched graph to the input using `jraph_utils.graph_to_inputs`. Note, that there is always one padding graph, which is used during the computation to dump values that depend to padded nodes, edges, and so on. For example, the code below will pad `batch_size = 2` structures per batch. Using the `stats` collected by the data loader, we can choose the remaining values based on the `batch_size`.

In [9]:
batch_size = 2

batched_graphs = jraph.dynamically_batch(
    data,
    n_node=stats['max_num_of_nodes'] * batch_size + 1,
    n_edge=stats['max_num_of_edges'] * batch_size + 1,
    n_graph=batch_size + 1,
    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1
)

# The arrays are numpy arrays at this point so they are on CPU and not on GPU.
jax.tree.map(lambda x: print(type(x)), next(batched_graphs))

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


GraphsTuple(nodes={'atomic_numbers': None, 'forces': None, 'hirshfeld_ratios': None, 'positions': None}, edges={}, receivers=None, senders=None, globals={'dipole_vec': None, 'energy': None, 'num_unpaired_electrons': None, 'stress': None, 'total_charge': None}, n_node=None, n_edge=None, n_pairs=None, idx_i_lr=None, idx_j_lr=None)

In [10]:
pprint(jraph_utils.batch_info_fn(next(batched_graphs)))

{'batch_segments': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2], dtype=int32),
 'graph_mask': Array([ True,  True, False], dtype=bool),
 'node_mask': Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False], dtype=bool),
 'num_of_non_padded_graphs': Array(2, dtype=int32)}
