# MLIP fitting workflow

This notebook walks though a lightweight workflow of fitting a MLIP to showcase usage of a wide range of wfl functions and ways they can be used. 

The main steps are: 

1. Create some molecules
2. Run GFN2-xTB MD
3. Filter by force components
4. Calculate global SOAP descriptor
5. Perform CUR decomposition to select diverse-ish training and testing sets
6. Fit a GAP potential
7. Evaluate structures with GAP
8. Plot atomization energy and force component correlation plots. 


## Imports 

In addition to standard packages or wfl dependencies, we make use of three external packages: 


- quip and quippy which provide interface for fitting and evaluating GAP. 
  
  Documentation: https://pypi.org/project/quippy-ase/ 
  
  Installation: `pip install quippy-ase`


- GFN2-xTB: a semi-empirical method designed for molecular systems, used as a reference method. 

  Documentation: 
  - https://xtb-docs.readthedocs.io/en/latest/contents.html
  - https://xtb-python.readthedocs.io/en/latest/

  Installation: `conda install -c conda-forge xtb-python`

- RDKit: a chemoinformatics package that wfl uses to convert 2D SMILES strings (e.g. "CCO" for ethanol) into 3D `Atoms` objects. 

  Documentation: https://rdkit.org/

  Installation: `conda install -c conda-forge rdkit`


In [1]:
import numpy as np

from ase.io import read, write
from ase import Atoms

from xtb.ase.calculator import XTB

from quippy.potential import Potential

from wfl.configset import ConfigSet, OutputSpec
from wfl.generate import md
import wfl.descriptors.quippy
import wfl.select.by_descriptor
import wfl.fit.gap.simple
from wfl.calculators import generic
from wfl.autoparallelize.autoparainfo import AutoparaInfo
from wfl.autoparallelize.remoteinfo import RemoteInfo
from wfl.generate import smiles
from wfl.utils.configs import atomization_energy
from wfl.fit import error
import wfl.map

from expyre.resources import Resources

In [2]:
# set random seed, so that MD runs, etc are reproducible and we can check for RMSEs. 
# this cell is hidden from tutorials. 
# np.random.seed(20230301)
import os
os.environ["WFL_DETERMINISTIC_HACK"] = "True"

## Reference calculator 

The calculator object given to `autoparalellize`-wrapped functions need to be pickle-able, so it can be executed on the parallel Python subprocesses with `multiprocessing.pool`. The calculators that can't be pickled need to be given to workflow functions as 

`(Initalizer, [args], {kwargs})`

e.g. xtb would normally be called with 

`xtb_calc = XTB(method="GFN2-xTB")`

but instead in wfl scripts we define it as

In [3]:
xtb_calc = (XTB, [], {"method": "GFN2-xTB"})

## Prepare isolated atoms 

GAP requires reference (`e0`) energies for fitting. We construct `Atoms` objects with a single atom, evaluate them with the reference GFN2-xTB method and store in a file to later combine them with the training set. 

In [4]:
isolated_at_fname='isolated_atoms.xtb.xyz'
isolated_atoms = [Atoms(element, positions=[(0, 0, 0)], cell=[50, 50, 50]) for element in ["H", "C"]]
inputs = ConfigSet(isolated_atoms)
outputs=OutputSpec(isolated_at_fname, tags={"config_type": "isolated_atom"})

# calculate reference energy
isolated_atoms = generic.run(
    inputs=inputs,
    outputs=outputs,
    calculator=xtb_calc,
    properties=["energy"],
    output_prefix="xtb_")

## Generate initial structures

We build this example on a small number of hydrocarbon molecules. Their connectivity is represented as SMILES strings and use RDKit to them into reasonable 3D geometries to start the molecular dynamics simulation with. 

In [5]:
all_smiles = [
    'CC1=CCC=CC(C)=C1C(C)C', 
    'CC1(c2ccc(CC3CC=CC3)cc2)CC1', 
    'C#CC[C@@H](CCC=C(C)C)C1CC1', 
    'Cc1ccccc1CCCC1=CCCCC1', 
    'C=CC1=CC[C@@H]2C[C@H]1C2(C)C', 
    'C1=CCC(Cc2ccc(CC3CC3)cc2)C1', 
    'C1=CC(c2ccccc2)=CCC1', 
    'C/C=C/CCCC[C@H](C)C(C)(C)C', 
    'C=C[C@@H]1C/C=C/CCCCCCCC1', 
    'C[C@H](CC(C)(C)C)[C@@H](C)C(C)(C)C', 
    'CC/C=C\\C[C@@H](C)c1cccc(C)c1C', 
    'C=C1CC2c3ccccc3C1c1ccccc12']

outputs = OutputSpec("1.ch.rdkit.xyz")
smiles_configs = smiles.run(all_smiles, outputs=outputs)

In [None]:
# RDKit doens't consistently generate the same initial structures. 
# # For testing purposes we load structures from file
# # this cell is hidden from docs. 
from pathlib import Path

# fn = Path(wfl.__file__).parent.resolve() / "../tests/assets/daisy_chain_mlip_fitting.1.ch.rdkit.xyz"
# smiles_configs = ConfigSet(fn)t
# orig_pythonpath = os.environ[f"PYTHONPATH"]
os.environ[f"PYTHONPATH"] = f"{Path(wfl.__file__).parent}"


## Run Molecular Dynamics simulation

We run the MD at 300 K with an NVT Berendsen thermostat to collect a pool of structures from which we will select diverse structures for the training set. 


While diverse training set leads to better model extrapolation, structures too dissimilar to the region of interest are fitted at the expense of accuracy elsewhere. One way to spot structures somewhat distant from equilibrium is by checking for high force components. To do so we attach a selector function which will exclude unreasonable structures from those returned by the md script. 

In [7]:
outputs = OutputSpec("2.ch.rdkit.md.traj.xyz")

md_params = {
    "steps": 80,
    "dt": 0.5,  # fs
    "temperature": 300,  # K
    "temperature_tau": 500,  
    "results_prefix": "xtb_",
    "traj_step_interval": 5}

remote_info = {
    "sys_name" : "github", 
    "job_name" : "md", 
    "resources" : { 
        "max_time" : "15m",
        "num_cores" : 2,
        "partitions" : "standard"}, 
    "check_interval": 5,
    "num_inputs_per_queued_job" :20,
    "pre_cmds": ['echo "-------------------------------------"', "which wfl"],
    "post_cmds": ['echo "_______________________________"', "which wfl"]
}


In [8]:
# set to None for github testing purposes
# This cell is hidden from being rendered in the docs. 
# remote_info = None


In [9]:
def reasonable_forces(at):
    force_comps = at.get_forces()
    return np.all(np.linalg.norm(force_comps, axis=1) < 8)

md_sample = md.sample(
    inputs=smiles_configs, 
    outputs=outputs,
    calculator=xtb_calc,
    traj_select_during_func = reasonable_forces, 
    autopara_info = AutoparaInfo(
        remote_info=remote_info),
    **md_params
    )


## Calculate SOAP descriptor

In [10]:
outputs = OutputSpec("3.ch.rdkit.md.traj.local_soap.xyz")

descriptor_key = "SOAP"
# Descriptor string, just as it would go into quip.
# dictionary can have a descriptor per species, e.g. 
# descriptor = {
#   "H": "soap ...",
#   "C": "soap ..."}
# `None` for dictionary keys just means that the same descriptor is used 
# for all elements. 
descriptor =   {
        None: "soap l_max=3 n_max=6 cutoff=4 delta=1 covariance_type=dot_product zeta=4 atom_gaussian_width=0.3"
    }

# this function isn't parallelised here, but can be
# by setting WFL_NUM_PYTHON_SUBPROCESSES or
# WFL_EXPYRE_INFO
md_soap_local = wfl.descriptors.quippy.calc(
    inputs=md_sample,
    outputs=outputs,
    descs=descriptor,
    key=descriptor_key,
    per_atom=True)

def get_average_soap(at, descriptor_key):
    at_desc = at.arrays.pop(descriptor_key)
    at_desc = np.sum(at_desc, axis=0)
    at_desc /= np.linalg.norm(at_desc)
    at.info[descriptor_key] = at_desc 
    return at

md_soap_global = wfl.map.run(
    inputs = md_soap_local,
    outputs = OutputSpec(), 
    map_func = get_average_soap, 
    args = [descriptor_key])

## Sub-select with CUR

Select diverse structures for training and testing sets with CUR. 

In [11]:
outputs = OutputSpec("4.ch.rdkit.md.traj.soap.cur_selection.xyz")
cur_selection = wfl.select.by_descriptor.CUR_conf_global(
    inputs=md_soap_global,
    outputs=outputs,
    num=100,                    # target number of structures to pick
    at_descs_info_key="SOAP")
    

train_fname = "5.1.train.xyz"
test_fname = "5.2.test.xyz"
gap_fname='gap.xml'

def process(at):
    at.cell = [50, 50, 50]
    # For now, SOAP descriptor in atoms.info cannot be parsed by the xyz reader
    del at.info["SOAP"]
    return at


processed_cur_selection = wfl.map.run(
    cur_selection,
    OutputSpec(),
    map_func = process)



Label and save training and testing sets

In [12]:
train_inputs = ConfigSet(list(processed_cur_selection)[0::2])
test_inputs = ConfigSet(list(processed_cur_selection)[1::2])

OutputSpec(train_fname, tags={"config_type": "train"}).write(train_inputs)
OutputSpec(test_fname, tags={"config_type": "test"}).write(test_inputs)

## Fit GAP

The gap parameter dictionary is almost directly converted to a command for `gap_fit`. 

In [13]:

gap_params = {
    "gap_file": gap_fname,
    "energy_parameter_name": "xtb_energy", 
    "force_parameter_name": "xtb_forces", 
    "default_sigma": [0.001, 0.01, 0.0, 0.0], 
    "config_type_kernel_regularisation": {"isolated_atom":[0.0001,0.0001,0.0,0.0]},
    "_gap": [{
            "soap": True,
            "l_max": 3,
            "n_max": 6, 
            "cutoff": 3,
            "delta": 0.1,
            "covariance_type": "dot_product",
            "zeta": 4, 
            "n_sparse":20, 
            "sparse_method": "cur_points", 
            "atom_gaussian_width":0.3,
            "cutoff_transition_width": 0.5},
         {
            "soap": True,
            "l_max": 3,
            "n_max": 6, 
            "cutoff": 6,
            "delta": 0.1,
            "covariance_type": "dot_product",
            "zeta": 4, 
            "n_sparse":20, 
            "sparse_method": "cur_points", 
            "atom_gaussian_width":0.6,
            "cutoff_transition_width": 1},
        {
            "distance_2b": True,
            "cutoff": 7, 
            "covariance_type": "ard_se",
            "delta": 1,
            "theta_uniform": 1.0,
            "sparse_method": "uniform", 
            "n_sparse": 10 
        }
    ]
}

remote_info = {
    "sys_name" : "github", 
    "job_name" : "gap-fit", 
    "resources" : { 
        "max_time" : "15m",
        "num_cores" : 2,
        "partitions" : "stndard"}, 
    "check_interval": 5, 
}


In [14]:
# set to None for github testing purposes
# This cell is hidden from being rendered in the docs. 
# remote_info = None
gap_params["rnd_seed"] = 20230301

In [15]:
train_configs = ConfigSet([train_fname, isolated_at_fname])
wfl.fit.gap.simple.run_gap_fit(
    fitting_configs=train_configs,
    fitting_dict=gap_params,
    stdout_file='gap_fit.out',
    skip_if_present=True,
    remote_info=remote_info)

fitting command:
 gap_fit gap_file=gap.xml energy_parameter_name=xtb_energy force_parameter_name=xtb_forces default_sigma={0.001 0.01 0.0 0.0} config_type_kernel_regularisation=isolated_atom:0.0001:0.0001:0.0:0.0 rnd_seed=20230301 atoms_filename=/tmp/_GAP_fitting_configs.f8y7ibmk.xyz gap={ soap=T l_max=3 n_max=6 cutoff=3 delta=0.1 covariance_type=dot_product zeta=4 n_sparse=20 sparse_method=cur_points atom_gaussian_width=0.3 cutoff_transition_width=0.5 : soap=T l_max=3 n_max=6 cutoff=6 delta=0.1 covariance_type=dot_product zeta=4 n_sparse=20 sparse_method=cur_points atom_gaussian_width=0.6 cutoff_transition_width=1 : distance_2b=T cutoff=7 covariance_type=ard_se delta=1 theta_uniform=1.0 sparse_method=uniform n_sparse=10 } 2>&1 > gap_fit.out 


'gap.xml'

## Evaluate structures with GAP

In [16]:
train_fn_with_gap = "6.1.train.gap.xyz"
test_fn_with_gap = "6.2.test.gap.xyz"
isolated_at_fn_with_gap = isolated_at_fname.replace('.xyz', '.gap.xyz')

inputs = ConfigSet([train_fname, test_fname, isolated_at_fname])
outputs = OutputSpec([train_fn_with_gap, test_fn_with_gap, isolated_at_fn_with_gap])

gap_calc = (Potential, [], {"param_filename":"gap.xml"})

resources = Resources(
    max_time = "15m",
    num_cores = 2,
    partitions = "standard")

remote_info = RemoteInfo(
    sys_name = "github",
    job_name = "gap-eval",
    resources = resources,
    check_interval=10, 
    input_files=["gap.xml*"])


In [17]:
# Set remote_info to None so that this can run on GitHub CI. 
# This cell is hidden from being rendered in the docs. 
# remote_info = None

In [18]:

gap_calc_autopara_info = AutoparaInfo(
    remote_info=remote_info)

generic.run(
    inputs=inputs,
    outputs=outputs,
    calculator=gap_calc,
    properties=["energy", "forces"],
    output_prefix="gap_",
    autopara_info=gap_calc_autopara_info,
    )


<wfl.configset.ConfigSet at 0x7ff96fe5c5b0>

## Evaluate error & plot correlation

wfl has simple convenience functions to compare fitted model's performance to the reference method. Here we calculate atomization energy, evaluate RMSE and plot the parity plots for atomization energy per atom and force components. 

In [20]:
# calculate atomization energies
for fn in [train_fn_with_gap, test_fn_with_gap]:
    configset = ConfigSet([fn, isolated_at_fn_with_gap])
    for prop_prefix in ["xtb_", "gap_"]:
        configset = atomization_energy(
            inputs=configset, 
            outputs=OutputSpec([fn, isolated_at_fn_with_gap], overwrite=True), 
            prop_prefix=prop_prefix) 

# calculate errors
inputs = ConfigSet([train_fn_with_gap, test_fn_with_gap])
errors, diffs, parity = error.calc(
    inputs=inputs, 
    calc_property_prefix='gap_',
    ref_property_prefix='xtb_',
    config_properties=["atomization_energy/atom"],
    atom_properties=["forces/comp"])

print(error.errors_dumps(errors))

# plot parity and error plots
error.value_error_scatter(
    all_errors = errors, 
    all_diffs=diffs,
    all_parity=parity,
    output="gap_rmses.png",
    ref_property_prefix="xtb_",
    calc_property_prefix="gap_"
)


      atomization_energy/a atomization_energy/a    F/c    F/c
                    meV/at                    #  meV/Å      #
test                 17.34                   50 642.83   5070
train                14.97                   50 621.38   5106
_ALL_                16.20                  100 632.16  10176


In [21]:
from pathlib import Path
from pytest import approx

# just check that all the files are there, so all steps completed (successfully). 

expected_files = [
    "isolated_atoms.xtb.xyz",
    "1.ch.rdkit.xyz",
    "2.ch.rdkit.md.traj.xyz",
    "3.ch.rdkit.md.traj.local_soap.xyz",
    "4.ch.rdkit.md.traj.soap.cur_selection.xyz",
    "5.1.train.xyz",
    "5.2.test.xyz",
    "gap.xml",
    "6.1.train.gap.xyz",
    "6.2.test.gap.xyz",
    "gap_rmses.png"
]

for file in expected_files:
    print(file)
    assert Path(file).exists()


ref_errors = {
    'atomization_energy/atom': {
        'train': {'RMSE': 0.023849550801771195, 'MAE': 0.019971738986477784, 'count': 50}, 
        '_ALL_': {'RMSE': 0.02450676656293013, 'MAE': 0.020064544549953763, 'count': 100}, 
        'test': {'RMSE': 0.02514681175206184, 'MAE': 0.02015735011342974, 'count': 50}}, 
    'forces/comp': {
        'train': {'RMSE': 0.7436094643663678, 'MAE': 0.5152898115037744, 'count': 5034}, 
        '_ALL_': {'RMSE': 0.7235685620391954, 'MAE': 0.5077753348085652, 'count': 10134}, 
        'test': {'RMSE': 0.7032271391631981, 'MAE': 0.500358104282353, 'count': 5100}}}

print(errors)

for prop_key, prop_dict in ref_errors.items():
    for config_key, config_dict in prop_dict.items():
        for measure_key, val in config_dict.items():
            if measure_key == "count":
                continue
            pred_val = errors[prop_key][config_key][measure_key]
            assert val == approx(pred_val, abs=3e-1), f'Mismatch for ref prop {prop_key} config {config_key} measure {measure_key} ref {val} != actual {pred_val}'

print("WFL_DETERMINISTIC_HACK" in os.environ)
assert False

isolated_atoms.xtb.xyz
1.ch.rdkit.xyz
2.ch.rdkit.md.traj.xyz
3.ch.rdkit.md.traj.local_soap.xyz
4.ch.rdkit.md.traj.soap.cur_selection.xyz
5.1.train.xyz
5.2.test.xyz
gap.xml
6.1.train.gap.xyz
6.2.test.gap.xyz
gap_rmses.png
{'atomization_energy/atom': {'train': {'RMSE': 0.014974721709943664, 'MAE': 0.012622912654433502, 'count': 50}, '_ALL_': {'RMSE': 0.016200667521092393, 'MAE': 0.012965381227920191, 'count': 100}, 'test': {'RMSE': 0.01734015472732864, 'MAE': 0.013307849801406882, 'count': 50}}, 'forces/comp': {'train': {'RMSE': 0.6213796984853067, 'MAE': 0.4465803716686252, 'count': 5106}, '_ALL_': {'RMSE': 0.6321592191859141, 'MAE': 0.4509735702456762, 'count': 10176}, 'test': {'RMSE': 0.6428326175648255, 'MAE': 0.45539796313214986, 'count': 5070}}}
True


AssertionError: 