# Fitting Workshop

## Preparing dataset

In [None]:
from qcportal import PortalClient

# Connect to the QCArchive portal
client = PortalClient("https://api.qcarchive.molssi.org")

# Get all dataset names using the newer API
datasets = client.list_datasets()

In [None]:
[
    (dataset["dataset_name"], dataset["dataset_type"])
    for dataset in datasets
    if "openff" in dataset["dataset_name"].lower()
    and "lipid" in dataset["dataset_name"].lower()
]

In [None]:
from openff.qcsubmit.results import OptimizationResultCollection
from qcportal import PortalClient
from qcportal.client import SinglepointDriver

client = PortalClient("https://api.qcarchive.molssi.org")  # TODO: Caching?

results = OptimizationResultCollection.from_server(
    client=client,
    datasets=["OpenFF Protein PDB 4-mers v4.0"],
    spec_name="default",
)

In [None]:
from qcportal.client import SinglepointDriver

records, molecules = zip(
    *results.to_basic_result_collection(
        [*SinglepointDriver],
    ).to_records(
        include=[
            "molecule",
            "identifiers",
            "properties",
        ]
    )
)

In [None]:
import descent.targets.energy
import numpy
import torch
import tqdm
from openff.units import Quantity, unit

data = {}
for record in tqdm.tqdm(records):
    smiles = (
        record.molecule.identifiers.canonical_isomeric_explicit_hydrogen_mapped_smiles
    )

    geometry = Quantity(record.molecule.geometry, "bohr")
    return_energy = (
        Quantity(
            record.properties["return_energy"],
            "hartree",
        )
        * unit.avogadro_constant
    )
    gradient = (
        Quantity(
            numpy.array(record.properties["scf total gradient"]).reshape((-1, 3)),
            "hartree/bohr",
        )
        * unit.avogadro_constant
    )

    entry = data.setdefault(
        smiles,
        descent.targets.energy.Entry(
            smiles=smiles,
            coords=[],
            energy=[],
            forces=[],
        ),
    )

    entry["coords"].append(geometry.m_as("angstrom"))
    entry["energy"].append(return_energy.m_as("kcal/mol"))
    entry["forces"].append(-gradient.m_as("kcal/mol/angstrom"))

In [None]:
from openff.toolkit import Molecule

for entry in data.values():
    n_confs = len(entry["energy"])
    n_atoms = Molecule.from_mapped_smiles(entry["smiles"]).n_atoms
    entry["coords"] = numpy.asarray(entry["coords"]).reshape(n_confs, n_atoms, 3)
    entry["forces"] = numpy.asarray(entry["coords"]).reshape(n_confs, n_atoms, 3)

ds = descent.targets.energy.create_dataset(data.values())

In [None]:
# ??filtering??
# ??Check coverage??

## Parametrization

In [None]:
from openff.toolkit import ForceField, Molecule

initial_ff = ForceField("openff_unconstrained-2.3.0.offxml")

In [None]:
# TODO: Parallelize?

interchanges = [
    initial_ff.create_interchange(Molecule.from_mapped_smiles(smiles).to_topology())
    for smiles in tqdm.tqdm(ds["smiles"])
]

In [None]:
import smee.converters

device = "cpu"

tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges)

assert len(tensor_tops) == len(interchanges) == len(ds)

tensor_tops_by_smiles = {
    smiles: ttop.to(device) for smiles, ttop in zip(ds["smiles"], tensor_tops)
}
tensor_ff = tensor_ff.to(device)

## Fitting

In [None]:
import descent.train

parameters = {
    "Bonds": descent.train.ParameterConfig(
        cols=["k", "length"],
        scales={"k": 1.0e-2, "length": 1.0},
        limits={"k": [0.0, None], "length": [0.0, None]},
    ),
    "Angles": descent.train.ParameterConfig(
        cols=["k", "angle"],
        scales={"k": 1.0e-2, "angle": 1.0e-2},
        limits={"k": [0.0, None], "angle": [0.0, 3.141592653589793]},
    ),
    "ProperTorsions": descent.train.ParameterConfig(
        cols=["k"],
        scales={"k": 1.0e1},
        limits={"k": [0.0, None]},
    ),
}
attributes: dict[str, descent.train.AttributeConfig] = {}

In [None]:
trainable = descent.train.Trainable(
    force_field=tensor_ff, parameters=parameters, attributes=attributes
)

In [None]:
import tensorboardX
import torch
from tqdm.notebook import tqdm

energy_weight = 1.0
force_weight = 1.0
n_epochs = 10
learning_rate = 1.0 / 15000
batch_size = len(ds)
directory = "tensorboard_logs"
train_data = ds

train_dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    collate_fn=lambda samples: samples,
    pin_memory=True,
)
trainable_parameters = trainable.to_values().to(device)

print("Start training...")
with tensorboardX.SummaryWriter(str(directory)) as writer:
    optimizer = torch.optim.Adam(
        [trainable_parameters],
        lr=learning_rate,
        amsgrad=True,
    )

    epoch_tqdm = tqdm(range(n_epochs), desc="epochs", dynamic_ncols=True)
    for i in epoch_tqdm:
        ff = trainable.to_force_field(trainable_parameters)
        epoch_loss = torch.zeros(size=(1,), device=device)
        energy_loss = torch.zeros(size=(1,), device=device)
        force_loss = torch.zeros(size=(1,), device=device)
        grad = None

        batch_tqdm = tqdm(
            leave=False,
            desc="computing loss",
            total=len(train_data),
            unit="tops",
            dynamic_ncols=True,
        )
        for cpu_batch in train_dataloader:
            # Copy the batch to device
            batch = [
                {k: v if k == "smiles" else v.to(device) for k, v in sample.items()}
                for sample in cpu_batch
            ]
            true_batch_size = len(batch)
            # Compute forces and energies
            e_ref, e_pred, f_ref, f_pred = descent.targets.energy.predict(
                batch,  # type: ignore
                ff,
                tensor_tops_by_smiles,
                "mean",
            )
            # Compute L2 loss
            batch_loss_energy = ((e_pred - e_ref) ** 2).sum() / true_batch_size
            batch_loss_force = ((f_pred - f_ref) ** 2).sum() / true_batch_size

            # Equal sum of L2 loss on energies and forces
            batch_loss = batch_loss_energy + batch_loss_force

            # Compute the gradient of batch_loss wrt trainable_parameters
            (batch_grad,) = torch.autograd.grad(
                batch_loss,
                trainable_parameters,
                create_graph=True,
            )
            # Add the batch gradient to the cumulative epoch gradient
            batch_grad = batch_grad.detach()
            if grad is None:
                grad = batch_grad
            else:
                grad += batch_grad

            # keep cumulative epoch losses to report MSE at the end
            epoch_loss += batch_loss.detach()
            energy_loss += batch_loss_energy.detach()
            force_loss += batch_loss_force.detach()

            # Update the progress bar
            batch_tqdm.update(true_batch_size)
        batch_tqdm.close()

        # Write results to logs
        epoch_tqdm.set_description(
            f"loss: {epoch_loss.detach().item()}, epochs",
        )

        writer.add_scalar("loss", epoch_loss.detach().item(), i)
        writer.add_scalar("loss_energy", energy_loss.detach().item(), i)
        writer.add_scalar("loss_forces", force_loss.detach().item(), i)

        writer.add_scalar("rmse_energy", energy_loss.detach().sqrt().item(), i)
        writer.add_scalar("rmse_forces", force_loss.detach().sqrt().item(), i)
        writer.flush()

        # Perform the optimization step
        trainable_parameters.grad = grad
        optimizer.step()
        optimizer.zero_grad()

In [None]:
%load_ext tensorboard
%tensorboard --logdir $directory

In [None]:
optimized_tensor_ff = trainable.to_force_field(trainable_parameters)

In [None]:
from collections.abc import Collection

from openff.toolkit.typing.engines.smirnoff.parameters import ParameterHandler

HANDLERS_WITHOUT_XML_PARAMETERS = {
    "NAGLChargesHandler",
    "ToolkitAM1BCCHandler",
}


def update_parameters(
    handler: ParameterHandler,
    potential: smee.TensorPotential,
    config: descent.train.ParameterConfig | None,
):

    for key, values in zip(
        potential.parameter_keys,
        potential.parameters,
        strict=True,
    ):
        if key.associated_handler in HANDLERS_WITHOUT_XML_PARAMETERS:
            continue
        parameter = handler[key.id]
        for name, unit, value in zip(
            potential.parameter_cols,
            potential.parameter_units,
            values,
            strict=True,
        ):
            if config is not None and name not in config.cols:
                continue
            name = name if key.mult is None else f"{name}{key.mult+1}"
            try:
                setattr(parameter, name, value * unit)
            except Exception:
                print(f"    COULD NOT UPDATE {key.id=} {name=} {unit=} {value=} {key.mult=}")


def update_attributes(
    handler: ParameterHandler,
    potential: smee.TensorPotential,
    config: descent.train.AttributeConfig | None,
):
    for name, value, unit in zip(
        [] if potential.attribute_cols is None else potential.attribute_cols,
        [] if potential.attributes is None else potential.attributes,
        [] if potential.attribute_units is None else potential.attribute_units,
        strict=True,
    ):
        if config is not None and name not in config.cols:
            continue
        setattr(handler, name, value * unit)


def write_smirnoff(
    initial_ff: ForceField,
    optimized_tensor_ff: smee.TensorForceField,
    parameters: None | dict[str, descent.train.ParameterConfig] = None,
    attributes: None | dict[str, descent.train.AttributeConfig] = None,
):

    optimized_smirnoff_ff = ForceField(initial_ff.to_string())
    for potential in optimized_tensor_ff.potentials:
        print(potential.type)
        handler = optimized_smirnoff_ff[potential.type]
        if parameters is None or potential.type in parameters:
            print("  updating parameters")
            update_parameters(
                handler,
                potential,
                None if parameters is None else parameters[potential.type],
            )
        if attributes is None or potential.type in attributes:
            print("  updating attributes")
            update_attributes(
                handler,
                potential,
                None if attributes is None else attributes[potential.type],
            )
    return optimized_smirnoff_ff


optimized_smirnoff_ff = write_smirnoff(
    initial_ff,
    optimized_tensor_ff,
    parameters,
    attributes,
)

In [None]:
for a, b in zip(optimized_smirnoff_ff.to_string().splitlines(), initial_ff.to_string().splitlines()):
    if a != b:
        print(f"INI: {b}", "\n",f"OPT: {a}", sep="")

## Benchmarks

- Visual inspection of structures
- YAMMBS? Not designed to be stable, user-facing software (present as "our internal benchmark software")