In [None]:
import datetime
import logging
import os
from pathlib import Path

import numpy as np
import pandas as pd
from astartes import train_test_split
from lightning import pytorch as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.preprocessing import StandardScaler

from causal_chemprop_logS import SobolevMulticomponentMPNN, CustomMSEMetric
from chemprop import data as chemprop_data_utils
from chemprop import featurizers, nn
from chemprop.nn.transforms import ScaleTransform, UnscaleTransform

In [None]:
RANDOM_SEED = 8675309
TRAINING_FPATH = Path("../data/aqueous.csv")        # TODO: Update path

# setup logging and output directories
_output_dir = Path(f"output/chemprop_{int(datetime.datetime.now(datetime.UTC).timestamp())}")
os.makedirs(_output_dir, exist_ok=True)
logging.getLogger("pytorch_lightning").setLevel(logging.INFO)
seed_everything(RANDOM_SEED)

downsample_percent = 1.0
mpnn_hidden_size = 800
fnn_hidden_size = 200

Data preparation

In [None]:
def _f(r):  # approximates gradients of logS wrt temperature using finite differences
    if len(r["scaled_logS"]) == 1:
        return [np.nan]
    sorted_idxs = np.argsort(r["scaled_temperature"])
    unsort_idxs = np.argsort(sorted_idxs)
    # mask out enormous (non-physical) values, negative values, and nan/inf
    grads = [
        i if (np.isfinite(i) and np.abs(i) < 1.0 and i > 0.0) else np.nan
        for i in np.gradient(
            [r["scaled_logS"][i] for i in sorted_idxs],
            [r["scaled_temperature"][i] for i in sorted_idxs],
        )
    ]
    return [grads[i] for i in unsort_idxs]

In [None]:
# load the training data
df = pd.read_csv(TRAINING_FPATH)
if downsample_percent:
    print(f"Down-sampling training data to {downsample_percent:.2%} size!")
    downsample_df = df.copy()
    downsample_df["original_index"] = np.arange(len(df))
    downsample_df = downsample_df.groupby(["solute_smiles", "solvent_smiles", "source"]).aggregate(list)
    downsample_df = downsample_df.sample(frac=downsample_percent, replace=False, random_state=RANDOM_SEED)
    chosen_indexes = downsample_df.explode("original_index")["original_index"].to_numpy().flatten().astype(int)
    print(f"Actual downsample percentage is {len(chosen_indexes)/len(df):.2%}, count: {len(chosen_indexes)}!")
    df = df.iloc[chosen_indexes]
    df.reset_index(inplace=True, drop=True)

# split the data s.t. model only sees a subset of the studies used to aggregate the training data
studies_train, studies_val = train_test_split(pd.unique(df["source"]), random_state=RANDOM_SEED, train_size=0.9, test_size=0.1)       
train_indexes = df.index[df["source"].isin(studies_train)].tolist()
val_indexes = df.index[df["source"].isin(studies_val)].tolist()
_total = len(df)
print(f"train: {len(train_indexes)} ({len(train_indexes)/_total:.0%}) validation:" f"{len(val_indexes)} ({len(val_indexes)/_total:.0%})")

In [None]:
# manual re-scaling
target_scaler = StandardScaler().fit(df[["logS"]].iloc[train_indexes])
scaled_logs = target_scaler.transform(df[["logS"]]).ravel()
temperature_scaler = StandardScaler().fit(df[["temperature"]].iloc[train_indexes])
scaled_temperature = temperature_scaler.transform(df[["temperature"]]).ravel()
# calculate known temperature gradients
tgrads = pd.concat(
    (
        df,
        pd.DataFrame(
            {
                "source_index": np.arange(len(df["temperature"])),
                "scaled_temperature": scaled_temperature,
                "scaled_logS": scaled_logs,
            }
        ),
    ),
    axis=1,
)
# group the data by experiment
tgrads = tgrads.groupby(["source", "solvent_smiles", "solute_smiles"])[["scaled_logS", "scaled_temperature", "source_index"]].aggregate(list)
tgrads["logSgradT"] = tgrads.apply(_f, axis=1)
tgrads = tgrads.explode(["logSgradT", "source_index"]).sort_values(by="source_index")
tgrads = tgrads["logSgradT"].to_numpy(dtype=np.float32)
_mask = np.isnan(tgrads)
print(f"Masking {np.count_nonzero(_mask)} of {len(_mask)} gradients!")
print(f"{np.count_nonzero(tgrads > 0)} of {len(tgrads)} were positive!")

In [None]:
all_data = [
    [
        chemprop_data_utils.MoleculeDatapoint.from_smi(smi, [log_s, log_s_grad_T], x_d=np.array([temperature]))
        for smi, log_s, log_s_grad_T, temperature in zip(df["solute_smiles"], scaled_logs, tgrads, df["temperature"])
    ],
    list(map(chemprop_data_utils.MoleculeDatapoint.from_smi, df["solvent_smiles"])),
]
# create datasets
train_data, val_data, _ = chemprop_data_utils.split_data_by_indices(all_data, train_indices=train_indexes, val_indices=val_indexes)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_datasets = [chemprop_data_utils.MoleculeDataset(train_data[i], featurizer) for i in range(len(all_data))]
val_datasets = [chemprop_data_utils.MoleculeDataset(val_data[i], featurizer) for i in range(len(all_data))]
causal_datasets = [chemprop_data_utils.MoleculeDataset(train_data[i], featurizer) for i in range(len(all_data))]       
# Create mcdatasets
train_mcdset = chemprop_data_utils.MulticomponentDataset(train_datasets)
train_mcdset.normalize_inputs("X_d", [temperature_scaler, None])
train_mcdset.cache = True
val_mcdset = chemprop_data_utils.MulticomponentDataset(val_datasets)
val_mcdset.cache = True
causal_mcdset = chemprop_data_utils.MulticomponentDataset(causal_datasets)
causal_mcdset.cache = True
# Create loaders
train_loader = chemprop_data_utils.build_dataloader(train_mcdset, batch_size=256, num_workers=1, persistent_workers=True, shuffle=True)
val_loader = chemprop_data_utils.build_dataloader(val_mcdset, batch_size=32, num_workers=1, persistent_workers=True, shuffle=False)
causal_loader = chemprop_data_utils.build_dataloader(causal_mcdset, batch_size=32, num_workers=1, persistent_workers=True, shuffle=False)

Training

In [None]:
# build Chemprop
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing(depth=3, d_h=mpnn_hidden_size, dropout=0.50) for _ in range(len(all_data))],
    n_components=len(all_data),
)
agg = nn.MeanAggregation()
output_transform = UnscaleTransform.from_standard_scaler(target_scaler)
ffn = nn.RegressionFFN(
    input_dim=mcmp.output_dim + 1,  # temperature
    hidden_dim=fnn_hidden_size,
    dropout=0.50,
    n_layers=4,
    criterion=CustomMSEMetric(),
    output_transform=output_transform,
)
X_d_transform = ScaleTransform.from_standard_scaler(temperature_scaler)
metric_list = [CustomMSEMetric()]
mcmpnn = SobolevMulticomponentMPNN(
    fnn_hidden_size + 1,  # +1 for solubility
    _output_dir,
    mcmp,
    agg,
    ffn,
    batch_norm=True,
    metrics=metric_list,
    X_d_transform=X_d_transform,
    init_lr=0.00001,     
    max_lr=0.0001,
    final_lr=0.00001,
)

In [None]:
# configure trainer
tensorboard_logger = TensorBoardLogger(
    _output_dir,
    name="tensorboard_logs",
    default_hp_metric=False,
)
callbacks = [
    EarlyStopping(
        monitor="val/logs_loss",  
        mode="min",
        verbose=False,
        patience=10,
    ),
    ModelCheckpoint(
        monitor="val/logs_loss",
        dirpath=os.path.join(_output_dir, "checkpoints"),
        save_top_k=1,
        mode="min",
    ),
]
trainer = pl.Trainer(
    accelerator="cpu",
    max_epochs=200,           
    logger=tensorboard_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
    num_sanity_val_steps=0,
    inference_mode=False,       # to enable sobolev
)

In [None]:
# Fit MCMPNN and causal model
trainer.fit(mcmpnn, train_loader, val_loader)
ckpt_path = trainer.checkpoint_callback.best_model_path
mcmpnn = mcmpnn.__class__.load_from_checkpoint(ckpt_path)
mcmpnn.fit_causal(causal_loader, causal_loader)

Testing

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.manifold import TSNE
from lightning.pytorch import Trainer

from causal_chemprop_logS import SobolevMulticomponentMPNN
import evomol

In [None]:
def parity_plot(truth, prediction, title):
    r, _ = pearsonr(truth, prediction)
    mse = mean_squared_error(truth, prediction)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(truth, prediction)
    wn_07 = np.count_nonzero(np.abs(truth - prediction) <= 0.7) / len(prediction)
    wn_1 = np.count_nonzero(np.abs(truth - prediction) <= 1.0) / len(prediction)
    
    stat_str = (f" - RMSE: {rmse:.4f}\n - W/n 0.7: {wn_07:.4f}\n")
    # stat_str = f" - Pearson's r: {r:.4f}\n - RMSE: {rmse:.4f}"
    plt.clf()
    plt.scatter(
        truth, prediction,
        alpha=0.5,              
        s=25,                   
        edgecolors='black',     
        facecolors='blue',      
        marker='o'
    )
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.xlabel("Actual logS values", fontsize=14)
    plt.ylabel("Predicted logS values", fontsize=14)
    min_val = min(np.min(truth), np.min(prediction)) - 0.5
    max_val = max(np.max(truth), np.max(prediction)) + 0.5
    plt.plot([min_val, max_val], [min_val, max_val], color="black", linestyle="-")
    plt.plot([min_val, max_val], [min_val + 1, max_val + 1], color="red", linestyle="--", alpha=0.3)
    plt.plot([min_val, max_val], [min_val - 1, max_val - 1], color="red", linestyle="--", alpha=0.3)
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)
    plt.text(
        min_val, max_val - 0.1, stat_str,
        ha="left", va="top",
        fontsize=14,
        bbox=dict(facecolor='white', alpha=0.01, edgecolor='none')
    )
    plt.show()


In [None]:
# Load MCMPNN and causal model
ckpt_path = Path("output/chemprop_aqueous/checkpoints/epoch=16-step=170.ckpt")      # TODO: Update path
causal_pkl = Path("output/chemprop_aqueous/causal_model_histogram.pkl")             # TODO: Update path
model = SobolevMulticomponentMPNN.load_from_checkpoint(
    ckpt_path, 
    causal_pkl=causal_pkl
)
print(model)

In [None]:
# Test _causal_inference
for holdout_fpath in (
    Path("../data/aqueous.csv"),    # TODO: Update path
):
    df = pd.read_csv(holdout_fpath)
    df = df.loc[val_indexes]       
    test_datapoints = [
        [
            chemprop_data_utils.MoleculeDatapoint.from_smi(smi, None, x_d=np.array([temperature]))
            for smi, temperature in zip(df["solute_smiles"], df["temperature"])
        ],
        list(map(chemprop_data_utils.MoleculeDatapoint.from_smi, df["solvent_smiles"])),
    ]
    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
    test_datasets = [chemprop_data_utils.MoleculeDataset(test_datapoints[i], featurizer) for i in range(len(test_datapoints))]
    test_mcdset = chemprop_data_utils.MulticomponentDataset(test_datasets)
    test_loader = chemprop_data_utils.build_dataloader(test_mcdset, shuffle=False)
    trainer = Trainer(logger=False)
    predictions = np.concatenate(trainer.predict(model, test_loader), axis=0)
    parity_plot(df["logS"], predictions, holdout_fpath.stem)

In [None]:
# Test _causal_counterfactual
X, Y = [], []
for holdout_fpath in (
    Path("../data/mif.csv"),        # TODO: Update path
):
    mif_df = pd.read_csv(holdout_fpath)
    for i in range(1, len(mif_df)):
        df = mif_df.iloc[[0, i]]  
        test_datapoints = [
            [
                chemprop_data_utils.MoleculeDatapoint.from_smi(smi, None, x_d=np.array([temperature]))
                for smi, temperature in zip(df["solute_smiles"], df["temperature"])
            ],
            list(map(chemprop_data_utils.MoleculeDatapoint.from_smi, df["solvent_smiles"])),
        ]
        featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
        test_datasets = [chemprop_data_utils.MoleculeDataset(test_datapoints[i], featurizer) for i in range(len(test_datapoints))]
        test_mcdset = chemprop_data_utils.MulticomponentDataset(test_datasets)
        test_loader = chemprop_data_utils.build_dataloader(test_mcdset, shuffle=False)
        trainer = Trainer(logger=False)
        predictions = np.concatenate(trainer.predict(model, test_loader), axis=0)
        X.append(df["logS"].iloc[-1])
        Y.append(predictions[-1])
        print(df["solute_smiles"].iloc[-1], df["logS"].iloc[-1], predictions[-1])
    parity_plot(np.array(X).reshape(-1,1), np.array(Y).reshape(-1,1), "NNRTI")


seed_sol = target_scaler.transform([[-2.11742423766986]])

In [None]:
# Test _causal_generate
evomol.run_model({
    # Adjust weight ratio between solubility and plogp 
    "obj_function": {
        "type": "linear_combination",
        "functions": ["qed", "plogp", "norm_sascore"],
        "coef": [1.0, 0.0, 0.0]
    },            
    "optimization_parameters": {        
        "max_steps": 50,
        "model": model,   
    },
    # Constraint action space to preserve core
    "action_space_parameters": {
        "atoms": "C,N,O,F",
        "change_bond_prevent_breaking_creating_bonds": True,
        "remove_group_only_remove_smallest_group": False
    },
    # Adjust seed molecules to perturb
    "io_parameters": {
        "model_path": "output/neopoly",
        "smiles_list_init": [
            "C1=CC=C2C(=C1)C=CC(=N2)C3=CN(N=N3)C4=CC=C(C=C4)O",
            "C1=CC(=CC=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCCO)O",
            "C1=CC(=CC=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCCN)O",
            "COCCOC1=CC2=C(C=C1)N=C(C=C2)C3=CN(N=N3)C4=CC=C(C=C4)O",
            "C1=CC(=CC=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCCOCCN)O",
            "C1COCCN1CCOCCOC2=CC3=C(C=C2)N=C(C=C3)C4=CN(N=N4)C5=CC=C(C=C5)O",
            "C1=CC(=CC=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCC(=O)O)O",
            "C1=CC(=C(C=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCCOCCN)F)O",
            "C1COCCN1CCOCCOC2=CC3=C(C=C2)N=C(C=C3)C4=CN(N=N4)C5=CC(=C(C=C5)O)F",
            "C1=CC(=C(C=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCC(=O)O)F)O",
            "C1=CC(=C(C=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCCCC(=O)O)F)O",
            "C1=CC(=C(C=C1N2C=C(N=N2)C3=NC4=C(C=C3)C=C(C=C4)OCCOCC(=O)O)F)O",
            "OC(=O)COc1cc2nccc(c2cc1)c3cn(nn3)c4ccccc4",
            "COCCOc1ccc(cc1)Oc2ccnc3cc(-c4cnnn4c5ccc(F)cc5O)ccc23",
            "O=C(O)c1ccc(cc1)Oc2ccnc3cc(-c4cnnn4c5ccc(F)cc5O)ccc23"
        ]
    },   
})