# Active Learning

# Import Packages

In [39]:
import pandas as pd
from pathlib import Path

from lightning import pytorch as pl

from chemprop import data, featurizers, models, nn
import random

from typing import Tuple
from torch.utils.data import DataLoader

# Change data inputs

In [40]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

## Load data

In [41]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,lipo
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10
...,...,...
95,CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...,2.20
96,CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...,2.04
97,CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...,4.49
98,COc1ccc(Cc2c(N)n[nH]c2N)cc1,0.20


## Get SMILES and targets

In [42]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [43]:
smis[:5] # show first 5 SMILES strings

array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
       'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
       'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
       'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
       'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],
      dtype=object)

In [44]:
ys[:5] # show first 5 targets

array([[ 3.54],
       [-1.18],
       [ 3.69],
       [ 3.37],
       [ 3.1 ]])

## Get molecule datapoints

In [45]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Data splitting for training/validation pool and testing

In [46]:
# available split types
list(data.SplitType.keys())

['CV_NO_VAL',
 'CV',
 'SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

In [47]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
nontest_indices, _, test_indices = data.make_split_indices(mols, "random", (0.9, 0.0, 0.1))
nontest_data, _, test_data = data.split_data_by_indices(
    all_data, nontest_indices, None, test_indices
)

## Get featurizer

In [48]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

# Change model inputs here

In [49]:
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.ckpt"  # some path
mpnn = models.MPNN.load_from_file(checkpoint_path)

/Users/joelmanu/anaconda3/envs/chv2/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'graph_transform' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['graph_transform'])`.


# Change active learning parameters here

In [50]:
batch_size = len(nontest_data) // 10  # number of datapoints added to trainval pool each iteration.

train_sizes = list(range(batch_size, len(nontest_data) + 1, batch_size))  # size of trainval datapool for each interval.
train_sizes[-1] = len(nontest_data)  # include all data in last iteration.

priority_function = lambda x: random.random()
# a priority function can be defined to control the order in which datapoints are added to the training and validation pool.

In [51]:
trainval_data = []
remaining_data = nontest_data

## Getting dataloaders

In [52]:
def get_dataloaders() -> Tuple[DataLoader]:
    trainval_mols = [d.mol for d in trainval_data]  # RDkit Mol objects are use for structure based splits
    train_indices, _, val_indices = data.make_split_indices(trainval_mols, "random", (0.9, 0.0, 0.1))
    train_data, val_data, _ = data.split_data_by_indices(
        trainval_data, train_indices, val_indices, None
    )

    train_dset = data.MoleculeDataset(train_data, featurizer)
    scaler = train_dset.normalize_targets()

    val_dset = data.MoleculeDataset(val_data, featurizer)
    val_dset.normalize_targets(scaler)

    test_dset = data.MoleculeDataset(test_data, featurizer)

    train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
    val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
    test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
    return train_loader, val_loader, test_loader

# Train with active learning

## Set up trainer

In [53]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,  # number of epochs to train for
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Start training

In [54]:
for train_size in train_sizes:
    # sort new datapoints by priority using priority function
    priority_remaining_data = [(priority_function(d), d) for d in remaining_data]
    sorted_remaining_data = [d[1] for d in sorted(priority_remaining_data, reverse=True)] 
    new_size = train_size - len(trainval_data)

    new_data = sorted_remaining_data[:new_size]
    remaining_data = remaining_data[new_size:]

    trainval_data.extend(new_data)

    train_loader, val_loader, test_loader = get_dataloaders()
    trainer.fit(mpnn, train_loader, val_loader)
    results = trainer.test(mpnn, test_loader)

  warn(
Loading `train_dataloader` to estimate number of stepping batches.
/Users/joelmanu/anaconda3/envs/chv2/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.

  | Name            | Type               | Params
-------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K 
1 | agg             | MeanAggregation    | 0     
2 | bn              | BatchNorm1d        | 600   
3 | predictor       | RegressionFFN      | 90.6 K
4 | X_d_transform   | Identity           | 0     
-------------------------------------------------------
318 K     Trainable params
0         Non-trainable params
318 K     Total params
1.276     Total estimated model params size (MB)


                                                  

/Users/joelmanu/anaconda3/envs/chv2/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/joelmanu/anaconda3/envs/chv2/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:104: Total length of `DataLoader` across ranks is zero. Please make sure this was your intention.


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] 

NotImplementedError: The operator 'aten::scatter_reduce.two_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.