# Finding alternatives to DGLLife for Solubility prediction

In [1]:
%reload_ext autoreload
%autoreload 2

import rdkit
import dgym as dg

# load all data
path = '../../dgym-data'

deck = dg.MoleculeCollection.load(
    f'{path}/DSi-Poised_Library_annotated.sdf',
    reactant_names=['reagsmi1', 'reagsmi2', 'reagsmi3']
)

reactions = dg.ReactionCollection.from_json(
    path = f'{path}/All_Rxns_rxn_library_sorted.json',
    smarts_col = 'reaction_string',
    classes_col = 'functional_groups'
)

building_blocks = dg.datasets.disk_loader(f'{path}/Enamine_Building_Blocks_Stock_262336cmpd_20230630.sdf')
fingerprints = dg.datasets.fingerprints(f'{path}/Enamine_Building_Blocks_Stock_262336cmpd_20230630_atoms.fpb')

import torch
import pyarrow.parquet as pq
table = pq.read_table('../../dgym-data/sizes.parquet')[0]
sizes = torch.tensor(table.to_numpy())

Using AqSolDB.

In [2]:
from tdc.single_pred import ADME
data = ADME(name = 'Solubility_AqSolDB')
split = data.get_split()

Found local copy...
Loading...
Done!


In [38]:
import cloudpickle as cp
from urllib.request import urlopen

with request.urlopen(
    'https://github.com/NilavoBoral/Therapeutics-Data-Commons/raw/main/best_model_series.pkl'
) as f:
    loaded_pickle_object = cp.load(f)

In [53]:
# d = loaded_pickle_object[2]
# d.fit(X)

In [72]:
%%time
import numpy as np
from sklearn.preprocessing import normalize
from scikit_mol.descriptors import MolecularDescriptorTransformer

desc_list = [
    'ExactMolWt', 'FpDensityMorgan1', 'FpDensityMorgan2', 'FpDensityMorgan3',
    'HeavyAtomMolWt', 'MaxAbsPartialCharge', 'MaxAbsPartialCharge', 'MinAbsPartialCharge',
    'MinPartialCharge', 'MolWt', 'NumRadicalElectrons', 'NumValenceElectrons',
    'MolLogP', 'FractionCSP3', 'HeavyAtomCount', 'NHOHCount', 'NOCount', 'NumAliphaticCarbocycles',
    'NumAliphaticHeterocycles', 'NumAliphaticRings', 'NumAromaticCarbocycles', 'NumAromaticHeterocycles',
    'NumAromaticRings', 'NumHAcceptors', 'NumHDonors', 'NumHeteroatoms', 'NumRotatableBonds',
    'NumSaturatedCarbocycles', 'NumSaturatedHeterocycles', 'NumSaturatedRings', 'RingCount',
]

transformer = MolecularDescriptorTransformer(
    desc_list, parallel=True)

X = transformer.transform([
    rdkit.Chem.MolFromSmiles(d)
    for d in split['train'].Drug
])

X = normalize(np.nan_to_num(X))
y = split['train']['Y']

CPU times: user 842 ms, sys: 237 ms, total: 1.08 s
Wall time: 1.46 s


In [78]:
from catboost import CatBoostRegressor

cb = CatBoostRegressor()
cb.fit(X, y)

Learning rate set to 0.055666
0:	learn: 2.3130387	total: 3.3ms	remaining: 3.3s
1:	learn: 2.2430673	total: 4.56ms	remaining: 2.27s
2:	learn: 2.1765152	total: 5.97ms	remaining: 1.99s
3:	learn: 2.1143769	total: 7.54ms	remaining: 1.88s
4:	learn: 2.0597623	total: 9.02ms	remaining: 1.79s
5:	learn: 2.0059922	total: 10.4ms	remaining: 1.72s
6:	learn: 1.9576766	total: 11.7ms	remaining: 1.66s
7:	learn: 1.9124469	total: 13.1ms	remaining: 1.62s
8:	learn: 1.8702389	total: 28.7ms	remaining: 3.16s
9:	learn: 1.8309049	total: 30.4ms	remaining: 3.01s
10:	learn: 1.7942109	total: 33.2ms	remaining: 2.98s
11:	learn: 1.7567925	total: 35.2ms	remaining: 2.9s
12:	learn: 1.7219722	total: 37.2ms	remaining: 2.82s
13:	learn: 1.6930987	total: 39ms	remaining: 2.75s
14:	learn: 1.6630898	total: 41.5ms	remaining: 2.72s
15:	learn: 1.6357107	total: 43.9ms	remaining: 2.7s
16:	learn: 1.6104465	total: 46.6ms	remaining: 2.69s
17:	learn: 1.5872915	total: 56.7ms	remaining: 3.1s
18:	learn: 1.5637680	total: 58.7ms	remaining: 3.03s

<catboost.core.CatBoostRegressor at 0x7f9a938988d0>

In [75]:
from scipy.stats import pearsonr

pearsonr(cb.predict(X), y)

PearsonRResult(statistic=0.955394360247308, pvalue=0.0)

In [77]:
from sklearn.metrics import mean_absolute_error

mean_absolute_error(y, cb.predict(X))

0.5270075350068456

Now testing.

In [88]:
X_test = transformer.transform([
    rdkit.Chem.MolFromSmiles(d)
    for d in split['test'].Drug
])

X_test = normalize(np.nan_to_num(X_test))
y_test = split['test']['Y']
y_pred = cb.predict(X_test)

In [90]:
pearsonr(y_test, y_pred)

PearsonRResult(statistic=0.8856841570973168, pvalue=0.0)

In [91]:
mean_absolute_error(y_test, y_pred)

0.7643665682693718