# Import Packages

In [1]:
import pandas as pd
import numpy as np
import torch
from lightning import pytorch as pl

from chemprop import data, featurizers, models

## Change model input here

In [2]:
checkpoint_path = '../tests/data/example_model_v2_spectra.ckpt' # path to the checkpoint file. 
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

In [3]:
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=147, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=433, out_features=300, bias=True)
    (dropout): Dropout(p=0, inplace=False)
    (tau): ReLU()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): SpectralFFN(
    (ffn): MLP(
      (0): Linear(in_features=300, out_features=300, bias=True)
      (1): ReLU()
      (2): Dropout(p=0, inplace=False)
      (3): Linear(in_features=300, out_features=1, bias=True)
      (spectral_activation): Softplus(beta=1, threshold=20)
    )
  )
)

# Change predict input here

In [4]:
test_path = 'tests/data/spectra/test_smiles.csv'
smiles_column = 'smiles'

## Load data

In [5]:
df_test = pd.read_csv(test_path)
df_test

Unnamed: 0,smiles
0,O=C(O)c1ccco1
1,O=C(O)c1ccco1
2,CCOP(=O)(OCC)C(O)[C@@H]1OC(C)(C)O[C@H]1[C@@H]1...
3,CCOP(=O)(OCC)C(O)[C@@H]1OC(C)(C)O[C@H]1[C@@H]1...
4,c1ccc(C2=NOC(c3ccccc3)C2)cc1
...,...
195,CCC(C)CCC(C)CC
196,CCCCC(C)(C)CCC
197,CCCCC(C)(CC)CC
198,CCCCCCCC(C)C


## Get smiles

In [6]:
smis = df_test[smiles_column]
smis[:5]

0                                        O=C(O)c1ccco1
1                                        O=C(O)c1ccco1
2    CCOP(=O)(OCC)C(O)[C@@H]1OC(C)(C)O[C@H]1[C@@H]1...
3    CCOP(=O)(OCC)C(O)[C@@H]1OC(C)(C)O[C@H]1[C@@H]1...
4                         c1ccc(C2=NOC(c3ccccc3)C2)cc1
Name: smiles, dtype: object

## Get molecule datapoints

In [7]:
test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smis]

## Get molecule datasets

In [8]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dset = data.MoleculeDataset(test_data, featurizer=featurizer)
test_loader = data.MolGraphDataLoader(test_dset, shuffle=False)

# Set up trainer

In [9]:
with torch.inference_mode():
    trainer = pl.Trainer(
        logger=None,
        enable_progress_bar=True,
        accelerator="cpu",
        devices=1
    )
    test_preds = trainer.predict(mpnn, test_loader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/li

Predicting DataLoader 0: 100%|██████████| 200/200 [00:00<00:00, 739.59it/s]


In [10]:
test_preds = np.concatenate(test_preds, axis=0)
df_test['pred'] = test_preds
df_test

Unnamed: 0,smiles,pred
0,O=C(O)c1ccco1,
1,O=C(O)c1ccco1,
2,CCOP(=O)(OCC)C(O)[C@@H]1OC(C)(C)O[C@H]1[C@@H]1...,
3,CCOP(=O)(OCC)C(O)[C@@H]1OC(C)(C)O[C@H]1[C@@H]1...,
4,c1ccc(C2=NOC(c3ccccc3)C2)cc1,
...,...,...
195,CCC(C)CCC(C)CC,
196,CCCCC(C)(C)CCC,
197,CCCCC(C)(CC)CC,
198,CCCCCCCC(C)C,
