In [6]:
import pickle as pkl
import warnings
from typing import Union

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import sklearn
import torch
from emle.models import EMLE
from emle.train._utils import pad_to_max

KJ_PER_MOL_TO_KCAL_PER_MOL = 1.0 / 4.184
HARTEE_TO_KJ_MOL = 2625.5

warnings.filterwarnings("ignore")

In [7]:
def bootstrap_statistic(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    dy_true: Union[np.ndarray, None] = None,
    dy_pred: Union[np.ndarray, None] = None,
    ci: float = 0.95,
    statistic: str = "RMSE",
    nbootstrap: int = 500,
    include_true_uncertainty: bool = False,
    include_pred_uncertainty: bool = False,
) -> dict:
    def compute_statistic(y_true_sample, y_pred_sample, statistic):
        if statistic == "RMSE":
            return np.sqrt(
                sklearn.metrics.mean_squared_error(y_true_sample, y_pred_sample)
            )
        elif statistic == "MUE":
            return sklearn.metrics.mean_absolute_error(y_true_sample, y_pred_sample)
        elif statistic == "MSE":
            return np.mean(y_pred_sample - y_true_sample)
        elif statistic == "R2":
            slope, intercept, r_value, _, _ = scipy.stats.linregress(
                y_true_sample, y_pred_sample
            )
            return r_value**2
        elif statistic == "rho":
            return scipy.stats.pearsonr(y_true_sample, y_pred_sample)[0]
        elif statistic == "KTAU":
            return scipy.stats.kendalltau(y_true_sample, y_pred_sample)[0]
        else:
            raise Exception(f"unknown statistic '{statistic}'")

    if dy_true is None:
        dy_true = np.zeros_like(y_true)
    if dy_pred is None:
        dy_pred = np.zeros_like(y_pred)

    N = len(y_true)
    s_n = np.zeros([nbootstrap], np.float64)
    for replicate in range(nbootstrap):
        idx = np.random.choice(np.arange(N), size=N, replace=True)
        y_true_sample = y_true[idx] + np.random.normal(
            0, dy_true[idx] if include_true_uncertainty else 0
        )
        y_pred_sample = y_pred[idx] + np.random.normal(
            0, dy_pred[idx] if include_pred_uncertainty else 0
        )
        s_n[replicate] = compute_statistic(y_true_sample, y_pred_sample, statistic)

    s_n_sorted = np.sort(s_n)
    low_frac = (1.0 - ci) / 2.0
    high_frac = 1 - low_frac

    return {
        "mle": compute_statistic(y_true, y_pred, statistic),
        "mean": np.mean(s_n),
        "stderr": np.std(s_n),
        "low": s_n_sorted[int(np.floor(nbootstrap * low_frac))],
        "high": s_n_sorted[int(np.ceil(nbootstrap * high_frac))],
    }


def calculate_rmse(predicted, reference):
    return np.sqrt(np.mean((predicted - reference) ** 2))


source_path = "../../../"
models_dict = {
    "General Model": source_path + "emle_models/emle_qm7_aev.mat",
    "Bespoke Model": source_path + "emle_models/ligand_bespoke_iter2.mat",
    "Patched Model": source_path + "emle_models/ligand_patched_species_iter2.mat",
}

testing_data_paths = [
    source_path + "data/testing_datasets/testing_data_iter1.pkl",
    source_path + "data/testing_datasets/testing_data_iter2.pkl",
    source_path + "data/testing_datasets/testing_data_iter3.pkl",
]

xyz_qm, xyz_mm, z, charges_mm = [], [], [], []
e_static_ref, e_ind_ref = [], []

s, q_core, q_val, alpha = [], [], [], []

for path in testing_data_paths:
    data = pkl.load(open(path, "rb"))
    xyz_qm += data["xyz_qm"]
    xyz_mm += data["xyz_mm"]
    z += data["z"]
    charges_mm += data["charges_mm"]
    e_static_ref += data["e_static"]
    e_ind_ref += data["e_ind"]

    # For MBIS/atomic properties
    s += data.get("s", [])
    q_core += data.get("q_core", [])
    q_val += data.get("q_val", [])
    alpha += data.get("alpha", [])

# Pad and convert to tensors
xyz_qm = pad_to_max(xyz_qm)
xyz_mm = pad_to_max(xyz_mm)
z = pad_to_max(z)
charges_mm = pad_to_max(charges_mm)
e_static_ref = torch.tensor(e_static_ref)
e_ind_ref = torch.tensor(e_ind_ref)

s_ref = pad_to_max(s)
q_core_ref = pad_to_max(q_core)
q_val_ref = pad_to_max(q_val)
alpha_ref = pad_to_max(alpha)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

emle_energy_dict = {}
emle_mbis_dict = {}

for model_name, model_path in models_dict.items():
    emle_model = EMLE(model=model_path, alpha_mode="species", device=device).double()
    e_static_emle, e_ind_emle = emle_model.forward(z, charges_mm, xyz_qm, xyz_mm)
    e_static_emle = e_static_emle.detach().cpu().numpy()
    e_ind_emle = e_ind_emle.detach().cpu().numpy()
    emle_energy_dict[model_name] = {
        "e_static_emle": e_static_emle * HARTEE_TO_KJ_MOL,
        "e_ind_emle": e_ind_emle * HARTEE_TO_KJ_MOL,
    }

    # MBIS/atomic properties
    emle_base = emle_model._emle_base
    s_emle, q_core_emle, q_val_emle, A_thole_emle = emle_base.forward(
        z, xyz_qm, torch.zeros(len(xyz_qm))
    )
    mask = z > 0
    n_atoms = mask.shape[1]
    mask_mat = (
        (mask[:, :, None] * mask[:, None, :])
        .repeat_interleave(3, dim=1)
        .repeat_interleave(3, dim=2)
    )
    A_thole_inv = torch.where(mask_mat, torch.linalg.inv(A_thole_emle), 0.0)
    alpha_emle = (
        torch.sum(A_thole_inv.reshape((-1, n_atoms, 3, n_atoms, 3)), dim=(1, 3))
        .detach()
        .numpy()
    )

    emle_mbis_dict[model_name] = {
        "s": s_emle.detach().numpy(),
        "q_core": q_core_emle.detach().numpy(),
        "q_val": q_val_emle.detach().numpy(),
        "alpha": alpha_emle,
        "z": z.detach().numpy(),
    }

In [8]:
reference = (e_static_ref.numpy() + e_ind_ref.numpy()) * KJ_PER_MOL_TO_KCAL_PER_MOL
predicted = {
    name: (d["e_static_emle"] + d["e_ind_emle"]) * KJ_PER_MOL_TO_KCAL_PER_MOL
    for name, d in emle_energy_dict.items()
}

data = []
rmse_values = {}

s_ref_flat = s_ref.detach().numpy().flatten()
q_ref_flat = (q_val_ref.detach().numpy() + q_core_ref.detach().numpy()).flatten()
alpha_ref_flat = alpha_ref.detach().numpy().flatten()

for model_name, mbis in emle_mbis_dict.items():
    s_pred = mbis["s"].flatten()
    q_pred = mbis["q_val"].flatten() + mbis["q_core"].flatten()
    alpha_pred = mbis["alpha"].flatten()

    rmse_values[model_name] = {
        "S": calculate_rmse(s_pred, s_ref_flat),
        "Charge": calculate_rmse(q_pred, q_ref_flat),
        "Polarizability": calculate_rmse(alpha_pred, alpha_ref_flat),
    }

    for type_, ref, pred, key in [
        ("S vs Sref", s_ref_flat, s_pred, "S"),
        ("Charges", q_ref_flat, q_pred, "Charge"),
        ("Polarizability", alpha_ref_flat, alpha_pred, "Polarizability"),
    ]:
        data.extend(
            [
                {"Model": model_name, "Atom": "All", "Difference": v, "Type": type_}
                for v in (pred - ref)
            ]
        )

df = pd.DataFrame(data)

In [9]:
reference = (e_static_ref.numpy() + e_ind_ref.numpy()) * KJ_PER_MOL_TO_KCAL_PER_MOL
predicted = {
    name: (d["e_static_emle"] + d["e_ind_emle"]) * KJ_PER_MOL_TO_KCAL_PER_MOL
    for name, d in emle_energy_dict.items()
}

s_ref_flat = s_ref.detach().numpy().flatten()
q_ref_flat = (q_val_ref.detach().numpy() + q_core_ref.detach().numpy()).flatten()
alpha_ref_flat = alpha_ref.detach().numpy().flatten()
z_flat = z.detach().numpy().flatten()  # atomic numbers

atomic_masks = {
    "H": (z_flat == 1),
    "C": (z_flat == 6),
    "N": (z_flat == 7),
    "O": (z_flat == 8),
    "S": (z_flat == 16),
}

data = []
rmse_values = {}

for model_name, mbis in emle_mbis_dict.items():
    s_pred = mbis["s"].flatten()
    q_pred = mbis["q_val"].flatten() + mbis["q_core"].flatten()
    alpha_pred = mbis["alpha"].flatten()

    rmse_values[model_name] = {
        "S": calculate_rmse(s_pred, s_ref_flat),
        "Charge": calculate_rmse(q_pred, q_ref_flat),
        "Polarizability": calculate_rmse(alpha_pred, alpha_ref_flat),
    }

    for type_, ref, pred, key in [
        ("S vs Sref", s_ref_flat, s_pred, "S"),
        ("Charges", q_ref_flat, q_pred, "Charge"),
    ]:
        for atom, mask in atomic_masks.items():
            mask = mask[: len(ref)]  # ensure shapes match
            diffs = (pred - ref)[mask]
            data.extend(
                [
                    {"Model": model_name, "Atom": atom, "Difference": d, "Type": type_}
                    for d in diffs
                ]
            )

    diffs = alpha_pred - alpha_ref_flat
    data.extend(
        [
            {"Model": model_name, "Atom": "", "Difference": d, "Type": "Polarizability"}
            for d in diffs
        ]
    )

df = pd.DataFrame(data)

In [None]:
sns.set(style="whitegrid", palette="colorblind", context="paper", font_scale=1.5)
fig = plt.figure(figsize=(16, 12))
gs = gridspec.GridSpec(2, 3, height_ratios=[1, 0.7])

x = np.linspace(-75, 10, 100)
y = x
for i, (model_name, pred_vals) in enumerate(predicted.items()):
    ax = fig.add_subplot(gs[0, i])
    ax.scatter(reference, pred_vals, alpha=0.7, label=model_name)
    ax.plot(x, y, "k--", label="Reference")
    ax.set_xlabel("$E^{QM/MM}_{int}$ [kcal.mol$^{-1}$]")
    ax.set_ylabel("$E^{EMLE}_{int}$ [kcal.mol$^{-1}$]")
    ax.set_xlim(-75, 10)
    ax.set_ylim(-75, 10)
    ax.set_aspect("equal")

    metrics = ["R2", "RMSE", "MUE", "MSE", "rho", "KTAU"]
    statistics = {}
    statistics_string = ""
    statistic_type = "mle"
    statistic_name = {
        "RMSE": r"RMSE",
        "MUE": r"MUE",
        "MSE": r"MSE",
        "R2": r"R$^2$",
        "rho": r"$\rho$",
        "KTAU": r"$\tau$",
    }

    for statistic in metrics:
        statistics[statistic] = bootstrap_statistic(
            reference, pred_vals, statistic=statistic
        )
        s = statistics[statistic]
        string = (
            f"{statistic_name[statistic]}: {s[statistic_type]:.2f} [{s['low']:.2f}, {s['high']:.2f}]"
            + "\n"
        )
        statistics_string += string

    ax.text(
        0.02,
        0.98,
        f"{model_name}\n{statistics_string}",
        transform=ax.transAxes,
        ha="left",
        va="top",
    )

plot_configs = [
    ("S vs Sref", "Valence widths ($s$)", "$s^{EMLE}-s^{MBIS}$ [$a_0$]", "S"),
    ("Charges", "Charges (QEq)", r"$q^{EMLE}-q^{MBIS}$ [$e$]", "Charge"),
    (
        "Polarizability",
        r"Molecular polarizability ($\alpha$)",
        r"$\alpha^{EMLE}_{mol}-\alpha^{B3LYP}_{mol}$ [$a_0^3$]",
        "Polarizability",
    ),
]

for i, (type_, title, ylabel, rmse_key) in enumerate(plot_configs):
    ax = fig.add_subplot(gs[1, i])
    sns.violinplot(
        x="Atom",
        y="Difference",
        hue="Model",
        data=df[df["Type"] == type_],
        ax=ax,
        split=False,
        inner="quart",
        palette="colorblind",
    )
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.set_xlabel("")
    handles, labels = ax.get_legend_handles_labels()
    new_labels = [f"{label} ({rmse_values[label][rmse_key]:.2e})" for label in labels]
    ax.legend(handles=handles, labels=new_labels, loc="upper center")


axes = fig.get_axes()
labels = ["A", "B", "C", "D", "E", "F"]
for ax, label in zip(axes, labels):
    ax.text(
        -0.1,
        1.05,
        label,
        transform=ax.transAxes,
        fontsize=32,
        fontweight="bold",
        va="bottom",
        ha="right",
    )

plt.tight_layout()
plt.subplots_adjust(hspace=0.00)
plt.savefig("fig3_emle_vs_qmmm.pdf", dpi=300, bbox_inches="tight")
plt.show()