In [1]:
import pickle
import pandas as pd
import numpy as np
import tensorflow as tf
from pathlib import Path
from dataclasses import dataclass, field
from mbeml.constants import LigandFeatures, TargetProperty
from mbeml.featurization import data_prep
from mbeml.metrics import mean_absolute_error, r2_score, mean_negative_log_likelihood

In [2]:
data_dir = Path("../../data/")

data_sets = {
    "train": pd.read_csv(data_dir / "training_data.csv"),
    "validation": pd.read_csv(data_dir / "validation_data.csv"),
    "composition_test": pd.read_csv(data_dir / "composition_test_data.csv"),
    "ligand_test": pd.read_csv(data_dir / "ligand_test_data.csv"),
}

model_dir = Path("../../models/")

In [3]:
@dataclass
class Experiment:
    name: str
    features: LigandFeatures
    target: TargetProperty = TargetProperty.SSE
    is_nn: bool = False
    predictions: dict = field(
        default_factory=lambda: {
            key: np.zeros([len(df), 1]) for key, df in data_sets.items()
        }
    )
    uncertainties: dict = field(
        default_factory=lambda: {
            key: np.zeros([len(df), 1]) for key, df in data_sets.items()
        }
    )

In [4]:
experiments = [
    Experiment(name="krr_standard_racs", features=LigandFeatures.STANDARD_RACS),
    Experiment(name="krr_two_body", features=LigandFeatures.LIGAND_RACS),
    Experiment(name="krr_three_body", features=LigandFeatures.LIGAND_RACS),
    Experiment(
        name="nn_standard_racs", features=LigandFeatures.STANDARD_RACS, is_nn=True
    ),
    Experiment(name="nn_two_body", features=LigandFeatures.LIGAND_RACS, is_nn=True),
    Experiment(name="nn_three_body", features=LigandFeatures.LIGAND_RACS, is_nn=True),
]

In [5]:
# Evaluate all experiments on the four data sets
for experiment in experiments:
    for df_name, data_set in data_sets.items():
        X, y = data_prep(
            data_set, experiment.features, experiment.target, experiment.is_nn
        )
        if experiment.is_nn:
            model = tf.keras.models.load_model(
                model_dir / experiment.target.name.lower() / experiment.name
            )
            y_mean, y_std = model.predict(X, verbose=0)
        else:
            with open(
                model_dir / experiment.target.name.lower() / f"{experiment.name}.pkl",
                "rb",
            ) as fin:
                model = pickle.load(fin)
            y_mean, y_std = model.predict(X, return_std=True)
            y_mean = y_mean.reshape(-1, 1)
            y_std = y_std.reshape(-1, 1)
        experiment.predictions[df_name] = y_mean
        experiment.uncertainties[df_name] = y_std

2024-02-09 10:00:59.973842: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
def evaluate_metric(metric, requires_uncertainty=False):
    results = {}
    for experiment in experiments:
        result_row = {}
        for key, data_set in data_sets.items():
            if requires_uncertainty:
                result_row[key] = metric(
                    data_set[experiment.target.full_name()],
                    experiment.predictions[key],
                    experiment.uncertainties[key],
                )
            else:
                result_row[key] = metric(
                    data_set[experiment.target.full_name()],
                    experiment.predictions[key],
                )
        results[experiment.name] = result_row
    return pd.DataFrame.from_dict(results, orient="index")

In [7]:
evaluate_metric(mean_absolute_error).round(2)

Unnamed: 0,train,validation,composition_test,ligand_test
krr_standard_racs,0.87,3.67,4.1,4.84
krr_two_body,2.63,3.96,3.12,5.05
krr_three_body,1.02,3.3,2.75,4.93
nn_standard_racs,2.33,3.51,4.16,4.73
nn_two_body,3.16,3.73,3.61,4.14
nn_three_body,2.78,3.48,3.41,4.0


In [8]:
evaluate_metric(r2_score).round(3)

Unnamed: 0,train,validation,composition_test,ligand_test
krr_standard_racs,0.998,0.959,0.947,0.921
krr_two_body,0.979,0.952,0.967,0.903
krr_three_body,0.996,0.962,0.976,0.907
nn_standard_racs,0.985,0.964,0.954,0.924
nn_two_body,0.972,0.958,0.964,0.934
nn_three_body,0.979,0.963,0.961,0.943


In [9]:
evaluate_metric(mean_negative_log_likelihood, requires_uncertainty=True).round(2)

Unnamed: 0,train,validation,composition_test,ligand_test
krr_standard_racs,2.84,3.84,3.8,4.55
krr_two_body,3.77,4.34,4.4,5.82
krr_three_body,3.14,4.41,4.6,5.97
nn_standard_racs,17.83,29.6,27.64,16.1
nn_two_body,24.85,31.37,22.82,15.36
nn_three_body,8.16,9.63,7.97,5.18


## Metal dependence

In [10]:
cores = ["cr3", "cr2", "mn3", "mn2", "fe3", "fe2", "co3", "co2"]


def metal_dependence(key: str, metric=mean_absolute_error):
    results = {}
    data_set = data_sets[key]
    for experiment in experiments:
        result_row = {}
        for core in cores:
            mask = (data_set[["metal", "ox"]] == [core[:2], int(core[-1])]).all(axis=1)
            result_row[core] = metric(
                data_set[mask][experiment.target.full_name()],
                experiment.predictions[key][mask],
            )
        result_row["all"] = metric(
            data_set[experiment.target.full_name()], experiment.predictions[key]
        )
        results[experiment.name] = result_row
    return pd.DataFrame.from_dict(results, orient="index")

In [11]:
metal_dependence("train").round(2)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,0.71,0.92,1.01,0.62,0.86,0.91,1.16,0.96,0.87
krr_two_body,1.58,2.91,3.16,2.45,2.8,3.09,3.36,2.24,2.63
krr_three_body,0.84,1.08,1.21,0.78,1.09,0.97,1.15,1.17,1.02
nn_standard_racs,1.38,2.58,3.07,1.96,2.42,2.39,2.69,2.72,2.33
nn_two_body,1.86,3.5,3.78,2.88,3.35,3.66,4.18,2.79,3.16
nn_three_body,1.85,3.25,3.62,2.72,2.81,2.88,3.17,2.58,2.78


In [12]:
metal_dependence("validation").round(2)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,3.29,3.59,4.55,3.07,3.22,3.83,5.04,3.74,3.67
krr_two_body,2.13,4.96,3.9,3.7,3.98,5.03,4.31,3.85,3.96
krr_three_body,2.07,3.65,2.8,3.0,3.47,3.9,3.83,3.54,3.3
nn_standard_racs,2.35,4.05,4.04,3.35,3.51,3.43,4.6,3.55,3.51
nn_two_body,2.17,4.31,4.08,3.49,3.76,4.28,4.72,3.6,3.73
nn_three_body,2.12,4.27,4.32,3.07,3.46,3.7,4.56,3.33,3.48


In [13]:
metal_dependence("validation", r2_score).round(3)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,-0.483,0.857,0.549,0.865,0.927,0.916,0.89,0.746,0.959
krr_two_body,0.199,0.754,0.661,0.798,0.866,0.869,0.929,0.747,0.952
krr_three_body,0.276,0.835,0.777,0.848,0.894,0.901,0.934,0.766,0.962
nn_standard_racs,0.241,0.828,0.68,0.846,0.905,0.941,0.911,0.765,0.964
nn_two_body,0.274,0.77,0.667,0.809,0.888,0.911,0.913,0.775,0.958
nn_three_body,0.359,0.796,0.623,0.856,0.902,0.93,0.92,0.786,0.963


In [14]:
metal_dependence("composition_test").round(2)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,3.47,2.21,4.7,2.11,2.87,5.47,7.67,4.3,4.1
krr_two_body,1.36,3.08,5.78,1.48,1.48,3.56,5.95,2.24,3.12
krr_three_body,1.41,2.66,5.66,1.76,2.98,1.57,3.63,2.29,2.75
nn_standard_racs,2.57,3.76,7.3,2.85,2.87,4.5,4.72,4.7,4.16
nn_two_body,2.36,2.82,6.42,3.48,2.28,3.07,4.45,3.99,3.61
nn_three_body,2.17,4.39,8.45,1.99,1.98,1.72,3.25,3.35,3.41


In [15]:
metal_dependence("composition_test", r2_score).round(2)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,-8.51,0.85,0.16,0.96,0.93,0.83,0.77,0.71,0.95
krr_two_body,-0.66,0.77,-0.16,0.98,0.98,0.93,0.86,0.92,0.97
krr_three_body,-0.66,0.79,-0.15,0.97,0.95,0.99,0.94,0.91,0.98
nn_standard_racs,-3.48,0.66,-0.92,0.94,0.94,0.92,0.91,0.68,0.95
nn_two_body,-2.94,0.79,-0.41,0.92,0.96,0.95,0.92,0.78,0.96
nn_three_body,-2.45,0.57,-1.56,0.97,0.97,0.98,0.96,0.83,0.96


In [16]:
metal_dependence("ligand_test").round(2)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,5.21,4.29,5.85,4.83,4.1,4.74,7.72,4.07,4.84
krr_two_body,3.74,3.35,4.13,6.68,5.15,6.95,8.62,3.31,5.05
krr_three_body,4.41,3.27,3.54,6.55,4.67,6.61,7.72,3.42,4.93
nn_standard_racs,2.54,3.46,5.37,5.0,6.54,4.98,9.32,3.97,4.73
nn_two_body,1.18,2.76,5.15,5.32,5.21,4.91,8.32,3.33,4.14
nn_three_body,1.26,3.06,3.79,5.17,5.17,4.88,7.83,3.14,4.0


In [17]:
metal_dependence("ligand_test", r2_score).round(2)

Unnamed: 0,cr3,cr2,mn3,mn2,fe3,fe2,co3,co2,all
krr_standard_racs,-11.19,0.36,-1.52,0.72,0.74,0.82,0.19,0.67,0.92
krr_two_body,-7.2,0.65,-0.14,0.39,0.64,0.65,0.03,0.74,0.9
krr_three_body,-9.5,0.64,0.06,0.39,0.67,0.67,0.15,0.78,0.91
nn_standard_racs,-2.0,0.55,-0.97,0.69,0.49,0.82,0.07,0.66,0.92
nn_two_body,0.1,0.74,-1.13,0.61,0.63,0.83,0.27,0.72,0.93
nn_three_body,-0.15,0.73,-0.24,0.67,0.66,0.85,0.4,0.78,0.94


## Ligand dependence on the ligand test set

In [18]:
test_ligands = [
    "4H-pyran",
    "[OH]-[CH]=[CH]-[OH]",
    "bifuran",
    "pyridine-N-oxide",
    "acrylamide",
    "dmf",
    "thiophene",
    "thiane",
    "4H-thiopyran",
    "oxazoline",
    "thioazole",
    "[NH]=[CH]-[OH]",
    "[PH]=[CH]-[OH]",
    "[NH2]-[NH]-[NH]-[NH2]",
    "1H-tetrazole",
    "1H-triazole",
    "thioformaldehyde",
    "[NH2]-[O]-[O]-[NH2]",
    "bidiazine",
    "[PH2]-[CH2]-[OH]",
    "[PH2]-[NH]-[NH]-[PH2]",
]

In [19]:
def ligand_dependence(metric):
    results = {}
    key = "ligand_test"
    data_set = data_sets[key]
    for lig in test_ligands:
        result_row = {}
        mask = data_set["name"].str.contains(lig, regex=False)
        result_row["count"] = np.count_nonzero(mask)
        for experiment in experiments:
            result_row[experiment.name] = metric(
                data_set[mask][experiment.target.full_name()],
                experiment.predictions[key][mask],
            )
        results[lig] = result_row
    return pd.DataFrame.from_dict(results, orient="index")

In [20]:
ligand_dependence(mean_absolute_error).round(2)

Unnamed: 0,count,krr_standard_racs,krr_two_body,krr_three_body,nn_standard_racs,nn_two_body,nn_three_body
4H-pyran,3,4.4,2.5,3.33,4.14,2.46,4.08
[OH]-[CH]=[CH]-[OH],4,7.78,1.77,2.97,3.83,3.02,3.85
bifuran,4,4.18,5.82,5.75,3.79,4.71,4.57
pyridine-N-oxide,7,3.27,4.01,3.84,3.98,3.19,3.59
acrylamide,7,7.02,6.83,6.83,2.56,4.23,3.28
dmf,8,2.79,4.37,3.89,2.72,1.76,1.85
thiophene,5,4.58,5.44,5.61,2.92,3.07,2.68
thiane,5,6.16,9.66,10.2,10.26,11.21,10.29
4H-thiopyran,5,4.11,4.26,4.37,6.77,3.81,4.23
oxazoline,7,3.84,4.85,4.92,5.44,5.17,4.38


In [21]:
ligand_dependence(r2_score).round(2)

Unnamed: 0,count,krr_standard_racs,krr_two_body,krr_three_body,nn_standard_racs,nn_two_body,nn_three_body
4H-pyran,3,0.79,0.95,0.91,0.81,0.95,0.86
[OH]-[CH]=[CH]-[OH],4,0.56,0.96,0.9,0.88,0.93,0.89
bifuran,4,0.87,0.8,0.79,0.93,0.88,0.89
pyridine-N-oxide,7,0.91,0.87,0.87,0.84,0.88,0.86
acrylamide,7,0.83,0.8,0.8,0.97,0.87,0.93
dmf,8,0.96,0.93,0.95,0.97,0.99,0.99
thiophene,5,0.89,0.78,0.77,0.94,0.95,0.96
thiane,5,0.72,0.37,0.26,0.32,0.18,0.33
4H-thiopyran,5,0.89,0.89,0.88,0.78,0.91,0.91
oxazoline,7,0.96,0.93,0.93,0.91,0.93,0.95
