In [1]:
import pickle
import pandas as pd
import numpy as np
import tensorflow as tf
from pathlib import Path
from dataclasses import dataclass, field
from mbeml.constants import LigandFeatures, TargetProperty, unique_cores
from mbeml.featurization import data_prep
from mbeml.metrics import mean_absolute_error, r2_score, mean_negative_log_likelihood

In [2]:
data_dir = Path("../../data/")

data_sets = {
    "train": pd.read_csv(data_dir / "training_data.csv"),
    "validation": pd.read_csv(data_dir / "validation_data.csv"),
    "composition_test": pd.read_csv(data_dir / "composition_test_data.csv"),
    "ligand_test": pd.read_csv(data_dir / "ligand_test_data.csv"),
}

model_dir = Path("../../models/")

In [3]:
@dataclass
class Experiment:
    name: str
    features: LigandFeatures
    target: TargetProperty = TargetProperty.GAP
    is_nn: bool = False
    predictions: dict = field(
        default_factory=lambda: {
            key: np.zeros([len(df), 4]) for key, df in data_sets.items()
        }
    )
    uncertainties: dict = field(
        default_factory=lambda: {
            key: np.zeros([len(df), 4]) for key, df in data_sets.items()
        }
    )

In [4]:
experiments = [
    Experiment(name="krr_standard_racs", features=LigandFeatures.STANDARD_RACS),
    Experiment(name="krr_two_body", features=LigandFeatures.LIGAND_RACS),
    Experiment(name="krr_three_body", features=LigandFeatures.LIGAND_RACS),
    Experiment(
        name="nn_standard_racs", features=LigandFeatures.STANDARD_RACS, is_nn=True
    ),
    Experiment(name="nn_two_body", features=LigandFeatures.LIGAND_RACS, is_nn=True),
    Experiment(name="nn_three_body", features=LigandFeatures.LIGAND_RACS, is_nn=True),
]

In [5]:
for experiment in experiments:
    for df_name, data_set in data_sets.items():
        X, y = data_prep(
            data_set, experiment.features, experiment.target, experiment.is_nn
        )
        if experiment.is_nn:
            model = tf.keras.models.load_model(
                model_dir / experiment.target.name.lower() / experiment.name
            )
            y_mean, y_std = model.predict(X, verbose=0)
        else:
            with open(
                model_dir / experiment.target.name.lower() / f"{experiment.name}.pkl",
                "rb",
            ) as fin:
                model = pickle.load(fin)
            y_mean, y_std = model.predict(X, return_std=True)
        experiment.predictions[df_name] = y_mean
        experiment.uncertainties[df_name] = y_std

2024-03-27 14:52:59.012338: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
def evaluate_metric(metric, requires_uncertainty=False, transformation=None):
    if transformation is None:
        transformation = lambda x: x
    results = {}
    for experiment in experiments:
        result_row = {}
        for key, data_set in data_sets.items():
            if requires_uncertainty:
                result_row[key] = metric(
                    transformation(data_set[experiment.target.full_name()].values),
                    transformation(experiment.predictions[key]),
                    transformation(experiment.uncertainties[key]),
                )
            else:
                result_row[key] = metric(
                    transformation(data_set[experiment.target.full_name()].values),
                    transformation(experiment.predictions[key]),
                )
        results[experiment.name] = result_row
    return pd.DataFrame.from_dict(results, orient="index")

In [7]:
evaluate_metric(mean_absolute_error).round(2)

Unnamed: 0,train,validation,composition_test,ligand_test
krr_standard_racs,0.13,0.41,0.49,0.57
krr_two_body,0.31,0.44,0.25,0.63
krr_three_body,0.13,0.38,0.62,0.62
nn_standard_racs,0.34,0.43,0.43,0.69
nn_two_body,0.4,0.43,0.47,0.63
nn_three_body,0.32,0.38,0.96,0.62


In [8]:
evaluate_metric(r2_score).round(2)

Unnamed: 0,train,validation,composition_test,ligand_test
krr_standard_racs,0.97,0.72,0.59,0.55
krr_two_body,0.83,0.69,0.87,0.49
krr_three_body,0.97,0.75,0.47,0.52
nn_standard_racs,0.85,0.72,0.74,0.26
nn_two_body,0.78,0.71,0.68,0.49
nn_three_body,0.86,0.77,-0.12,0.4
