In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)  # enable double precision (very important)

# domain-specific representations & kernels for molecules
import sys
sys.path.append("../md17/experiments")
sys.path.append("../md17")
from fchl import FCHL19Representation
from gdml_jax.util.datasets import load_md17, get_symmetries
from gdml_jax.kernels import GDMLKernel, sGDMLKernel, GlobalSymmetryKernel

# general operator kernel regression framework
import opgp
from opgp import build_solve
from opgp import build_predict_scalar



In [2]:
def rbf(x1, x2, lengthscale=1.0):
    return jnp.exp(-0.5 * jnp.sum((x1 - x2) ** 2) / lengthscale**2)

def sFCHL19(z, perms, kappa=rbf):
    """Returns a function that computes a scalar kernel between two molecules.

    This is a variant of the sGDML kernel with FCHL19 as descriptors as opposed
    to inverse pairwise distances. Only a few physically plausible permutation 
    symmetries are taken into account to increase efficiency."""
    return GlobalSymmetryKernel(FCHL19Representation(z), kappa, perms, is_atomwise=True)

def mae(a, b):
    return jnp.mean(jnp.abs(a - b))

def evaluate_errors(f, operators, x, observations):
    return {opkey: mae(jax.vmap(op(f))(x[opkey]), observations[opkey]) 
            for (opkey, op) in operators.items()}

def negative(f):
    return lambda x: -f(x)

In [3]:
datadir = "/tmp/md17_train"
molecule = "ethanol"
n_train = 100
n_test = 2000

trainset, testset, meta = load_md17(molecule, n_train, n_test, datadir)
train_x, train_e, train_y = trainset
shape = meta["shape"]
z = meta["z"]
perms = get_symmetries(molecule)
# basekernel = partial(GDMLKernel(shape), lengthscale=10.0)
basekernel = partial(sGDMLKernel(shape, perms=perms), lengthscale=10.0)
# basekernel = partial(sFCHL19(z, perms), lengthscale=10.0) # needs 100GB memory (or batching as implemented in gdml_jax, but not yet in opgp, TODO(niklas))

In [4]:
def fit(k, operators, x, observations):
    solve = build_solve(k, operators, solver="cholesky")
    alphas = solve(x, observations)
    f = build_predict_scalar(k, operators, x, alphas)
    return f

operators = {"grad": jax.grad}
x = {"grad": train_x}
observations  = {"grad": train_y}
f_negative_energy = fit(basekernel, operators, x, observations)

# train error
evaluate_errors(f_negative_energy, operators, x, observations)

{'grad': DeviceArray(1.10090854, dtype=float64)}

In [5]:
test_x, test_e, test_y = testset
evaluate_errors(f_negative_energy, operators, {"grad": test_x}, {"grad": test_y})

{'grad': DeviceArray(1.49639915, dtype=float64)}

## Include both energies and forces

Now we have to respect the sign convention that $F(x) = -\nabla U(x)$ 

In [6]:
train_e_mean = train_e.mean()
operators = {
    "id": lambda f: f, 
    "-grad": lambda f: negative(jax.grad(f))
}
x = {"id": train_x, "-grad": train_x}
observations  = {"id": train_e - train_e_mean, "-grad": train_y}
f_energy = fit(basekernel, operators, x, observations)

# train error
evaluate_errors(f_energy, operators, x, observations)

{'id': DeviceArray(0.39844113, dtype=float64),
 '-grad': DeviceArray(1.10075538, dtype=float64)}

In [7]:
evaluate_errors(f_energy, operators, {"id": test_x, "-grad": test_x}, {"id": test_e - train_e_mean, "-grad": test_y})

{'id': DeviceArray(0.49860371, dtype=float64),
 '-grad': DeviceArray(1.49568216, dtype=float64)}