# Import packages

In [1]:
import pandas as pd

from lightning import pytorch as pl
from sklearn.model_selection import train_test_split
import torch

from chemprop import data
from chemprop import featurizers
from chemprop import models
from chemprop import nn
from chemprop.nn import metrics

# Change data inputs here

In [2]:
input_path = '../tests/data/regression.csv' # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['logSolubility'] # list of names of the columns containing targets

## Load data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,logSolubility
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...,-0.770
1,Cc1occc1C(=O)Nc2ccccc2,-3.300
2,CC(C)=CCCC(C)=CC(=O),-2.060
3,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,-7.870
4,c1ccsc1,-1.330
...,...,...
495,Nc1cc(nc(N)n1=O)N2CCCCC2,-1.989
496,Nc2cccc3nc1ccccc1cc23,-4.220
497,c1ccc2cc3c4cccc5cccc(c3cc2c1)c45,-8.490
498,OC(c1ccc(Cl)cc1)(c2ccc(Cl)cc2)C(Cl)(Cl)Cl,-5.666


## Get SMILES and targets

In [4]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [5]:
smis[:5] # show first 5 SMILES strings

array(['OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O',
       'Cc1occc1C(=O)Nc2ccccc2', 'CC(C)=CCCC(C)=CC(=O)',
       'c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43', 'c1ccsc1'], dtype=object)

In [6]:
ys[:5] # show first 5 targets

array([[-0.77],
       [-3.3 ],
       [-2.06],
       [-7.87],
       [-1.33]])

## Get molecule datapoints

In [7]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [8]:
train_data, val_test_data = train_test_split(all_data, test_size=0.1)
val_data, test_data = train_test_split(val_test_data, test_size=0.5)

## Get MoleculeDataset

In [9]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data, featurizer)
test_dset.normalize_targets(scaler)


## Get DataLoader

In [10]:
train_loader = data.MolGraphDataLoader(train_dset, num_workers=num_workers)
val_loader = data.MolGraphDataLoader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.MolGraphDataLoader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [11]:
mp = nn.BondMessagePassing()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` nn.agg.AggregationRegistry`, including
- `agg = nn.MeanAggregation()`
- `agg = nn.SumAggregation()`
- `agg = nn.NormAggregation()`

In [12]:
print(nn.agg.AggregationRegistry)

ClassRegistry {
    'mean': <class 'chemprop.nn.agg.MeanAggregation'>,
    'sum': <class 'chemprop.nn.agg.SumAggregation'>,
    'norm': <class 'chemprop.nn.agg.NormAggregation'>
}


In [13]:
agg = nn.MeanAggregation()

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [14]:
print(nn.PredictorRegistry)

ClassRegistry {
    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}


In [15]:
ffn = nn.RegressionFFN(
    loc=scaler.mean_, # pass in the mean of the training targets
    scale=scaler.scale_, # pass in the scale of the training targets
)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [16]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [17]:
print(metrics.MetricRegistry)

ClassRegistry {
    'mae': <class 'chemprop.nn.metrics.MAEMetric'>,
    'mse': <class 'chemprop.nn.metrics.MSEMetric'>,
    'rmse': <class 'chemprop.nn.metrics.RMSEMetric'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAEMetric'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSEMetric'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSEMetric'>,
    'r2': <class 'chemprop.nn.metrics.R2Metric'>,
    'roc': <class 'chemprop.nn.metrics.AUROCMetric'>,
    'prc': <class 'chemprop.nn.metrics.AUPRCMetric'>,
    'accuracy': <class 'chemprop.nn.metrics.AccuracyMetric'>,
    'f1': <class 'chemprop.nn.metrics.F1Metric'>,
    'bce': <class 'chemprop.nn.metrics.BCEMetric'>,
    'ce': <class 'chemprop.nn.metrics.CrossEntropyMetric'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'sid': <class 'chemprop.nn.metrics.SIDMetric'>,
    'wasserstein': <class 'chemprop.nn.metrics.Wass

In [18]:
metric_list = [metrics.RMSEMetric(), metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [19]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=147, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=433, out_features=300, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (tau): ReLU()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Linear(in_features=300, out_features=300, bias=True)
      (1): ReLU()
      (2): Dropout(p=0, inplace=False)
      (3): Linear(in_features=300, out_features=1, bias=True)
    )
  )
)

# Set up trainer

In [20]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=False, # Use `True` if you want to save model checkpoints
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Start training

In [21]:
trainer.fit(mpnn, train_loader, val_loader)
torch.save(mpnn.state_dict(), "model.pt")

Loading `train_dataloader` to estimate number of stepping batches.
/Users/kevingreenman/miniconda3/envs/chemprop-v2/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
  gammas = (self.final_lrs / max_lrs) ** (1 / cooldown_steps)

  | Name            | Type               | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 264 K 
1 | agg             | MeanAggregation    | 0     
2 | bn              | BatchNorm1d        | 600   
3 | predictor       | RegressionFFN      | 90.6 K
  | other params    | n/a                | 1     
-------------------------------------------------------
355 K     Trainable params
1         Non-trainable params
355 K     Total params
1.422     Total estimated model params

                                                                           

/Users/kevingreenman/miniconda3/envs/chemprop-v2/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 11.95it/s, train/loss=0.376, val_loss=3.340]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 11.93it/s, train/loss=0.376, val_loss=3.340]


# Test results

In [22]:
results = trainer.test(mpnn, test_loader)


/Users/kevingreenman/miniconda3/envs/chemprop-v2/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 55.16it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/mae            3.2070827051509774
        test/rmse           3.2070827051509774
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
