In [46]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
import os
from habnet.featuriser.featurise import Featuriser, MOL_TYPES
from habnet.featuriser.habnet_featurizer import AtomHAbNetFeaturizer
from chemprop import data, featurizers, nn
from chemprop.nn import metrics
from chemprop.models import multi
from chemprop.nn.metrics import MSE
from lightning import pytorch as pl

import shap

In [48]:
feat_data = Featuriser(os.path.expanduser("~/code/chemprop_phd_customised/habnet/data/processed/sdf_data"), path = "/home/calvin/code/chemprop_phd_customised/habnet/data/processed/target_data/target_data.csv",
                       include_extra_features = True)
def leave_one_out_split(feat_data, n_leave_out: int =1):
    """
    Given feat_data structured as [list_of_r1h, list_of_r2h],
    this function returns a list of tuples (train, test) for each leave-one-out split.
    Each train is a list of two lists (for r1h and r2h) with one element left out,
    and each test is a list of two lists containing the left-out element.
    """
    n = len(feat_data[0])  # assuming both lists have the same length
    feat_data_mod = []
    leave_out = []
    leave_out.append([feat_data[0][-n_leave_out]])
    leave_out.append([feat_data[1][-n_leave_out]])
    feat_data_mod.append(feat_data[0][:-n_leave_out])
    feat_data_mod.append(feat_data[1][:-n_leave_out])

    return feat_data_mod, leave_out

feat_data_mod, leave_one_out = leave_one_out_split(feat_data)


Reaction rxn_635 not found in target data
Reaction rxn_288 not found in target data
Reaction rxn_785 not found in target data


In [49]:
component_to_split_by = 0
mols = [d.mol for d in feat_data_mod[component_to_split_by]]

train_indices, val_indices, test_indices = data.make_split_indices(mols, "kennard_stone", (0.8, 0.1, 0.1))


train_data, val_data, test_data = data.split_data_by_indices(
    feat_data_mod, train_indices, val_indices, test_indices
)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


In [50]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(
    extra_atom_fdim=10
)

train_datasets = [data.MoleculeDataset(train_data[0][i], featurizer) for i in range(len(MOL_TYPES))]
val_datasets = [data.MoleculeDataset(val_data[0][i], featurizer) for i in range(len(MOL_TYPES))]
test_datasets = [data.MoleculeDataset(test_data[0][i], featurizer) for i in range(len(MOL_TYPES))]

In [51]:
len(train_datasets[0]), len(val_datasets[0]), len(test_datasets[0])

(1444, 180, 182)

In [52]:
train_mcdset = data.MulticomponentDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = data.MulticomponentDataset(test_datasets)

In [64]:
train_loader = data.build_dataloader(train_mcdset, batch_size=64)
val_loader = data.build_dataloader(val_mcdset, shuffle=False, batch_size=64)
test_loader = data.build_dataloader(test_mcdset, shuffle=False, batch_size=64)
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing(depth=3, d_v = featurizer.atom_fdim, dropout=0.2
                                  ) for _ in range(len(MOL_TYPES))],
    n_components=len(MOL_TYPES), shared = False    
)
agg = nn.MeanAggregation()
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
ffn = nn.RegressionFFN(
     input_dim=mcmp.output_dim,
     hidden_dim=300,
     n_tasks = 5,
     dropout=0.2,
     output_transform=output_transform,
     #criterion=metrics.MSE(task_weights=[0.1, 0.1, 0.1, 1, 1])
)
metric_list = [metrics.RMSE(), metrics.MAE(), metrics.R2Score()]  # Only the first is used for early stopping.
mcmpnn = multi.MulticomponentMPNN(mcmp, agg, ffn, metrics=metric_list, batch_norm=False,
                                  init_lr=1e-4, max_lr=1e-3, final_lr=1e-4)


trainer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val/rmse", patience=10, mode="min"),
        pl.callbacks.ModelCheckpoint(
            monitor="val/rmse",
            mode="min",
            save_top_k=1,
            filename="{epoch:02d}-{val_rmse:.2f}-BEST",
        ),
    ],
)
trainer.fit(mcmpnn, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                         | Params | Mode 
-------------------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 467 K  | train
1 | agg             | MeanAggregation              | 0      | train
2 | bn              | Identity                     | 0      | train
3 | predictor       | RegressionFFN                | 181 K  | train
4 | X_d_transform   | Identity                     | 0      | train
5 | metrics         | ModuleList                   | 0      | train
-------------------------------------------------------------------------
649 K     Trainable params
0         Non-trainable params
649 K     Total params
2.597     Total estimated model params size (MB)
36        Modules in

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [65]:
blue = trainer.test(mcmpnn, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/calvin/miniforge3/envs/habnet_rocm/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]