# Training

# Import packages

In [46]:
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
from sklearn.model_selection import KFold
import csv
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import CSVLogger


from chemprop import data, featurizers, models, nn

# Change data inputs here

In [2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "data" / "folds" / "train_smiles_fold_0.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'full_smiles' # name of the column containing SMILES strings
target_columns = ['rejection'] # list of names of the columns containing targets

## Load data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,full_smiles,rejection
0,CO.N#CC1=CC=C(N)C=C1,0.260000
1,CCCCCCC.COC1=CC=C(OC)C=C1,-0.033365
2,CC(OCC)=O.OC(CC1=C(C=CC=C1)NC2=C(C=CC=C2Cl)Cl)=O,0.143540
3,CO.C(CBr)Br,0.111000
4,CO.Cc1ccc(I)cc1,0.380000
...,...,...
8497,CO.C1=CC=C(C(=C1)C(=O)O)F,0.650000
8498,CC#N.CC(C1=CC=CC=C1)CC(C2=CC=CC=C2)CC(C3=CC=CC...,0.843000
8499,CC1=CC=CC=C1.CC(OCC(C)OCC(C)O)COC(C)COC(C)COC(...,0.928900
8500,O.CN(C)N=O,0.605100


## Get SMILES and targets

In [4]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

## Get molecule datapoints

In [5]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [6]:
# available split types
list(data.SplitType.keys())

['SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

Chemprop's `make_split_indices` function will always return a two- (if no validation) or three-length tuple.
Each member is a list of length `num_replicates`.
The inner lists then contain the actual indices for splitting.

The type signature for this return type is `tuple[list[list[int]], ...]`.

In [7]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


Chemprop's splitting function implements our preferred method of data splitting, which is random replication.
It's also possible to add your own custom cross-validation splitter, such as one of those as implemented in scikit-learn, as long as you get the data into the same `tuple[list[list[int]], ...]` data format with something like this:

In [9]:
k_splits = KFold(n_splits=5)
k_train_indices, k_val_indices, k_test_indices = [], [], []
for fold in k_splits.split(mols):
    k_train_indices.append(fold[0])
    k_val_indices.append([])
    k_test_indices.append(fold[1])
k_train_data, _, k_test_data = data.split_data_by_indices(
    all_data, k_train_indices, None, k_test_indices
)

## Get MoleculeDataset
Recall that the data is in a list equal in length to the number of replicates, so we select the zero index of the list to get the first replicate.

In [32]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)

test_dset = data.MoleculeDataset(test_data[0], featurizer) # NOTE: it's not normalized

## Get DataLoader

In [33]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing and Aggregation
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [34]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation() # Mean aggregation as in the CLI

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [35]:
print(nn.PredictorRegistry)

ClassRegistry {
    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
    'regression-quantile': <class 'chemprop.nn.predictors.QuantileFFN'>,
    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}


In [36]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

In [37]:
ffn = nn.RegressionFFN(output_transform=output_transform)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [38]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [39]:
print(nn.metrics.MetricRegistry)

ClassRegistry {
    'mse': <class 'chemprop.nn.metrics.MSE'>,
    'mae': <class 'chemprop.nn.metrics.MAE'>,
    'rmse': <class 'chemprop.nn.metrics.RMSE'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,
    'r2': <class 'chemprop.nn.metrics.R2Score'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,
    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,
    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,
    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>
}


In [40]:
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [None]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=scaler)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(
    (0): RMSE(task_weights=[[1.0]])
    (1): MAE(task_weigh

# Set up trainer

In [51]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)

# This automatically saves to "logs/train_smiles_fold_0_no_descr/metrics.csv"
logger = CSVLogger("logs", name="my_exp")
trainer = pl.Trainer(
    logger=logger,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores


# Start training

In [52]:
trainer.fit(mpnn, train_loader, val_loader)

Loading `train_dataloader` to estimate number of stepping batches.


`Trainer.fit` stopped: `max_epochs=20` reached.


# Test results

Note: the `test_loader` was not scaled

In [53]:
results = trainer.test(dataloaders=test_loader, weights_only=False)

Restoring states from the checkpoint path at /Users/rossom/Desktop/Projects/nf10k/notebooks/checkpoints/best-epoch=2-val_loss=0.29.ckpt
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:365: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
Loaded model weights from the checkpoint at /Users/rossom/Desktop/Projects/nf10k/notebooks/checkpoints/best-epoch=2-val_loss=0.29.ckpt


/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


## Validation results

Note: the `val_loader` targets were scaled so we need to un-scale it

In [54]:
val_dset.reset() # might alter the dataset state, but it un-normalizes the targets
# see: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/data/datasets/index.html#chemprop.data.datasets.MoleculeDataset
val_loader_unscaled = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
results = trainer.test(dataloaders=val_loader_unscaled, weights_only=False)

Restoring states from the checkpoint path at /Users/rossom/Desktop/Projects/nf10k/notebooks/checkpoints/best-epoch=2-val_loss=0.29.ckpt
Loaded model weights from the checkpoint at /Users/rossom/Desktop/Projects/nf10k/notebooks/checkpoints/best-epoch=2-val_loss=0.29.ckpt
