In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from openforcefield.topology import Molecule, Topology

from openff.system.stubs import ForceField

In [None]:
# Construct a single-molecule system from toolkit classes
mol = Molecule.from_smiles("CCO")
mol.generate_conformers(n_conformers=1)
top = Topology.from_molecules([mol])
parsley = ForceField("openff-1.0.0.offxml")

off_sys = parsley.create_openff_system(topology=top)
bonds = off_sys.handlers["Bonds"]

In [None]:
# Transform parameters into matrix representations
p = bonds.get_force_field_parameters()
mapping = bonds.get_mapping()
q = bonds.get_system_parameters()
m = bonds.get_param_matrix()

In [None]:
# force field parameters, each row is something like [k (kcal/mol/A), length (A)]
p

In [None]:
# system parameters, a.k.a. force field parameters as they exist in a parametrized system
q

In [None]:
# m is the parametrization matrix, which can be dotted with p to get out q
assert np.allclose(m.dot(p.flatten()).reshape((-1, 2)), q)

m

In [None]:
# save and set initial values
q0 = q
p0 = p

# set learning rate
a = 0.05

In [None]:
from copy import deepcopy

In [None]:
# let jax run with autodiff
_, f_vjp_bonds = jax.vjp(bonds.parametrize, jnp.asarray(p0))  #d/dp

In [None]:
q_target = deepcopy(q0)
p_target = deepcopy(p0)

# modify a few of the force field targets to arbitrary values;
# this mimic some "true" values we wish to tune to, despite
# these values not being known in real-world fitting
p_target[:, 1] = 0.5 + np.random.rand(4)

# obtain the target _sytem_ parameters by dotting the parametrization
# matrix with target force field values
q_target = m.dot(p_target.flatten()).reshape((-1, 2))

# create a dummy loss function via faking known target parameters;
# this should probably be something like the the difference between
# computed and reference energies
def loss(q):
    return q - q_target

In [None]:
# derivative of loss function evaluated at the original system parameters;
# note that column 0 mathces target values, so the derivate is flat
f_vjp_bonds(loss(q0))  # dL/dp (!)

In [None]:
fig, ax = plt.subplots()

# label target values
ax.hlines(p_target[0, 1], 0, 100, color="k", ls="--", label="[#6X4:1]-[#6X4:2]")
ax.hlines(p_target[1, 1], 0, 100, color="r", ls="--", label="[#6X4:1]-[#1:2]")
ax.hlines(p_target[2, 1], 0, 100, color="g", ls="--", label="[#6:1]-[#8:2]")
ax.hlines(p_target[3, 1], 0, 100, color="b", ls="--", label="[#8:1]-[#1:2]")

for i in range(100):
    # use jax to get the gradient
    grad = f_vjp_bonds(loss(q))[0]
    # update force field parameters
    p -= a * grad
    # use the parametrization matrix to propagate new
    # force field parameters into new system parameters
    q = m.dot(p.flatten()).reshape((-1, 2))
    if i % 10 == 0:
        print(f"step {i}\tloss: {np.sum(loss(q) ** 2)}")
        ax.plot(i, p[0][1], "k.")
        ax.plot(i, p[1][1], "r.")
        ax.plot(i, p[2][1], "g.")
        ax.plot(i, p[3][1], "b.")

ax.legend(loc=0)
ax.set_xlabel("iteration")
ax.set_ylabel("parameter value (bond length-ish)")
ax.set_xlim((0, 100))
ax.set_ylim((0, 3))

In [None]:
# We can do everything all over again with angles, almost identically
angles = off_sys.handlers["Angles"]
q0 = angles.get_system_parameters()
p0 = angles.get_force_field_parameters()
mapping = angles.get_mapping()
m = angles.get_param_matrix()
q = q0
p = p0
a = 0.05

q_target = deepcopy(q0)
p_target = deepcopy(p0)
p_target[:, 1] = np.random.randint(100, 120, 3)

q_target = m.dot(p_target.flatten()).reshape((-1, 2))


def loss(q):
    return q - q_target


_, f_vjp_angles = jax.vjp(angles.parametrize, jnp.asarray(p))

In [None]:
fig, ax = plt.subplots()

# label target values
ax.hlines(p_target[0, 1], 0, 100, color="k", ls="--", label="[*:1]~[#6X4:2]-[*:3]")
ax.hlines(p_target[1, 1], 0, 100, color="r", ls="--", label="[*:1]-[#8:2]-[*:3]")
ax.hlines(p_target[2, 1], 0, 100, color="g", ls="--", label="[#1:1]-[#6X4:2]-[#1:3]")

for i in range(100):
    # use jax to get the gradient
    grad = f_vjp_angles(loss(q))[0]
    # update force field parameters
    p -= a * grad
    # print(p[0])
    q = m.dot(p.flatten()).reshape((-1, 2))
    if i % 10 == 0:
        print(f"step {i}\tloss: {np.sum(loss(q) ** 2)}")
        ax.plot(i, p[0][1], "k.")
        ax.plot(i, p[1][1], "r.")
        ax.plot(i, p[2][1], "g.")

ax.legend(loc=0)
ax.set_xlabel("iteration")
ax.set_ylabel("parameter value (angle-ish)")
ax.set_xlim((0, 100))
ax.set_ylim((100, 120))

It's likely intractable to have a _single_ matrix that globally represents the entire set of force field parameters $\textit{P}$ because it scales poorly with system size and complexity; consider the corresponding system parameter matrix $\textit{Q}$ and the mapping matrix between them $\textit{M}$. This example separates out the force field into "handlers" and optimizes each independently. TBD is how to co-optimize all handlers at once, which may involve passing data between handlers and/or some creative crafting of loss function(s).