In [None]:
import typing

import openff.interchange
import openff.toolkit
import pandas
import torch
from rdkit import Chem

import smee
from utils import load_dimers, mol_to_image, plot_energies, compute_dimer_vdw_energy


In [None]:
openff_ff = openff.toolkit.ForceField("openff-2.1.0.offxml")

dataset = "DESS66x8"
results_folder = "DES66x8-fittings/results/"
fit_name = "sage-2.1-opt"
tensor_ff, dimers = load_dimers(dataset_name=dataset, forcefield=openff_ff)

In [None]:
import copy

vdw_potential_initial = tensor_ff.potentials_by_type["vdW"]
vdw_potential = copy.deepcopy(vdw_potential_initial)

vdw_parameters_sqrt = torch.clone(vdw_potential.parameters).sqrt()
vdw_parameters_sqrt.requires_grad = True

# fitting_tensor_ff = copy.deepcopy(tensor_ff)

# fitting_tensor_ff.v_sites.parameters.requires_grad = True


# optimizer = torch.optim.Adam([vdw_parameters_sqrt, fitting_tensor_ff.v_sites.parameters], lr=0.001)
optimizer = torch.optim.Adam([vdw_parameters_sqrt], lr=0.001)

for epoch in range(1000):
    # prevent negative values and bug where gradient is NaN
    vdw_potential.parameters = vdw_parameters_sqrt**2
    vdw_potential.parameters[:, 0] += 1.0e-10

    loss = torch.zeros(1)

    for dimer in dimers:
        ana2b_d3 = dimer.ana2b_v_ana + dimer.ana2b_v_d3

        lj_energies = compute_dimer_vdw_energy(dimer=dimer, potential=vdw_potential, ff=tensor_ff)

        loss += torch.sum((ana2b_d3 - lj_energies) ** 2)

    loss.backward()
    # hide gradients of other v-site parameters
    # fitting_tensor_ff.v_sites.parameters.grad[:, 1:] = 0.0
    # make sure the normal H LJ has no gradient only needed when fitting contracted sites
    # vdw_parameters_sqrt.grad[4, 0] = 0.0 
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 5 == 0 or epoch == 74:
        print(f"Epoch {epoch}: loss={loss.item()} kcal / mol")

In [None]:
import numpy as np
import os

rows = []
energy_diff_initial = []
energy_diff_final = []


for dimer in dimers:

    ccsd_energies = dimer.ccsd
    ana2b_d3 = dimer.ana2b_v_ana + dimer.ana2b_v_d3

    lj_energies_initial = compute_dimer_vdw_energy(dimer=dimer, potential=vdw_potential_initial, ff=tensor_ff)
    lj_energies_opt = compute_dimer_vdw_energy(dimer=dimer, potential=vdw_potential, ff=tensor_ff)

    energy_diff_initial.extend(lj_energies_initial.numpy() - ana2b_d3.numpy())
    opt_lj = lj_energies_opt.detach().numpy()
    energy_diff_final.extend(opt_lj - ana2b_d3.numpy())

    rows.append(
        {
            "Dimer": mol_to_image(dimer.smiles_a, dimer.smiles_b),
            "Group": dimer.group_id,
            f"Energy ": plot_energies(
                dimer.distances,
                {
                    "Sage LJ": lj_energies_initial,
                    "CCSD": ccsd_energies,
                    "ANA2B+D3": ana2b_d3,
                    "Opt LJ": lj_energies_opt,
                },
            ),
            "Sage LJ RMSE [kcal / mol]": np.sqrt(np.mean((lj_energies_initial.numpy() - ana2b_d3.numpy()) **2)),
            "Opt LJ RMSE [kcal / mol]": np.sqrt(np.mean((opt_lj - ana2b_d3.numpy()) **2))
        }
    )
rows.append({"Sage LJ RMSE [kcal / mol]": np.sqrt(np.mean(np.square(energy_diff_initial))), "Opt LJ RMSE [kcal / mol]": np.sqrt(np.mean(np.square(energy_diff_final)))})
os.makedirs(os.path.join(results_folder, fit_name))
pandas.DataFrame(rows).to_html(os.path.join(results_folder, fit_name, "sage-2.1-opt-lj.html"), escape=False, index=False)

In [None]:
# save the new LJ values back to the force field and write to file
from openff.units import unit

vdw_handler = openff_ff.get_parameter_handler("vdW")
for i in range(len(vdw_potential.parameters)):
    smirks = vdw_potential.parameter_keys[i].id
    if "EP" not in smirks:
       # update the parameter
       p = vdw_handler[smirks]
       epsilon, sigma = vdw_potential.parameters[i].detach().numpy()
       p.epsilon = epsilon * unit.kilocalorie / unit.mole
       p.sigma = sigma * unit.angstrom

In [None]:
openff_ff.to_file(os.path.join(results_folder, fit_name, "opt_ff.offxml"))

In [None]:
# create a table of the parameter changes and save to csv also include a plot of parameter change?
parameter_changes = []
for i in range(len(vdw_potential.parameter_keys)):
    initial_params = vdw_potential_initial.parameters[i].numpy()
    final_params = vdw_potential.parameters[i].detach().numpy()
    parameter_changes.append(
        {
            "smirks": vdw_potential.parameter_keys[i].id,
            "epsilon initial": initial_params[0],
            "epsilon final": final_params[0],
            "sigma initial": initial_params[1],
            "sigma final": final_params[1]
         }
    )

pandas.DataFrame(parameter_changes).to_csv(os.path.join(results_folder, fit_name, "opt_ff_changes.csv"), index=False)