In [1]:
from lightning import pytorch as pl
import numpy as np
from chemprop import data, models, nn

Step 1: Make datapoints

In [2]:
smis = ["C", "CC", "CCC", "CCCC", "CCCCC"]
ys = np.random.rand(len(smis), 1)
datapoints = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

Step 2: Make a dataset and dataloader

In [3]:
dataset = data.MoleculeDataset(datapoints)
dataloader = data.build_dataloader(dataset)

Step 3: Define the model

In [4]:
chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())

Step 4: Set up the trainer

In [5]:
trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Step 5: Train the model

In [6]:
trainer.fit(chemprop_model, dataloader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type               | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K 
1 | agg             | MeanAggregation    | 0     
2 | bn              | BatchNorm1d        | 600   
3 | predictor       | RegressionFFN      | 90.6 K
4 | X_d_transform   | Identity           | 0     
------------------------------------------

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s, train_loss=0.293]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s, train_loss=0.293]


Step 6: Use the model to make predictions

In [7]:
dataloader = data.build_dataloader(dataset, shuffle=False)
preds = trainer.predict(chemprop_model, dataloader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  5.97it/s]


In [8]:
preds

[tensor([[-0.0043],
         [-0.0097],
         [-0.0117],
         [-0.0121],
         [-0.0144]])]