# Import packages

In [2]:
import pandas as pd

from lightning import pytorch as pl
from sklearn.model_selection import train_test_split

from chemprop.v2 import data
from chemprop.v2 import featurizers
from chemprop.v2.models import modules, models, metrics

# Change data inputs here

In [3]:
input_path = '../tests/data/regression.csv' # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['logSolubility'] # list of names of the columns containing targets

## Load data

In [4]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,logSolubility
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...,-0.770
1,Cc1occc1C(=O)Nc2ccccc2,-3.300
2,CC(C)=CCCC(C)=CC(=O),-2.060
3,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,-7.870
4,c1ccsc1,-1.330
...,...,...
495,Nc1cc(nc(N)n1=O)N2CCCCC2,-1.989
496,Nc2cccc3nc1ccccc1cc23,-4.220
497,c1ccc2cc3c4cccc5cccc(c3cc2c1)c45,-8.490
498,OC(c1ccc(Cl)cc1)(c2ccc(Cl)cc2)C(Cl)(Cl)Cl,-5.666


## Get smiles and targets

In [5]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [6]:
smis[:5] # show first 5 SMILES strings

array(['OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O',
       'Cc1occc1C(=O)Nc2ccccc2', 'CC(C)=CCCC(C)=CC(=O)',
       'c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43', 'c1ccsc1'], dtype=object)

In [7]:
ys[:5] # show first 5 targets

array([[-0.77],
       [-3.3 ],
       [-2.06],
       [-7.87],
       [-1.33]])

## Get molecule datapoints

In [8]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [9]:
train_data, val_test_data = train_test_split(all_data, test_size=0.1)
val_data, test_data = train_test_split(val_test_data, test_size=0.5)

## Get MoleculeDataset

In [10]:
featurizer = featurizers.MoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data, featurizer)
test_dset.normalize_targets(scaler)


## Get DataLoader

In [11]:
train_loader = data.MolGraphDataLoader(train_dset, num_workers=num_workers)
val_loader = data.MolGraphDataLoader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.MolGraphDataLoader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = modules.BondMessageBlock()` or `mp = modules.AtomMessageBlock()`

In [12]:
mp = modules.BondMessageBlock()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` modules.agg.AggregationRegistry`, including
- `agg = modules.MeanAggregation()`
- `agg = modules.SumAggregation()`
- `agg = modules.NormAggregation()`

In [13]:
print(modules.agg.AggregationRegistry)

ClassRegistry {
    'mean': <class 'chemprop.v2.models.modules.agg.MeanAggregation'>,
    'sum': <class 'chemprop.v2.models.modules.agg.SumAggregation'>,
    'norm': <class 'chemprop.v2.models.modules.agg.NormAggregation'>
}


In [14]:
agg = modules.MeanAggregation()

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `modules.ReadoutRegistry`.

For regression:
- `ffn = modules.RegressionFFN()`
- `ffn = modules.MveFFN()`
- `ffn = modules.EvidentialFFN()`

For classification:
- `ffn = modules.BinaryClassificationFFN()`
- `ffn = modules.BinaryDirichletFFN()`
- `ffn = modules.MulticlassClassificationFFN()`
- `ffn = modules.MulticlassDirichletFFN()`

For spectral:
- `ffn = modules.SpectralFFN()` # will be available in future version

In [15]:
print(modules.ReadoutRegistry)

ClassRegistry {
    'regression': <class 'chemprop.v2.models.modules.readout.RegressionFFN'>,
    'regression-mve': <class 'chemprop.v2.models.modules.readout.MveFFN'>,
    'regression-evidential': <class 'chemprop.v2.models.modules.readout.EvidentialFFN'>,
    'classification': <class 'chemprop.v2.models.modules.readout.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.v2.models.modules.readout.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.v2.models.modules.readout.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.v2.models.modules.readout.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.v2.models.modules.readout.SpectralFFN'>
}


In [16]:
ffn = modules.RegressionFFN(
    loc=scaler.mean_, # pass in the mean of the training targets
    scale=scaler.scale_, # pass in the scale of the training targets
)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [17]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [18]:
print(metrics.MetricRegistry)

ClassRegistry {
    'mae': <class 'chemprop.v2.models.metrics.MAEMetric'>,
    'mse': <class 'chemprop.v2.models.metrics.MSEMetric'>,
    'rmse': <class 'chemprop.v2.models.metrics.RMSEMetric'>,
    'bounded-mae': <class 'chemprop.v2.models.metrics.BoundedMAEMetric'>,
    'bounded-mse': <class 'chemprop.v2.models.metrics.BoundedMSEMetric'>,
    'bounded-rmse': <class 'chemprop.v2.models.metrics.BoundedRMSEMetric'>,
    'r2': <class 'chemprop.v2.models.metrics.R2Metric'>,
    'roc': <class 'chemprop.v2.models.metrics.AUROCMetric'>,
    'prc': <class 'chemprop.v2.models.metrics.AUPRCMetric'>,
    'accuracy': <class 'chemprop.v2.models.metrics.AccuracyMetric'>,
    'f1': <class 'chemprop.v2.models.metrics.F1Metric'>,
    'bce': <class 'chemprop.v2.models.metrics.BCEMetric'>,
    'ce': <class 'chemprop.v2.models.metrics.CrossEntropyMetric'>,
    'mcc': <class 'chemprop.v2.models.metrics.MCCMetric'>,
    'sid': <class 'chemprop.v2.models.metrics.SIDMetric'>,
    'wasserstein': <class 'chemp

In [19]:
metric_list = [metrics.RMSEMetric(), metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [20]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

mpnn

MPNN(
  (message_passing): BondMessageBlock(
    (W_i): Linear(in_features=147, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=433, out_features=300, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (tau): ReLU()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (readout): RegressionFFN(
    (ffn): SimpleFFN(
      (ffn): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
        (1): ReLU()
        (2): Dropout(p=0, inplace=False)
        (3): Linear(in_features=300, out_features=1, bias=True)
      )
    )
  )
)

# Set up trainer

In [21]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=False, # Use `True` if you want to save model checkpoints
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Start training

In [22]:
trainer.fit(mpnn, train_loader, val_loader)

Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(

  | Name            | Type             | Params
-----------------------------------------------------
0 | message_passing | BondMessageBlock | 264 K 
1 | agg             | MeanAggregation  | 0     
2 | bn              | BatchNorm1d      | 600   
3 | readout         | RegressionFFN    | 90.6 K
-----------------------------------------------------
355 K     Trainable params
3         Non-trainable params
355 K     Total params
1.422     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

  rank_zero_warn(


Epoch 19: 100%|██████████| 9/9 [00:01<00:00,  4.87it/s, train/loss=0.0736, val_loss=3.010]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 9/9 [00:01<00:00,  4.86it/s, train/loss=0.0736, val_loss=3.010]


# Test results

In [23]:
results = trainer.test(mpnn, test_loader)


  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 88.49it/s] 
