In [1]:
import jax
import jax.numpy as jnp
import jraph
import numpy as np
import time

from dataclasses import dataclass
from pprint import pprint

from mlff.mdx.potential.mlff_potential_sparse import load_model_from_workdir
from mlff.utils import jraph_utils
from mlff.data import AseDataLoaderSparse
from mlff.data import NpzDataLoaderSparse

from so3lr import So3lr



In [2]:
dataloader = NpzDataLoaderSparse(
    'data/ethanol.npz',
)

data, stats = dataloader.load(
    cutoff=4.5,
    cutoff_lr=np.inf,
    calculate_neighbors_lr=True,
    pick_idx=np.arange(50) # take the first 50 frames from the data set
)

MLFF:root:Load data from data/ethanol.npz.
MLFF:root:Calculate short range neighbors within cutoff=4.5 Ang.
MLFF:root:Calculate long-range neighbors within cutoff_lr=inf Ang.
  np.ceil(bin_size * nbins_c / face_dist_c).astype(int)
2000it [00:00, 44849.75it/s]
MLFF:root:... done!


In [3]:
# The data loader takes the input data file and converts each structure into a `jraph.GraphsTuple`.

pprint(data[0])

GraphsTuple(nodes={'positions': array([[-0.3712776 , -0.45252128, -0.19895536],
       [-0.72997075,  0.68413099,  0.79622273],
       [ 0.98474324, -0.24509398, -0.47843776],
       [-0.59848449, -1.4733098 ,  0.19563715],
       [-1.00308552, -0.30490059, -1.10558217],
       [-0.52242124,  1.73814903,  0.17984159],
       [-1.6873976 ,  0.53773135,  1.23071945],
       [ 0.08736885,  0.90074002,  1.4408355 ],
       [ 1.214412  , -0.26932519, -1.46412348]]), 'atomic_numbers': array([6, 6, 8, 1, 1, 1, 1, 1, 1]), 'forces': array([[-26.65845641,  23.24921427,   2.1633537 ],
       [ 33.23849668,  45.80338958, -77.83582996],
       [ 43.15698382,  -6.54653358, -44.62892596],
       [  6.9142078 ,   7.25861413,   0.43256605],
       [  3.35740248,  -5.6187004 ,   5.12580848],
       [-34.62580734, -52.22708068,  23.59973159],
       [-23.99129633,   9.05796712,  15.11034046],
       [ 15.82917769, -17.15955234,  35.86599296],
       [-17.10727453,  -3.75687692,  40.18842496]]), 'hirshfel

In [4]:
# Create a fresh iterator.

batch_size = 1

# Batch the graphs.
batched_graphs = jraph.dynamically_batch(
    data,
    n_node=stats['max_num_of_nodes'] * batch_size + 1,
    n_edge=stats['max_num_of_edges'] * batch_size + 1,
    n_graph=batch_size + 1,
    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1
)


so3lr_calc = So3lr(
    calculate_forces=True,
    lr_cutoff=100
)

total_compute_time = 0.0
for graph_batch in batched_graphs:
    # Transform the batched graph to inputs dict.
    inputs = jraph_utils.graph_to_batch_fn(
        graph_batch
    )
    
    # Add theory_mask field
    inputs['theory_mask'] = jnp.eye(16)[0].reshape(1, -1)

    # Measure he model computation
    start = time.perf_counter()
    output = so3lr_calc(inputs)
    end = time.perf_counter()
    total_compute_time += (end - start)

print(f'Pure computation time without jit = {total_compute_time:.4g}')

Pure computation time without jit = 48.96


In [5]:
# Create a fresh iterator for JIT test
batched_graphs_jit = jraph.dynamically_batch(
    data,
    n_node=stats['max_num_of_nodes'] * batch_size + 1,
    n_edge=stats['max_num_of_edges'] * batch_size + 1,
    n_graph=batch_size + 1,
    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1
)

so3lr_calc_jit = jax.jit(so3lr_calc)

i = 0
total_compute_time = 0.0
compile_time = 0.0

for graph_batch in batched_graphs_jit:
    # Transform the batched graph to inputs dict.
    inputs = jraph_utils.graph_to_batch_fn(
        graph_batch
    )
    
    # Add theory_mask field that is required by the MLFF model
    # The NPZ data doesn't include this field, so we need to add it manually
    inputs['theory_mask'] = jnp.eye(16)[0].reshape(1, -1)  # Use first theory level for all graphs
    
    # Measure ONLY the model computation
    start = time.perf_counter()
    output = jax.block_until_ready(so3lr_calc_jit(inputs))
    end = time.perf_counter()
    
    if i == 0:
        compile_time = end - start
        pure_compute_time = 0.0  # Don't add first computation to pure time - avoid floating point errors
    else:
        pure_compute_time += (end - start)  # Add only pure computation

    
    
    graph_batch.nodes['forces_true'] = output['forces']
    graph_batch.nodes['energy_true'] = output['energy']
    graph_batch.nodes['dipole_vec_true'] = output['dipole_vec']
    graph_batch.nodes['hirshfeld_ratios_true'] = output['hirshfeld_ratios']

    # predicted.append(jraph.unbatch(graph_batch))
    
    i += 1

    
total_compute_time = compile_time + pure_compute_time

print('Total number of iterations = ', i)
print(f'Compilation time plus first computation (first call) = {compile_time:.4g}')
print(f'Pure computation time from second computation onwards (iterations 2+) = {pure_compute_time:.4g}')
print(f'Average time per computation (including compilation overhead) = {total_compute_time/i:.4g}')
if i > 1: 
    print(f'Average time per computation (excluding compilation) = {pure_compute_time/(i-1):.4g}')
print(f'Total wall clock time = {total_compute_time:.4g}')

Total number of iterations =  50
Compilation time plus first computation (first call) = 7.325
Pure computation time from second computation onwards (iterations 2+) = 0.03676
Average time per computation (including compilation overhead) = 0.1472
Average time per computation (excluding compilation) = 0.0007501
Total wall clock time = 7.362


In [6]:
# You can see the last entry has zeros, due to padding graph. See next section for explanation.
pprint(output)

{'dipole_vec': Array([[-0.15462711, -0.28720653, -0.02340287],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32),
 'energy': Array([-14.743898,   0.      ], dtype=float32),
 'forces': Array([[-3.2873504 , -2.3319964 ,  2.048338  ],
       [ 2.5215158 ,  1.5717448 , -0.6058162 ],
       [-0.37934327, -1.2226087 , -0.47387624],
       [ 0.10440782,  1.1608028 ,  0.9021926 ],
       [ 0.7856201 ,  1.6373655 , -0.95621043],
       [ 0.11453546,  0.68705595,  1.147562  ],
       [ 0.6586375 , -0.9075333 ,  0.16016062],
       [-0.7316438 , -1.1171225 ,  0.02287863],
       [ 0.21362092,  0.522292  , -2.2452292 ],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32),
 'hirshfeld_ratios': Array([0.79526365, 0.7683944 , 0.9239218 , 0.63038445, 0.5819373 ,
       0.5990365 , 0.5886431 , 0.5959441 , 0.6013007 , 0.        ],      dtype=float32)}


# Why jraph and batching hustle?

You might wonder, why going through all the hustle with `jraph.dynamically_batch` and not just simply define an `inputs` dict for each structure `SO3LR` should be evaluated on. As illustrated above, `jax.jit` takes a long time during the first call, as it creates the computational graph and optimizes it via XLA afterwards. This is great, since we have seen that this can drastically speed up calculations even on CPU. There is a caveat though: `jax.jit` assumes static shapes of the input arrays. As an example, consider some function and its jitted counterpart

```python
fn = lambda A, B: do stuff ..
fn_jit = jax.jit(fn)
```

Every time the shapes of `A` and `B` change, `jax.jit` will trigger a recompile, introducing a drastic computational overhead.


When using `SO3LR` for structures of different size, this would for example correspond to

```python
inputs1['positions'].shape = (N1, 3)
inputs2['positions'].shape = (N2, 3)
inputs2['positions'].shape = (N2, 3)
```

where `N1`, `N2` and `N3` are the number of atoms of the first and second structure. Therefore, when calling `so3lr_jit` for each input there is triggered a recompile. To avoid such issues one can pad all arrays in `inputs` that depend on the input structure to some pre-defined value. Given some `iterator` over `jraph.GraphTuples` this is exactly what `jraph.dynamically_batch` is doing via `n_node`, `n_edge`, `n_graph` and `n_pair`. Afterwards, we transform the batched graph to the input using `jraph_utils.graph_to_inputs`. Note, that there is always one padding graph, which is used during the computation to dump values that depend to padded nodes, edges, and so on. For example, the code below will pad `batch_size = 2` structures per batch. Using the `stats` collected by the data loader, we can choose the remaining values based on the `batch_size`.

In [7]:
batch_size = 2

batched_graphs = jraph.dynamically_batch(
    data,
    n_node=stats['max_num_of_nodes'] * batch_size + 1,
    n_edge=stats['max_num_of_edges'] * batch_size + 1,
    n_graph=batch_size + 1,
    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1
)

# The arrays are numpy arrays at this point so they are on CPU and not on GPU.
jax.tree.map(lambda x: print(type(x)), next(batched_graphs))

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


GraphsTuple(nodes={'atomic_numbers': None, 'forces': None, 'hirshfeld_ratios': None, 'positions': None}, edges={}, receivers=None, senders=None, globals={'dipole_vec': None, 'energy': None, 'num_unpaired_electrons': None, 'stress': None, 'total_charge': None}, n_node=None, n_edge=None, n_pairs=None, idx_i_lr=None, idx_j_lr=None)

In [8]:
pprint(jraph_utils.batch_info_fn(next(batched_graphs)))

{'batch_segments': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2], dtype=int32),
 'graph_mask': Array([ True,  True, False], dtype=bool),
 'node_mask': Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False], dtype=bool),
 'num_of_non_padded_graphs': Array(2, dtype=int32)}
