In [4]:
import jax
import jax.numpy as jnp
from jax.config import config
from gdml_jax.util.datasets import load_md17, get_symmetries
from gdml_jax.models import GDMLPredict, GDMLPredictEnergy
from gdml_jax.solve import solve_closed
from gdml_jax import losses


# Enable double precision.
config.update("jax_enable_x64", True)

DATA_DIR = "/tmp/md17"
MOLECULE = "ethanol"
N_TRAIN = 200
N_TEST = 200
REG = jnp.float64(1e-10)
VALIDATION_SPLIT = 0.8

# Data loading.
trainset, testset, meta = load_md17(MOLECULE, N_TRAIN, N_TEST, DATA_DIR)
train_x, train_e, train_f = trainset

# Split train data into train and validation part for hyperparameter search.
split = int(jnp.floor(VALIDATION_SPLIT * N_TRAIN))
train_x, val_x = train_x[:split], train_x[split:]
train_e, val_e = train_e[:split], train_e[split:]
train_f, val_f = train_f[:split], train_f[split:]


def loss_from_kernel(basekernel, kernel_kwargs, reg=REG):
    params = solve_closed(basekernel, train_x, train_f, reg=reg, kernel_kwargs=kernel_kwargs)
    force_fn = GDMLPredict(basekernel, train_x)
    # The energy_fn also needs train_e and params to estimate the integration constant.
    energy_fn = GDMLPredictEnergy(basekernel, train_x, train_e, params)
    print(f"train error:")
    print(f"forces MAE (component-wise): {losses.mae(train_f, force_fn(params, train_x))}")
    print(f"energy MAE:                  {losses.mae(train_e, energy_fn(train_x))}")
    print(f"validation error:")
    print(f"forces MAE (component-wise): {losses.mae(val_f, force_fn(params, val_x))}")
    print(f"energy MAE:                  {losses.mae(val_e, energy_fn(val_x))}")

In [5]:
"""Here we define a custom GDML-style kernel from scratch with a pairwise descriptor."""

def pairwise_descriptor(x, power):
    num_atoms, _ = x.shape
    idx_i, idx_j = jnp.triu_indices(num_atoms, k=1)
    r_ij = jnp.linalg.norm(x[idx_i] - x[idx_j], axis=1)
    assert r_ij.shape == (num_atoms * (num_atoms - 1) // 2,)
    return jnp.power(r_ij, -power) # 1 / r_ij ** power

def rbf(x1, x2, lengthscale):
    return jnp.exp(-jnp.float64(0.5) * jnp.sum(jnp.square(x1 - x2)) / jnp.square(lengthscale))

def basekernel(x1, x2, lengthscale, power):
    d1 = pairwise_descriptor(x1, power)
    d2 = pairwise_descriptor(x2, power)
    return rbf(d1, d2, lengthscale)

kernel_kwargs = {"lengthscale": jnp.float64(5.0), "power": jnp.float64(1.0)}
loss_from_kernel(basekernel, kernel_kwargs)

train error:
forces MAE (component-wise): 0.14117780028317145
energy MAE:                  0.2323495020311384
validation error (n_valid):
forces MAE (component-wise): 2.190765784053275
energy MAE:                  0.4103110136129544


In [6]:
"""Here we additionally take selected permutation symmetries into account in the style of sGDML."""

perms = get_symmetries(MOLECULE)  # precomputed from https://github.com/stefanch/sgdml

def symmetrized_basekernel(x1, x2, lengthscale, power):
    d1 = pairwise_descriptor(x1, power)
    def basekernel_per_permutation(perm):
        d2 = pairwise_descriptor(x2[perm], power)
        return rbf(d1, d2, lengthscale)
    return jnp.mean(jax.vmap(basekernel_per_permutation)(perms))

kernel_kwargs = {"lengthscale": jnp.float64(5.0), "power": jnp.float64(1.0)}
loss_from_kernel(symmetrized_basekernel, kernel_kwargs)

train error:
forces MAE (component-wise): 0.44560376328630064
energy MAE:                  0.20943909990255635
validation error (n_valid):
forces MAE (component-wise): 0.769969214088659
energy MAE:                  0.15866935875383206
