# Training Regression - Multicomponent

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training_regression_multicomponent.ipynb)

In [None]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

In [1]:
import pandas as pd
from lightning import pytorch as pl
from pathlib import Path

from chemprop import data, featurizers, models, nn
from chemprop.nn import metrics
from chemprop.models import multi


# Load data

## Change your data inputs here

In [2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol+mol" / "mol+mol.csv" # path to your data .csv file containing SMILES strings and target values
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings
target_columns = ['peakwavs_max'] # list of names of the columns containing targets

## Read data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,solvent,peakwavs_max
0,CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...,ClCCl,642.0
1,C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...,ClCCl,420.0
2,CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...,O,544.0
3,c1ccc2[nH]ccc2c1,O,290.0
4,CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...,ClC(Cl)Cl,736.0
...,...,...,...
95,COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...,C1CCOC1,359.0
96,COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...,C1CCCCC1,386.0
97,CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O,CCO,425.0
98,Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...,c1ccccc1,324.0


## Get SMILES and targets

In [4]:
smiss = df_input.loc[:, smiles_columns].values
ys = df_input.loc[:, target_columns].values

In [24]:
# [O-]c1c(-c2ccccc2)cc(-[n+]2c(-c3ccccc3)cc(-c3ccccc3)cc2-c2ccccc2)cc1-c1ccccc1 find in df_input['smiles']

molecule_row = df_input.loc[df_input['smiles'] == '[O-]c1c(-c2ccccc2)cc(-[n+]2c(-c3ccccc3)cc(-c3ccccc3)cc2-c2ccccc2)cc1-c1ccccc1']
molecule_row

Unnamed: 0,smiles,solvent,peakwavs_max
29,[O-]c1c(-c2ccccc2)cc(-[n+]2c(-c3ccccc3)cc(-c3c...,ClC(Cl)Cl,731.227622
74,[O-]c1c(-c2ccccc2)cc(-[n+]2c(-c3ccccc3)cc(-c3c...,CCCCCCCCCC,922.290323
86,[O-]c1c(-c2ccccc2)cc(-[n+]2c(-c3ccccc3)cc(-c3c...,CC(C)(C)c1cccc(C(C)(C)C)n1,840.911765


In [5]:
# Take a look at the first 5 SMILES strings and targets
smiss[:5], ys[:5]

(array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',
         'ClCCl'],
        ['C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',
         'ClCCl'],
        ['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',
         'O'],
        ['c1ccc2[nH]ccc2c1', 'O'],
        ['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',
         'ClC(Cl)Cl']], dtype=object),
 array([[642.],
        [420.],
        [544.],
        [290.],
        [736.]]))

## Make molecule datapoints
Create a list of lists containing the molecule datapoints for each components. The target is stored in the 0th component.

In [7]:
all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]
all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_columns))]


In [18]:
len(all_data)

2

# Split data

## Perform data splitting for training, validation, and testing

In [8]:
component_to_split_by = 0 # index of the component to use for structure based splits
mols = [d.mol for d in all_data[component_to_split_by]]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


In [19]:
len(train_data)

1

In [25]:
train_data[0][0]

[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7b9a4a48b8b0>, y=array([840.9117647]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[O-]c1c(-c2ccccc2)cc(-[n+]2c(-c3ccccc3)cc(-c3ccccc3)cc2-c2ccccc2)cc1-c1ccccc1', V_f=None, E_f=None, V_d=None),
 MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7b9a4a70e880>, y=array([544.]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1', V_f=None, E_f=None, V_d=None),
 MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7b9a4a489d90>, y=array([366.2]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CCCCCCC(CCCCCC)N1C(=O)c2cccc3c(-c4ccc(OC)cc4)ccc(c23)C1=O', V_f=None, E_f=None, V_d=None),
 MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7b9a4a48af10>, y=array([461.]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CCCCCCCCOc1ccc(C#Cc2ccc(C#Cc3ccc(OCCCCCCCC)c(OCCCCCCCC)c

In [21]:
train_data[0][1][0]

MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7b9a4a4790e0>, y=None, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CC(C)(C)c1cccc(C(C)(C)C)n1', V_f=None, E_f=None, V_d=None)

# Get MoleculeDataset for each components

In [20]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_datasets = [data.MoleculeDataset(train_data[0][i], featurizer) for i in range(len(smiles_columns))]
val_datasets = [data.MoleculeDataset(val_data[0][i], featurizer) for i in range(len(smiles_columns))]
test_datasets = [data.MoleculeDataset(test_data[0][i], featurizer) for i in range(len(smiles_columns))]

In [None]:
train_datasets

# Construct multicomponent dataset and scale the targets

In [9]:
train_mcdset = data.MulticomponentDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = data.MulticomponentDataset(test_datasets)


# Construct data loader

In [10]:
train_loader = data.build_dataloader(train_mcdset)
val_loader = data.build_dataloader(val_mcdset, shuffle=False)
test_loader = data.build_dataloader(test_mcdset, shuffle=False)

# Construct multicomponent MPNN

## MulticomponentMessagePassing
- `blocks`: a list of message passing block used for each components
- `n_components`: number of components

In [11]:
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing() for _ in range(len(smiles_columns))],
    n_components=len(smiles_columns),
)

## Aggregation

In [12]:
agg = nn.MeanAggregation()

## RegressionFFN

In [13]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

In [14]:
ffn = nn.RegressionFFN(
    input_dim=mcmp.output_dim,
    output_transform=output_transform,
)

## Metrics

In [15]:
metric_list = [metrics.RMSE(), metrics.MAE()] # Only the first metric is used for training and early stopping

## MulticomponentMPNN

In [16]:
mcmpnn = multi.MulticomponentMPNN(
    mcmp,
    agg,
    ffn,
    metrics=metric_list,
)

mcmpnn

MulticomponentMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-1): 2 x BondMessagePassing(
        (W_i): Linear(in_features=86, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=372, out_features=300, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (tau): ReLU()
        (V_d_transform): Identity()
        (graph_transform): Identity()
      )
    )
  )
  (agg): MeanAggregation()
  (bn): Identity()
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(


# Set up trainer

In [17]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [18]:
trainer.fit(mcmpnn, train_loader, val_loader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type                         | Params | Mode 
-------------------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 455 K  | train
1 | agg             | MeanAggregation              | 0      | train
2 | bn              | Identity                     | 0      | train
3 | predictor       | RegressionFFN            

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  4.50it/s, train_loss_step=0.422, val_loss=0.301, train_loss_epoch=0.659]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  4.10it/s, train_loss_step=0.422, val_loss=0.301, train_loss_epoch=0.659]


# Test results

In [19]:
results = trainer.test(mcmpnn, test_loader)


/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 28.64it/s]
