# Predicting Regression - Reaction

# Import packages

In [1]:
import pandas as pd
import numpy as np
import torch
from lightning import pytorch as pl

from chemprop import data, featurizers, models

# Change model input here

In [2]:
checkpoint_path = '../tests/data/example_model_v2_reaction.ckpt' # path to the checkpoint file. 
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

## Load model

In [3]:
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=193, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=465, out_features=300, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (tau): ReLU()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Linear(in_features=300, out_features=300, bias=True)
      (1): ReLU()
      (2): Dropout(p=0, inplace=False)
      (3): Linear(in_features=300, out_features=1, bias=True)
    )
  )
)

# Change predict input here

In [4]:
test_path = '../tests/data/regression/rxn.csv'
smiles_column = 'smiles'

## Load smiles

In [5]:
df_test = pd.read_csv(test_path)

smis = df_test.loc[:, smiles_column].values
smis[:5]

array(['[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:15])([H:13])[H:14])([H:11])[H:12])([H:9])[H:10])[H:8]>>[C:3](=[C:4]=[O:5])([H:11])[H:12].[C:6]([O:7][H:15])([H:8])([H:13])[H:14].[O:1]=[C:2]([H:9])[H:10]',
       '[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:4]3([H:11])[O:5][C@:6]1([H:12])[C@@:7]23[H:13]>>[C:1]1([H:8])([H:9])[O:2][C:3]([H:10])=[C:7]([H:13])[C@:6]1([O+:5]=[C-:4][H:11])[H:12]',
       '[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H:13])([H:14])[C:5]([H:15])=[C:6]([H:16])[C@@:7]12[H:17])([H:8])([H:9])[H:10]>>[C:1]([C@@:2]1([H:11])[C:3]([H:12])([H:13])[C:4]([H:14])=[C:5]([H:15])[C:6]([H:16])=[C:7]1[H:17])([H:8])([H:9])[H:10]',
       '[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])([H:11])[H:12])([H:8])([H:9])[H:10]>>[C-:1]([O+:2]=[C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])[H:12])([H:8])[H:10].[H:9][H:11]',
       '[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H:10])[H:11])([H:7])([H:

## Load datapoints

In [6]:
test_data = [data.ReactionDatapoint.from_smi(smi) for smi in smis]

## Define featurizer

In [7]:
featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_="PROD_DIFF")
# Testing parameters should match training parameters

## Get dataset and dataloader

In [8]:
test_dset = data.ReactionDataset(test_data, featurizer=featurizer)
test_loader = data.MolGraphDataLoader(test_dset, shuffle=False)

# Perform tests

In [9]:
with torch.inference_mode():
    trainer = pl.Trainer(
        logger=None,
        enable_progress_bar=True,
        accelerator="cpu",
        devices=1
    )
    test_preds = trainer.predict(mpnn, test_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /home/kpg/chemprop-v2/chemprop/examples/lightning_logs
  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 100/100 [00:00<00:00, 261.97it/s]


In [10]:
test_preds = np.concatenate(test_preds, axis=0)
test_preds = mpnn.output_scaler.inverse_transform(test_preds)
df_test['preds'] = test_preds
df_test.loc[:, ['smiles', 'preds']]

Unnamed: 0,smiles,preds
0,[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1...,17.128060
1,[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:...,13.054828
2,[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...,13.137711
3,[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])...,15.169869
4,[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H...,17.142166
...,...,...
95,[C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]...,16.036926
96,[O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11...,16.249517
97,[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...,16.102618
98,[C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4...,18.164276
