# Training Regression - Multicomponent

In [1]:
import pandas as pd
from lightning import pytorch as pl
from pathlib import Path

from chemprop import data
from chemprop import featurizers
from chemprop import models
from chemprop import nn
from chemprop.nn import metrics
from chemprop.models import multi


# Load data

## Change your data inputs here

In [2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests/data/regression/mol+mol.csv" # path to your data .csv file containing SMILES strings and target values
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings
target_columns = ['peakwavs_max'] # list of names of the columns containing targets

## Read data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,solvent,peakwavs_max
0,CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...,ClCCl,642.0
1,C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...,ClCCl,420.0
2,CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...,O,544.0
3,c1ccc2[nH]ccc2c1,O,290.0
4,CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...,ClC(Cl)Cl,736.0
...,...,...,...
95,COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...,C1CCOC1,359.0
96,COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...,C1CCCCC1,386.0
97,CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O,CCO,425.0
98,Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...,c1ccccc1,324.0


## Get SMILES and targets

In [4]:
smiss = df_input.loc[:, smiles_columns].values
ys = df_input.loc[:, target_columns].values

In [5]:
# Take a look at the first 5 SMILES strings and targets
smiss[:5], ys[:5]

(array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',
         'ClCCl'],
        ['C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',
         'ClCCl'],
        ['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',
         'O'],
        ['c1ccc2[nH]ccc2c1', 'O'],
        ['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',
         'ClC(Cl)Cl']], dtype=object),
 array([[642.],
        [420.],
        [544.],
        [290.],
        [736.]]))

## Make molecule datapoints
Create a list of lists containing the molecule datapoints for each components. The target is stored in the 0th component.

In [6]:
all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]
all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_columns))]


# Split data

## Change your data splitting inputs here

In [7]:
split_key_molecule_index = 0 # key molecule used for splitting
split = 'random' # type of split
sizes = (0.8, 0.1, 0.1) # sizes of train, validation, and test sets

In [8]:
# available split types
list(data.SplitType.keys())

['CV_NO_VAL',
 'CV',
 'SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

## Split data based on key molecule

In [9]:
train_data, val_data, test_data = data.split_component(all_data, split=split, sizes=sizes, key_index=split_key_molecule_index)


# Get MoleculeDataset for each components

In [10]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_datasets = [data.MoleculeDataset(train_data[i], featurizer) for i in range(len(smiles_columns))]
val_datasets = [data.MoleculeDataset(val_data[i], featurizer) for i in range(len(smiles_columns))]
test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_columns))]

# Construct multicomponent dataset and scale the targets

In [11]:
train_mcdset = data.MulticomponentDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = data.MulticomponentDataset(test_datasets)
test_mcdset.normalize_targets(scaler)


# Construct data loader

In [12]:
train_loader = data.MolGraphDataLoader(train_mcdset)
val_loader = data.MolGraphDataLoader(val_mcdset, shuffle=False)
test_loader = data.MolGraphDataLoader(test_mcdset, shuffle=False)

# Construct multicomponent MPNN

## MulticomponentMessagePassing
- `blocks`: a list of message passing block used for each components
- `n_components`: number of components

In [13]:
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing() for i in range(len(smiles_columns))],
    n_components=len(smiles_columns),
)

## Aggregation

In [14]:
agg = nn.MeanAggregation()

## RegressionFFN

In [15]:
ffn = nn.RegressionFFN(
    loc=scaler.mean_, # pass in the mean of the training targets
    scale=scaler.scale_, # pass in the scale of the training targets
    input_dim=mcmp.output_dim,
)

## Metrics

In [16]:
metric_list = [metrics.RMSEMetric(), metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## MulticomponentMPNN

In [17]:
mcmpnn = multi.MulticomponentMPNN(
    mcmp,
    agg,
    ffn,
    metrics=metric_list,
)

mcmpnn

MulticomponentMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-1): 2 x BondMessagePassing(
        (W_i): Linear(in_features=147, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=433, out_features=300, bias=True)
        (dropout): Dropout(p=0, inplace=False)
        (tau): ReLU()
      )
    )
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Linear(in_features=600, out_features=300, bias=True)
      (1): ReLU()
      (2): Dropout(p=0, inplace=False)
      (3): Linear(in_features=300, out_features=1, bias=True)
    )
  )
)

# Set up trainer

In [18]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

/home/gridsan/adoner/mambaforge/envs/chemprop/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/gridsan/adoner/mambaforge/envs/chemprop/lib/py ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Start training

In [19]:
trainer.fit(mcmpnn, train_loader, val_loader)

Loading `train_dataloader` to estimate number of stepping batches.
/home/gridsan/adoner/mambaforge/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.

  | Name            | Type                         | Params
-----------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 528 K 
1 | agg             | MeanAggregation              | 0     
2 | bn              | BatchNorm1d                  | 1.2 K 
3 | predictor       | RegressionFFN                | 180 K 
  | other params    | n/a                          | 1     
-----------------------------------------------------------------
710 K     Trainable params
1         Non-trainable params
710 K     Total params
2.842     Total estima

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/gridsan/adoner/mambaforge/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 2/2 [00:04<00:00,  0.49it/s, train/loss=0.0605, val_loss=443.0]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 2/2 [00:04<00:00,  0.48it/s, train/loss=0.0605, val_loss=443.0]


# Test results

In [20]:
results = trainer.test(mcmpnn, test_loader)


/home/gridsan/adoner/mambaforge/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 19.80it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/mae            440.77487709353835
        test/rmse            441.3444630218094
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
