In [29]:
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np
import torch
from chemprop import data, featurizers, models, nn, utils

In [30]:
# Define paths
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "training" / "data" / "train_smiles.csv" 
descriptors_path = chemprop_dir / "training" / "data" / "descriptors.csv"
num_workers = 0 
smiles_column = 'full_smiles' 
target_columns = ['rejection'] 

# Load data
df_input = pd.read_csv(input_path)
print(df_input.head())
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

# Extract additional descriptors
df_descriptors = pd.read_csv(descriptors_path)
extra_mol_descriptors = np.array(df_descriptors.values)
mols = [utils.make_mol(smi, keep_h=False, add_h=False) for smi in smis]

# Define datapoints
datapoints = [
    data.MoleculeDatapoint(mol, y, x_d=X_d)
    for mol, y, X_d in zip(
        mols,
        ys,
        extra_mol_descriptors,
    )
]

# Split data
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1), num_replicates=3)  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    datapoints, train_indices, val_indices, test_indices
)

# Define data loaders
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()
extra_mol_descriptors_scaler = train_dset.normalize_inputs("X_d")

val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)
val_dset.normalize_inputs("X_d", extra_mol_descriptors_scaler)

test_dset = data.MoleculeDataset(test_data[0], featurizer)

train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

                                         full_smiles  rejection
0                               CO.N#CC1=CC=C(N)C=C1   0.260000
1                          CCCCCCC.COC1=CC=C(OC)C=C1  -0.033365
2   CC(OCC)=O.OC(CC1=C(C=CC=C1)NC2=C(C=CC=C2Cl)Cl)=O   0.143540
3  O.CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=CC=C3)N)C(...   0.926400
4                                        CO.C(CBr)Br   0.111000


In [31]:
# Define the model
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
ffn_input_dim = mp.output_dim + extra_mol_descriptors.shape[1]
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
ffn = nn.RegressionFFN(n_layers=2, input_dim=ffn_input_dim, output_transform=output_transform, dropout=0.5)
batch_norm = True
print(nn.metrics.MetricRegistry)
metric_list = [nn.metrics.RMSE(), nn.metrics.R2Score()] # Only the first metric is used for training and early stopping

X_d_transform = nn.ScaleTransform.from_standard_scaler(extra_mol_descriptors_scaler)

ensemble = []
n_models = 3
for _ in range(n_models):
    ensemble.append(models.MPNN(mp, agg, ffn, metric_list, X_d_transform=X_d_transform))

ClassRegistry {
    'mse': <class 'chemprop.nn.metrics.MSE'>,
    'mae': <class 'chemprop.nn.metrics.MAE'>,
    'rmse': <class 'chemprop.nn.metrics.RMSE'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,
    'r2': <class 'chemprop.nn.metrics.R2Score'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,
    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,
    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,
    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>
}


In [32]:
# Train
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "model/checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)

trainers = []
for model in ensemble:
    trainer = pl.Trainer(
        logger=True,
        enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
        enable_progress_bar=True,
        accelerator="auto",
        devices=1,
        max_epochs=1, # number of epochs to train for
        callbacks=[checkpointing], # Use the configured checkpoint callback
    )
    trainers.append(trainer)

for trainer, model in zip(trainers, ensemble):
    trainer.fit(model, train_loader, val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /Users/rossom/Desktop/Projects/nf10k/project_notebooks/training/model/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:365: Skipping 'batch_norm' parameter because it is not possible to safely dump to YAML.
`Trainer.fit` stopped: `max_epochs=1` reached.


Loading `train_dataloader` to estimate number of stepping batches.


`Trainer.fit` stopped: `max_epochs=1` reached.


Loading `train_dataloader` to estimate number of stepping batches.


`Trainer.fit` stopped: `max_epochs=1` reached.


In [None]:
# Make predictions with the ensemble
predictions = []
for trainer, model in zip(trainers, ensemble):
    predictions.append(torch.concat(trainer.predict(model, test_loader)))

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


In [34]:
stacked_predictions = torch.stack(predictions) # [M, N, T]

# Calculate uncertainty metrics
mean_prediction = stacked_predictions.mean(dim=0)  # Mean across models
variance = stacked_predictions.var(dim=0)  # Variance across models (aleatoric + epistemic)
std_dev = stacked_predictions.std(dim=0)  # Standard deviation

# Total uncertainty (variance of ensemble predictions)
total_uncertainty = variance.squeeze()

# Coefficient of variation (normalized uncertainty)
coefficient_of_variation = (std_dev / (mean_prediction.abs() + 1e-8)).squeeze()

# Convert to numpy for analysis
mean_prediction_np = mean_prediction.numpy()
total_uncertainty_np = total_uncertainty.numpy()
std_dev_np = std_dev.numpy()
cv_np = coefficient_of_variation.numpy()

# Display uncertainty statistics
print(f"Mean prediction: {mean_prediction_np.mean():.4f}")
print(f"Mean uncertainty (std): {std_dev_np.mean():.4f}")
print(f"Mean coefficient of variation: {cv_np.mean():.4f}")

# Create uncertainty dataframe
uncertainty_df = pd.DataFrame({
    'mean_prediction': mean_prediction_np.flatten(),
    'std_dev': std_dev_np.flatten(),
    'variance': total_uncertainty_np.flatten(),
    'coefficient_of_variation': cv_np.flatten()
})

uncertainty_df.head()

Mean prediction: 1.1549
Mean uncertainty (std): 0.5864
Mean coefficient of variation: 0.5397


Unnamed: 0,mean_prediction,std_dev,variance,coefficient_of_variation
0,1.339892,0.570404,0.32536,0.425709
1,0.914156,0.457258,0.209085,0.500197
2,1.092786,0.577866,0.333929,0.5288
3,0.589933,0.188326,0.035467,0.319232
4,0.85264,0.495163,0.245187,0.580741


In [35]:
print(len(datapoints))

9920
