# Using RIGR featurizer in Chemprop

This is an example notebook illustrating the use of RIGR featurizer to train (and infer) chemprop models. RIGR featurizer maps all resonance forms of any molecule or radical to a single representation.

## Creating the environment

This notebook uses the chemprop environment from the [`rigr_flag`](https://github.com/akshatzalte/chemprop/tree/rigr_flag) branch. The steps for setting up the environment are:

1. Install chemprop by cloning the source code via git. Make sure to start in an appropriate local directory where you want chemprop folder to exist.

    ```bash
    cd YourFolderPath
    git clone https://github.com/akshatzalte/chemprop.git
    ```
2. Navigate to the chemprop directory and switch the branch to `rigr_flag` to use RIGR as a flag.

    ```bash
    cd chemprop
    git checkout rigr_flag
    ```
3. Create and activate the environment using the environment.yml file.

    ```bash
    conda env create -f environment.yml --name=YourEnvName
    conda activate YourEnvName
    ```


## Imports

In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
from lightning import pytorch as pl
from typing import Sequence
from rdkit.Chem.rdchem import Atom, Bond, Mol
from rdkit import Chem

from chemprop import data, featurizers, models, nn
from chemprop.featurizers.atom import MultiHotAtomFeaturizer
from chemprop.featurizers.bond import MultiHotBondFeaturizer
from chemprop.featurizers.molecule import ChargeFeaurizer, MultiplicityFeaurizer
from chemprop.utils import make_mol

# Load data

In [21]:
chemprop_dir = Path("/home/hwpang/Projects/chemprop_v2_dev/chemprop")
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

In [22]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,lipo
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10
...,...,...
95,CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...,2.20
96,CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...,2.04
97,CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...,4.49
98,COc1ccc(Cc2c(N)n[nH]c2N)cc1,0.20


In [23]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

# Featurization and make dataset

In [30]:
mols = [make_mol(smi, add_h=True, keep_h=True) for smi in smis]

charge_featurizer = ChargeFeaurizer()
multiplicity_featurizer = MultiplicityFeaurizer()
charge_feats = [charge_featurizer(mol) for mol in mols]
mult_feats = [multiplicity_featurizer(mol) for mol in mols]
x_ds = [np.hstack([charge_feat, mult_feat]) for charge_feat, mult_feat in zip(charge_feats, mult_feats)]

all_data = [data.MoleculeDatapoint(mol, name=smi, y=y, x_d=x_d) for mol, smi, y, x_d in zip(mols, smis, ys, x_ds)]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

In [31]:
rigr_atom_featurizer = MultiHotAtomFeaturizer.v2(rigr=True)
rigr_bond_featurizer = MultiHotBondFeaturizer(rigr=True)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(atom_featurizer=rigr_atom_featurizer, bond_featurizer=rigr_bond_featurizer)

train_dset = data.MoleculeDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)

test_dset = data.MoleculeDataset(test_data, featurizer)


# Dataloader

In [32]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Model

In [33]:
mp = nn.BondMessagePassing(
    d_v=featurizer.atom_fdim,
    d_e=featurizer.bond_fdim,
)
agg = nn.MeanAggregation()
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
ffn = nn.RegressionFFN(
    input_dim=mp.output_dim + train_dset.d_xd,
    output_transform=output_transform,
)
batch_norm = True
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()] # Only the first metric is used for training and early stopping
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=66, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=352, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=302, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
)

# Trainer

In [34]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

/home/hwpang/miniforge3/envs/rigr_env/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/hwpang/miniforge3/envs/rigr_env/lib/python3.12 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [35]:
trainer.fit(mpnn, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/home/hwpang/miniforge3/envs/rigr_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 215 K  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 600    | train
3 | predictor       | RegressionFFN      | 91.2 K | train
4 | X_d_transform   | Identity           | 0      | train
---------------------------------------------------------------
307 K     Trainable params
0         Non-trainable params
307 K     Total params
1.

                                                                           

/home/hwpang/miniforge3/envs/rigr_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 31.67it/s, train_loss=0.0709, val_loss=0.902]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 21.14it/s, train_loss=0.0709, val_loss=0.902]


# Test results

In [36]:
results = trainer.test(mpnn, test_loader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/hwpang/miniforge3/envs/rigr_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 174.36it/s]
