In [None]:
import datetime
import logging
import os
from pathlib import Path

import numpy as np
import pandas as pd
from astartes import train_test_split
from chemprop import data as chemprop_data_utils
from chemprop import featurizers, nn
from lightning import pytorch as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.preprocessing import StandardScaler
import torch

from causal_chemprop import SobolevMPNN, CustomMSEMetric
from chemprop.nn.transforms import ScaleTransform, UnscaleTransform

In [None]:
RANDOM_SEED = 8675309
TRAINING_FPATH = Path("../data/aurka.csv")          # TODO: Update path

# setup logging and output directories
_output_dir = Path(f"output/chemprop_{int(datetime.datetime.now(datetime.UTC).timestamp())}")
os.makedirs(_output_dir, exist_ok=True)
logging.getLogger("pytorch_lightning").setLevel(logging.INFO)
seed_everything(RANDOM_SEED)

downsample_percent = 1.0
mpnn_hidden_size = 800
fnn_hidden_size = 200

Data preparation

In [None]:
# load the training data
df = pd.read_csv(TRAINING_FPATH)
if downsample_percent:
    print(f"Down-sampling training data to {downsample_percent:.2%} size!")
    downsample_df = df.copy()
    downsample_df["original_index"] = np.arange(len(df))
    downsample_df = downsample_df.groupby(["SMI", "MR_ID"]).aggregate(list)
    downsample_df = downsample_df.sample(frac=downsample_percent, replace=False, random_state=RANDOM_SEED)
    chosen_indexes = downsample_df.explode("original_index")["original_index"].to_numpy().flatten().astype(int)
    print(f"Actual downsample percentage is {len(chosen_indexes)/len(df):.2%}, count: {len(chosen_indexes)}!")
    df = df.iloc[chosen_indexes]
    df.reset_index(inplace=True, drop=True)

# train-test split
studies_train, studies_val = train_test_split(pd.unique(df["MR_ID"]), random_state=RANDOM_SEED, train_size=0.8, test_size=0.2)
train_indexes = df.index[df["MR_ID"].isin(studies_train)].tolist()
val_indexes = df.index[df["MR_ID"].isin(studies_val)].tolist()
test_indexes = val_indexes[len(val_indexes)//2:]
val_indexes = val_indexes[:len(val_indexes)//2]
_total = len(df)
print(f"train: {len(train_indexes)} ({len(train_indexes)/_total:.0%}) validation:" f"{len(val_indexes)} ({len(val_indexes)/_total:.0%}) test:" f"{len(test_indexes)} ({len(test_indexes)/_total:.0%})")

In [None]:
# manual re-scaling
target_scaler = StandardScaler().fit(df[["logS"]].iloc[train_indexes])
scaled_logs = target_scaler.transform(df[["logS"]]).ravel()
all_data = [
    chemprop_data_utils.MoleculeDatapoint.from_smi(smi, [log_s])
    for smi, log_s in zip(df["SMI"], scaled_logs)
]
# create datasets and dataloaders
train_data, val_data, _ = chemprop_data_utils.split_data_by_indices(all_data, train_indices=train_indexes, val_indices=val_indexes)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_datasets = chemprop_data_utils.MoleculeDataset(train_data, featurizer) 
val_datasets = chemprop_data_utils.MoleculeDataset(val_data, featurizer) 
causal_datasets = chemprop_data_utils.MoleculeDataset(train_data, featurizer) 
train_loader = chemprop_data_utils.build_dataloader(train_datasets, batch_size=64, num_workers=1, persistent_workers=True, shuffle=True)
val_loader = chemprop_data_utils.build_dataloader(val_datasets, batch_size=16, num_workers=1, persistent_workers=True, shuffle=False)
causal_loader = chemprop_data_utils.build_dataloader(causal_datasets, batch_size=16, num_workers=1, persistent_workers=True, shuffle=False)

Training

In [None]:
# build Chemprop
mcmp = nn.BondMessagePassing(
    depth=3, 
    d_h=mpnn_hidden_size, 
    dropout=0.50
) 
agg = nn.MeanAggregation()
output_transform = UnscaleTransform.from_standard_scaler(target_scaler)
ffn = nn.RegressionFFN(
    input_dim=mcmp.output_dim,  
    hidden_dim=fnn_hidden_size,
    dropout=0.50,
    n_layers=4,
    criterion=CustomMSEMetric(),
    output_transform=output_transform,
)
metric_list = [CustomMSEMetric()]
mcmpnn = SobolevMPNN(
    fnn_hidden_size + 1,  # +1 for solubility
    _output_dir,
    mcmp,
    agg,
    ffn,
    batch_norm=True,
    metrics=metric_list,
    init_lr=0.00001,     
    max_lr=0.0001,
    final_lr=0.00001,
)

In [None]:
# configure trainer
tensorboard_logger = TensorBoardLogger(
    _output_dir,
    name="tensorboard_logs",
    default_hp_metric=False,
)
callbacks = [
    EarlyStopping(
        monitor="val/logs_loss",  # don't condition on neopoly loss, use FNN loss instead
        mode="min",
        verbose=False,
        patience=10,
    ),
    ModelCheckpoint(
        monitor="val/logs_loss",
        dirpath=os.path.join(_output_dir, "checkpoints"),
        save_top_k=1,
        mode="min",
    ),
]
trainer = pl.Trainer(
    max_epochs=200,           
    logger=tensorboard_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
    num_sanity_val_steps=0,
    inference_mode=False,       # to enable sobolev
)

In [None]:
# Fit MCMPNN and causal model
trainer.fit(mcmpnn, train_loader, val_loader)
ckpt_path = trainer.checkpoint_callback.best_model_path
mcmpnn = mcmpnn.__class__.load_from_checkpoint(ckpt_path)
mcmpnn.fit_causal(causal_loader, causal_loader)
torch.save(mcmpnn, _output_dir / "chemprop_model.pt")

Testing

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.manifold import TSNE
from lightning.pytorch import Trainer

from causal_chemprop import SobolevMPNN

In [None]:
def parity_plot(truth, prediction, title):
    r, _ = pearsonr(truth, prediction)
    mse = mean_squared_error(truth, prediction)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(truth, prediction)
    wn_07 = np.count_nonzero(np.abs(truth - prediction) <= 0.7) / len(prediction)
    wn_1 = np.count_nonzero(np.abs(truth - prediction) <= 1.0) / len(prediction)

    stat_str = f" - Pearson's r: {r:.4f}\n - RMSE: {rmse:.4f}"
    plt.clf()
    plt.scatter(
        truth, prediction,
        alpha=0.5,              # increase transparency to make points more visible
        s=25,                   # slightly larger markers
        edgecolors='black',     # remove outlines for better density visualization
        facecolors='blue',      # fill color
        marker='o'
    )
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.xlabel("Actual IC50 values", fontsize=14)
    plt.ylabel("Predicted IC50 values", fontsize=14)
    min_val = min(np.min(truth), np.min(prediction)) - 0.5
    max_val = max(np.max(truth), np.max(prediction)) + 0.5
    plt.plot([min_val, max_val], [min_val, max_val], color="black", linestyle="-")
    plt.plot([min_val, max_val], [min_val + 1, max_val + 1], color="red", linestyle="--", alpha=0.3)
    plt.plot([min_val, max_val], [min_val - 1, max_val - 1], color="red", linestyle="--", alpha=0.3)
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)
    plt.text(
        min_val, max_val - 0.1, stat_str,
        ha="left", va="top",
        fontsize=14,
        bbox=dict(facecolor='white', alpha=0.01, edgecolor='none')
    )
    plt.show()

In [None]:
# Load MCMPNN and causal model
ckpt_path = Path("output/chemprop_aurka8020/checkpoints/epoch=61-step=682.ckpt")        # TODO: Update path
causal_pkl = Path("output/chemprop_aurka8020/causal_model.pkl")                         # TODO: Update path
model = SobolevMPNN.load_from_checkpoint(
    ckpt_path, 
    causal_pkl=causal_pkl
)
print(model)

In [None]:
# Test _causal_inference
for holdout_fpath in (     
    Path("../data/aurka.csv"),          # TODO: Update path
):
    df = pd.read_csv(holdout_fpath)
    df = df.loc[test_indexes]
    test_datapoints = [chemprop_data_utils.MoleculeDatapoint.from_smi(smi, None, x_d=None) for smi in df["SMI"]]
    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
    test_datasets = chemprop_data_utils.MoleculeDataset(test_datapoints, featurizer)
    test_loader = chemprop_data_utils.build_dataloader(test_datasets, shuffle=False)
    trainer = Trainer(logger=False)
    predictions = np.concatenate(trainer.predict(model, test_loader), axis=0)
    parity_plot(df["logS"], predictions, holdout_fpath.stem)