In [1]:
import random
from openff.toolkit.topology import Molecule
import h5py
import numpy as onp
import jax
import jax.numpy as jnp
import jax_md
import espalomax as esp
from concurrent import futures

  PyTreeDef = type(jax.tree_structure(None))


In [16]:
from run_compile import DataLoader
dataloader = DataLoader(path="data/", partition="all")

In [17]:
model = esp.nn.Parametrization(
    representation=esp.nn.GraphAttentionNetwork(128, 6),
    janossy_pooling=esp.nn.JanossyPooling(128, 2),
)

In [18]:
from flax.training.checkpoints import restore_checkpoint
state = restore_checkpoint("__checkpoint", target=None)
params = state["params"]

In [19]:
g, x, u = next(iter(dataloader))
ff_params = model.apply(params, g)
u_hat = esp.mm.get_energy(ff_params, x)

In [22]:
jaxmd_params = esp.mm.to_jaxmd_mm_energy_fn_parameters(ff_params)

In [23]:
jaxmd_params

MMEnergyFnParameters(harmonic_bond_parameters=HarmonicBondParameters(particles=DeviceArray([[0, 1],
             [0, 2],
             [0, 3],
             [0, 4]], dtype=int32), epsilon=DeviceArray([0.01880366, 0.01880366, 0.01880366, 0.01880366], dtype=float32), length=DeviceArray([0.5550566, 0.5550566, 0.5550566, 0.5550566], dtype=float32)), harmonic_angle_parameters=HarmonicAngleParameters(particles=DeviceArray([[1, 0, 2],
             [1, 0, 3],
             [1, 0, 4],
             [2, 0, 3],
             [2, 0, 4],
             [3, 0, 4]], dtype=int32), epsilon=DeviceArray([0.05592059, 0.05592059, 0.05592059, 0.05592059, 0.05592053,
             0.05592053], dtype=float32), length=DeviceArray([1.079294 , 1.079294 , 1.079294 , 1.079294 , 1.0792929,
             1.0792929], dtype=float32)), periodic_torsion_parameters=PeriodicTorsionParameters(particles=DeviceArray([], dtype=int32), amplitude=DeviceArray([], dtype=float32), periodicity=DeviceArray([], dtype=int32), phase=DeviceArray

In [27]:
def parameters_from_molecule(
        molecule: Molecule,
        base_forcefield: str = "openff_unconstrained-2.0.0.offxml",
):
    """Get jax_md.mm.MMEnergyFnParameters from single molecule.

    Parameters
    ----------
    molecule : Molecule
        Input OpenFF molecule.
    base_forcefield : str
        Base force field for nonbonded terms and exceptions.

    Returns
    -------
    jax_md.mm.MMEnergyFnParameters
        Resulting parameters.
    """
    from openff.toolkit.typing.engines.smirnoff import ForceField
    molecule.assign_partial_charges("zeros")
    forcefield = ForceField(base_forcefield)
    system = forcefield.create_openmm_system(
        molecule.to_topology(),
        charge_from_molecules=[molecule],
    )

    from jax_md.mm_utils import parameters_from_openmm_system
    parameters = parameters_from_openmm_system(system)
    return parameters

parameters_from_molecule(Molecule.from_smiles("C"))

MMEnergyFnParameters(harmonic_bond_parameters=HarmonicBondParameters(particles=DeviceArray([[0, 1],
             [0, 2],
             [0, 3],
             [0, 4]], dtype=int32), epsilon=DeviceArray([309655.08432241, 309655.08432241, 309655.08432241,
             309655.08432241], dtype=float64), length=DeviceArray([0.10938995, 0.10938995, 0.10938995, 0.10938995], dtype=float64)), harmonic_angle_parameters=HarmonicAngleParameters(particles=DeviceArray([[1, 0, 2],
             [1, 0, 3],
             [1, 0, 4],
             [2, 0, 3],
             [2, 0, 4],
             [3, 0, 4]], dtype=int32), epsilon=DeviceArray([408.16169048, 408.16169048, 408.16169048, 408.16169048,
             408.16169048, 408.16169048], dtype=float64), length=DeviceArray([2.01765472, 2.01765472, 2.01765472, 2.01765472, 2.01765472,
             2.01765472], dtype=float64)), periodic_torsion_parameters=PeriodicTorsionParameters(particles=DeviceArray([], dtype=int32), amplitude=DeviceArray([], dtype=float64), peri