In [1]:
from lightning import pytorch as pl
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from chemprop import data, models, nn

### Multitask model

Step 1: Make datapoints

In [2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol_multitask.csv"
smiles_column = 'smiles' 
target_columns = ["mu","alpha","homo","lumo","gap","r2","zpve","cv","u0","u298","h298","g298"] 

df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

datapoints = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

Step 2: Split data and make datasets

In [3]:
split_indices = data.make_split_indices(datapoints)
train_data, val_data, test_data = data.split_data_by_indices(datapoints, *split_indices)


train_dset = data.MoleculeDataset(train_data)
val_dset = data.MoleculeDataset(val_data)
test_dset = data.MoleculeDataset(test_data)

Step 3: Scale targets and make dataloaders

In [4]:
output_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(output_scaler)

train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset)
test_loader = data.build_dataloader(test_dset)

Step 4: Define the model

In [5]:
output_transform = nn.transforms.UnscaleTransform.from_standard_scaler(output_scaler)

ffn = nn.RegressionFFN(n_tasks = len(target_columns), output_transform=output_transform)
chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), ffn)

Step 5: Set up the trainer

In [6]:
trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Step 6: Train the model

In [7]:
trainer.fit(chemprop_model, train_loader, val_loader)

Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type               | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K 
1 | agg             | MeanAggregation    | 0     
2 | bn              | BatchNorm1d        | 600   
3 | predictor       | RegressionFFN      | 93.9 K
4 | X_d_transform   | Identity           | 0     
-------------------------------------------------------
322 K     Trainable params
0         Non-trainable params
322 K     Total params
1.289     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 7/7 [00:09<00:00,  0.77it/s, train_loss=0.607, val_loss=1.160]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 7/7 [00:09<00:00,  0.77it/s, train_loss=0.607, val_loss=1.160]


Step 7: Use the model to make predictions

In [8]:
preds = trainer.predict(chemprop_model, test_loader)
preds = torch.concat(preds, axis=1)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  2.12it/s]


In [9]:
preds.shape

torch.Size([51, 12])