<a href="https://colab.research.google.com/github/collvey/jaxani/blob/main/Jaxani_Energy_Calculation_Validation_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pre-steps

In [2]:
!git clone https://github.com/collvey/jaxani.git

Cloning into 'jaxani'...
remote: Enumerating objects: 203, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (137/137), done.[K
remote: Total 203 (delta 107), reused 147 (delta 57), pack-reused 0[K
Receiving objects: 100% (203/203), 70.96 KiB | 2.73 MiB/s, done.
Resolving deltas: 100% (107/107), done.


In [3]:
import sys

sys.path.insert(0, '/content/jaxani')

In [4]:
!pip install -r ./jaxani/test_requirements.txt

Collecting lark (from -r ./jaxani/test_requirements.txt (line 2))
  Downloading lark-1.1.9-py3-none-any.whl (111 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.7/111.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lark
Successfully installed lark-1.1.9


# Example Usage

In [5]:
import jax.numpy as jnp
import unittest
import os

from flax.training import train_state, checkpoints
from jaxani.constants import Constants
from jaxani.aev import AEVComputer
from jaxani.nn import SpeciesConverter
from jaxani.utils import load_sae
from jaxani.model import rebuild_model_ensemble
from test_util.generate_test_checkpoint import generate_test_checkpoint
from neurochem.parse_resources import parse_neurochem_resources

CKPT_DIR = os.path.join('./jaxani/test/test_ckpts')
CKPT_PREFIX = 'test_ensemble_'

def jax_energy_from_restored_state(test_species, test_coordinates):
    jax_species = jnp.array(test_species)
    jax_coordinates = jnp.array(test_coordinates)

    info_file = 'ani-2x_8x.info'
    # Loads info file
    const_file, sae_file, _ensemble_prefix, _ensemble_size = parse_neurochem_resources(info_file)

    consts = Constants(const_file)
    jax_aev_computer = AEVComputer(**consts)
    jax_species_converter = SpeciesConverter(consts.species)
    jax_energy_shifter, _sae_dict = load_sae(sae_file, return_dict=True)

    # Converts species from periodic table index to internal ordering scheme
    jax_species, jax_coordinates = jax_species_converter((
        jax_species, jax_coordinates))
    # Computes AEVs
    jax_species, jax_aevs = jax_aev_computer.forward((jax_species, jax_coordinates))
    # Load ensemble model and params from restored state
    if not os.path.exists(os.path.join(CKPT_DIR, f'{CKPT_PREFIX}0')):
        generate_test_checkpoint()
    restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=None, prefix=CKPT_PREFIX)
    rebuilt_model_ensemble = rebuild_model_ensemble(restored_state['params'])
    # Calculates potential energy
    _, total_energy = rebuilt_model_ensemble.apply(restored_state['params'], (jax_species, jax_aevs))
    # Adds atomic energies
    total_energy = total_energy + jax_energy_shifter.sae(jax_species)
    return total_energy

if __name__ == '__main__':
    test_species = [[6, 1, 7, 8, 1]]
    test_coordinates = [[
        [0.03192167, 0.00638559, 0.01301679],
        [-0.83140486, 0.39370209, -0.26395324],
        [-0.66518241, -0.84461308, 0.20759389],
        [0.45554739, 0.54289633, 0.81170881],
        [0.66091919, -0.16799635, -0.91037834]]]
    energy = jax_energy_from_restored_state(test_species, test_coordinates)
    print(energy)

Downloading ANI model parameters ...
[-168.81503569]


The calculated enregy from `jaxani` is -168.81503569

# Performance

From manual testing on local device, the energy calculation takes 2.04 seconds before optimization.

In [None]:
# %%timeit
# energy = jax_energy_from_restored_state(test_species, test_coordinates)

Colab `%%timeit` shows the energy calculation starting from loading model takes 4.5 s ± 691 ms

# Validation

We import `torchani` module to validate the same output given the test input.

In [6]:
!pip install torchani

Collecting torchani
  Downloading torchani-2.2.4-py3-none-any.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
Collecting lark-parser (from torchani)
  Downloading lark_parser-0.12.0-py2.py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.5/103.5 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lark-parser, torchani
Successfully installed lark-parser-0.12.0 torchani-2.2.4


In [7]:
# -*- coding: utf-8 -*-
"""
Computing Energy and Force Using Models Inside Model Zoo
========================================================

TorchANI has a model zoo trained by NeuroChem. These models are shipped with
TorchANI and can be used directly.
"""

###############################################################################
# To begin with, let's first import the modules we will use:
import torch
import torchani

###############################################################################
# Let's now manually specify the device we want TorchANI to run:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization. Predicting the energy and force
# using the average of the 8 models outperform using a single model, so it is
# always recommended to use an ensemble, unless the speed of computation is an
# issue in your application.
#
# The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species. If not specified, you need to use
# 0, 1, 2, 3, ... to index species
model = torchani.models.ANI2x(periodic_table_index=True).to(device)

###############################################################################
# Now let's define the coordinate and species. If you just want to compute the
# energy and force for a single structure like in this example, you need to
# make the coordinate tensor has shape ``(1, Na, 3)`` and species has shape
# ``(1, Na)``, where ``Na`` is the number of atoms in the molecule, the
# preceding ``1`` in the shape is here to support batch processing like in
# training. If you have ``N`` different structures to compute, then make it
# ``N``.
#
# .. note:: The coordinates are in Angstrom, and the energies you get are in Hartree
# coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
#                              [-0.83140486, 0.39370209, -0.26395324],
#                              [-0.66518241, -0.84461308, 0.20759389],
#                              [0.45554739, 0.54289633, 0.81170881],
#                              [0.66091919, -0.16799635, -0.91037834]]],
#                            requires_grad=True, device=device)
# # In periodic table, C = 6 and H = 1
# species = torch.tensor([[6, 1, 1, 1, 1]], device=device)

coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
                             [-0.83140486, 0.39370209, -0.26395324],
                             [-0.66518241, -0.84461308, 0.20759389],
                             [0.45554739, 0.54289633, 0.81170881],
                             [0.66091919, -0.16799635, -0.91037834]]],
                           requires_grad=True, device=device)
species = torch.tensor([[6, 1, 7, 8, 1]], device=device)

###############################################################################
# Now let's compute energy and force:
energy = model((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative

###############################################################################
# And print to see the result:
print('Energy:', energy.item())
# print('Force:', force.squeeze())



/usr/local/lib/python3.10/dist-packages/torchani/resources/
Downloading ANI model parameters ...
Energy: -168.8150356803993


The calculated energy is -168.81503562079465.