# Training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training.ipynb)

# Importing packages

In [None]:
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np

from chemprop import data, featurizers, models, nn, utils

# Data Inputs

In [None]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "data" / "train_smiles.csv" 
descriptors_path = chemprop_dir / "data" / "descriptors.csv"
num_workers = 0 
smiles_column = 'full_smiles' 
target_columns = ['rejection'] 

## Load data

In [13]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,full_smiles,rejection
0,CO.N#CC1=CC=C(N)C=C1,0.260000
1,CCCCCCC.COC1=CC=C(OC)C=C1,-0.033365
2,CC(OCC)=O.OC(CC1=C(C=CC=C1)NC2=C(C=CC=C2Cl)Cl)=O,0.143540
3,O.CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=CC=C3)N)C(...,0.926400
4,CO.C(CBr)Br,0.111000
...,...,...
9915,CO.C1=CC=C(C(=C1)C(=O)O)F,0.650000
9916,CC#N.CC(C1=CC=CC=C1)CC(C2=CC=CC=C2)CC(C3=CC=CC...,0.843000
9917,CC1=CC=CC=C1.CC(OCC(C)OCC(C)O)COC(C)COC(C)COC(...,0.928900
9918,O.CN(C)N=O,0.605100


## Get SMILES and targets

In [14]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [15]:
smis[:5] # show first 5 SMILES strings

array(['CO.N#CC1=CC=C(N)C=C1', 'CCCCCCC.COC1=CC=C(OC)C=C1',
       'CC(OCC)=O.OC(CC1=C(C=CC=C1)NC2=C(C=CC=C2Cl)Cl)=O',
       'O.CC1(C(N2C(S1)C(C2=O)NC(=O)C(C3=CC=CC=C3)N)C(=O)O)C',
       'CO.C(CBr)Br'], dtype=object)

In [16]:
ys[:5] # show first 5 targets

array([[ 0.26      ],
       [-0.03336461],
       [ 0.14354009],
       [ 0.9264    ],
       [ 0.111     ]])

## Molecule Extra Descriptors

In [17]:
df_descriptors = pd.read_csv(descriptors_path)
extra_mol_descriptors = np.array(df_descriptors.values)

## Get molecule datapoints

In [18]:
mols = [utils.make_mol(smi, keep_h=False, add_h=False) for smi in smis]

In [19]:
datapoints = [
    data.MoleculeDatapoint(mol, y, x_d=X_d)
    for mol, y, X_d in zip(
        mols,
        ys,
        extra_mol_descriptors,
    )
]

In [20]:
datapoints[:2]  # show first 2 datapoints

[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x32c1d0820>, y=array([0.26]), weight=1.0, gt_mask=None, lt_mask=None, x_d=array([150.   ,  -1.   ,  59.   ,  10.   ,  22.55 ,  32.04 ,   0.505,
          0.55 ,   0.792,   1.6  ,  33.   ,  14.5  ,  -0.82 ,   7.4  ,
          6.   ,  10.9  ,   0.88 ,  22.   ,   7.   ,   0.   ,   0.   ,
          1.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   1.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
          0.   ,  

## Data splitting for training, validation, and testing

Chemprop's `make_split_indices` function will always return a two- (if no validation) or three-length tuple.
Each member is a list of length `num_replicates`.
The inner lists then contain the actual indices for splitting.

The type signature for this return type is `tuple[list[list[int]], ...]`.

In [21]:
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1), num_replicates=3)  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    datapoints, train_indices, val_indices, test_indices
)

## Get MoleculeDataset
Recall that the data is in a list equal in length to the number of replicates, so we select the zero index of the list to get the first replicate.

In [22]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()
extra_mol_descriptors_scaler = train_dset.normalize_inputs("X_d")

val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)
val_dset.normalize_inputs("X_d", extra_mol_descriptors_scaler)

test_dset = data.MoleculeDataset(test_data[0], featurizer)

## Get DataLoader

In [23]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [24]:
mp = nn.BondMessagePassing()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` nn.agg.AggregationRegistry`, including
- `agg = nn.MeanAggregation()`
- `agg = nn.SumAggregation()`
- `agg = nn.NormAggregation()`

In [25]:
agg = nn.MeanAggregation()

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [26]:
ffn_input_dim = mp.output_dim + extra_mol_descriptors.shape[1]
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

In [27]:
ffn = nn.RegressionFFN(n_layers=2, input_dim=ffn_input_dim, output_transform=output_transform, dropout=0.5)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [28]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

In [29]:
print(nn.metrics.MetricRegistry)

ClassRegistry {
    'mse': <class 'chemprop.nn.metrics.MSE'>,
    'mae': <class 'chemprop.nn.metrics.MAE'>,
    'rmse': <class 'chemprop.nn.metrics.RMSE'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,
    'r2': <class 'chemprop.nn.metrics.R2Score'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,
    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,
    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,
    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>
}


In [30]:
metric_list = [nn.metrics.RMSE(), nn.metrics.R2Score()] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [31]:
X_d_transform = nn.ScaleTransform.from_standard_scaler(extra_mol_descriptors_scaler)

mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=406, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.5, inplace=False)
        (2): Linear(in_features=300, out_features=300, bias=True)
      )
      (2): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.5, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
 

# Set up trainer

In [32]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "model/checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)


trainer = pl.Trainer(
    logger=True,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=1, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


# Start training

In [33]:
trainer.fit(mpnn, train_loader, val_loader)

Loading `train_dataloader` to estimate number of stepping batches.
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:365: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
`Trainer.fit` stopped: `max_epochs=1` reached.


# Test results

In [None]:
results = trainer.test(dataloaders=test_loader, weights_only=False)  # weights_only=False is only required with pytorch lightning version 2.6.0 or newer

Restoring states from the checkpoint path at /Users/rossom/Desktop/Projects/nf10k/project_notebooks/training/model/checkpoints/best-epoch=0-val_loss=0.82.ckpt
Loaded model weights from the checkpoint at /Users/rossom/Desktop/Projects/nf10k/project_notebooks/training/model/checkpoints/best-epoch=0-val_loss=0.82.ckpt
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
