In [6]:
#!pip install chemprop

In [189]:
import pandas as pd
from pathlib import Path
from lightning import pytorch as pl
from chemprop import data, featurizers, models, nn

from matplotlib import pyplot as plt

In [None]:
input_path = "./data/KIBA_selected/tasks_pulled_data/P11309_vs_non_binder.csv"
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'Drug' # name of the column containing SMILES strings
target_columns = ['Y_binary'] # classification of activity (either 0 or 1)

In [233]:
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.7, 0.1, 0.2))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)

train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
  warn(


Creating the model

In [234]:
mp = nn.BondMessagePassing()

In [235]:
nn.metrics.MetricRegistry['accuracy']

chemprop.nn.metrics.BinaryAccuracy

In [236]:
from chemprop.nn.metrics import BCELoss, BinaryAccuracy

In [237]:
BinaryAccuracy

chemprop.nn.metrics.BinaryAccuracy

In [238]:
agg = nn.MeanAggregation()
ffn = nn.BinaryClassificationFFN(n_tasks = len(target_columns))
batch_norm = False
metric_list = [nn.metrics.BinaryAccuracy()] #want to change this _T_default_metric = BinaryAUROC

In [239]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): Identity()
  (predictor): BinaryClassificationFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): BCELoss(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(
    (0): BinaryAccuracy()
    (1): BCELoss(task_weights=[[1.0]])
  )
)

In [240]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="cpu",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/homebrew/anaconda3/envs/nco/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [241]:
trainer.fit(mpnn, train_loader, val_loader)


/opt/homebrew/anaconda3/envs/nco/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/aygulminnegalieva/Music/nco/negative-class-optimization/notebooks/small_molecules/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/opt/homebrew/anaconda3/envs/nco/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | MeanAggregation         | 0      | train
2 | bn              | Identity                | 0      | train
3 | predictor       | Bina

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/homebrew/anaconda3/envs/nco/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 8/8 [00:00<00:00, 14.29it/s, train_loss_step=0.625, val_loss=0.660, train_loss_epoch=0.618]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 8/8 [00:00<00:00, 14.02it/s, train_loss_step=0.625, val_loss=0.660, train_loss_epoch=0.618]


In [242]:
results = trainer.test(mpnn, test_loader)

/opt/homebrew/anaconda3/envs/nco/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 20.16it/s]


In [243]:
results

[{'test/accuracy': 0.6518518328666687}]

### Testuing on ood

In [None]:
input_path = "./data/KIBA_selected/tasks_pulled_data/P11309_vs_weak.csv"
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.7, 0.1, 0.2))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

test_dset = data.MoleculeDataset(test_data[0], featurizer)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
  warn(


In [231]:
results = trainer.test(mpnn, test_loader)

/opt/homebrew/anaconda3/envs/nco/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 19.72it/s]
