In [1]:
%load_ext autoreload
%autoreload 2

## Advanced features
Here we will discuss some of the more advanced features offered in the `molfeat` package. In the following `transformer` refers to a subclass of `MoleculeTransformer`. 

### Processing

All molecule transformers in `molfeat` support `pre-processing` and `post-processing` of a list of molecule. This allows for example to implement new transformers without needing to change a large part of the code. We demonstrate this with a toy transformer model that compute the `rdkit3D` descriptors, which will require conformers in input, but also applies a power transformer to make the data more gaussian like.

In [2]:
import datamol as dm
import numpy as np
from functools import partial
from sklearn.preprocessing import PowerTransformer
from molfeat.trans.base import MoleculeTransformer
from loguru import logger

# reproducibility
CONF_PARAMS = dict(n_confs=5, clear_existing=True, minimize_energy=True, random_seed=12)


class MyCustomTransformer(MoleculeTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ptrans = PowerTransformer()

    def preprocess(self, inputs, *args, **kwargs):
        # During preprocessing we need to compute the conformation of the molecule
        inputs = [dm.to_mol(x) for x in inputs]
        outputs = dm.parallelized(partial(dm.conformers.generate, **CONF_PARAMS), inputs, n_jobs=4)
        logger.info("In preprocess: computed the conformers")
        return outputs

    def postprocess(self, inputs, *args, **kwargs):
        # during postprocess we standardize the data
        outputs = self.ptrans.fit_transform(inputs)
        logger.info("In postprocess: normalized the data")
        return outputs

Using backend: pytorch


In [3]:
data = dm.data.freesolv().sample(n=100)

Let's first show that we cannot featurize without conformers in the molecules by using `MoleculeTransformer` directly

In [14]:
trans = MoleculeTransformer(featurizer="desc3d", verbose=False)
trans.transform(data["smiles"], ignore_errors=False)

ValueError: Cannot transform molecule at index 0

Now with our new class :

In [None]:
trans = MyCustomTransformer(featurizer="desc3d")
X = trans.preprocess(data["smiles"])
X = trans.transform(X)
X = trans.postprocess(X)
X

2021-07-16 17:29:54.933 | INFO     | __main__:preprocess:20 - In preprocess: computed the conformers
  loglike = -n_samples / 2 * np.log(x_trans.var())
2021-07-16 17:29:55.837 | INFO     | __main__:postprocess:26 - In postprocess: normalized the data


array([[ 1.27336055,  0.9874605 ,  1.99189802, ..., -1.34248133,
        -1.38425571, -1.16710531],
       [-0.60835615, -0.3324303 , -0.23798693, ..., -0.64279521,
        -0.54665693, -0.59015871],
       [ 0.63046753,  0.75814346, -0.34787693, ...,  0.45775594,
         0.54136146,  0.61144396],
       ...,
       [ 1.32519232,  0.99850081,  1.1748335 , ...,  0.04142212,
        -0.11412962, -0.53027834],
       [ 0.53299287,  0.71157622, -0.08109695, ..., -0.02670614,
         0.11096503,  0.07506185],
       [ 1.36885527,  1.00718807,  2.09435387, ..., -1.51994557,
        -1.26664259, -1.22082345]])

### Callbacks
As you may have noted the above process is tedious and require defining a novel class. Moreover, if the same molecule is passed, the conformer are regenerated again. 

To address this, transformers in `molfeat` supports a callback option that is used to defined behaviour of `preprocess`, `postprocess` and even `get_collate_fn`.

The base callback class is available at `molfeat.utils.callbacks.FeatCallback`. You can also chain a list of callbacks through `molfeat.utils.callbacks.FeatCallbackList`. 

The default behaviour of this callbacks are to do nothing. However for convenience a `ConformerCallback` is also provided. This latter callback is effiently implemented allowing to cache conformer computation of previously seen molecules on disk.

In [None]:
# let's clear the initial cache
!rm -rf {cb.tmp_dir}/molfeat/utils/callbacks/

In [6]:
from molfeat.utils.callbacks import ConformerCallback

cb = ConformerCallback(n_jobs=4, **CONF_PARAMS)

In [7]:
%%timeit -n 1 -r 1
trans = MoleculeTransformer(featurizer="desc3d", callbacks=cb, dtype=float)
X2 = trans(data["smiles"], ignore_errors=False, enforce_dtype=True)

2.77 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


Computation are cached to disk, and you can have a quick overview of their content. Note that caching is done using a combination of both `inchikey` of the input molecule and `parameters` of the conformer generation, to prevent erroneous caching.

In [13]:
!tree {cb.tmp_dir}/molfeat | head -n 20

/Users/manu/Library/Caches/molfeat/conformers/molfeat
└── utils
    └── callbacks
        └── _generate_conformer
            ├── 00e04915ad95b437d6bca4e335c4bb62
            │   ├── metadata.json
            │   └── output.pkl
            ├── 0183520c8848a5f315f4a333cc0143ce
            │   ├── metadata.json
            │   └── output.pkl
            ├── 02c6467cf645a7b590f7f28923b0e5ed
            │   ├── metadata.json
            │   └── output.pkl
            ├── 0325f0027466d614fc79ddaaede073ed
            │   ├── metadata.json
            │   └── output.pkl
            ├── 036b771e817fd50fae26fefbd2fcd8da
            │   ├── metadata.json
            │   └── output.pkl
            ├── 0374f115da4cf1ee88e7c41e99e543c6


Let's compute again to make sure the cache has kicked in and that we get faster results.

In [9]:
%%timeit -n 1 -r 3
trans = MoleculeTransformer(featurizer="desc3d", callbacks=cb, dtype=float)
X2 = trans(data["smiles"], ignore_errors=False, enforce_dtype=True)

121 ms ± 4.98 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


Sanity check to compare output to calculator output

In [11]:
trans = MoleculeTransformer(featurizer="desc3d", callbacks=cb, dtype=float)
X2 = trans.transform(data["smiles"], ignore_errors=False, enforce_dtype=True)
X2[-1]

array([ 8.36296440e-01,  9.98321117e-01,  1.36344887e-01,  5.79219138e-02,
        9.42348490e-01,  6.91150588e+00,  1.12445303e+02,  1.19324543e+02,
        1.45956249e+00,  4.55100499e-04,  1.71447673e-02,  7.22000000e-01,
        8.02000000e-01,  6.05000000e-01,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  8.77000000e-01,  1.06800000e+00,  1.07300000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  5.89000000e-01,
        5.73000000e-01,  3.09000000e-01,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  8.75000000e-01,  1.06400000e+00,  1.06500000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  4.67000000e-01,
        3.64000000e-01,  

### Caching feature computation

It's possible to cache feature calculation natively in order to speed up featurization on the same dataset. `molfeat` provides a wrapper class for that purpose. 

The internal cache can either be an instance of DataCache or a dict mapping each molecules to its features.

Note that it's possible to first compute the information in the cache then used the pre-filled cached to instanciate a new transformer

In [23]:
from molfeat.trans.base import PrecomputedMolTransformer
from molfeat.utils.cache import DataCache
import time

trans = MoleculeTransformer("desc2d")
cache = DataCache(name="desc2d_transformer")

precomp = PrecomputedMolTransformer(cache=cache, featurizer=trans)
molecules = data["smiles"].values

t1 = time.time()
out1 = precomp.transform(molecules)
elapsed1 = time.time() - t1
print("Elapsed time: {:.3f} s".format(elapsed1))
out2 = precomp.transform(molecules)
print("Elapsed time after caching: {:.3f} s".format(time.time() - elapsed1 - t1))
assert np.all(out1 == out2)

Elapsed time: 0.727 s
Elapsed time after caching: 0.025 s


In [19]:
out1

[array([ 5.13194444e+00,  9.24444444e-01,  5.13194444e+00,  9.24444444e-01,
         4.91869888e-01,  1.02177000e+02,  8.80650000e+01,  1.02104465e+02,
         4.40000000e+01,  0.00000000e+00,  4.62900700e-02, -3.81474012e-01,
         3.81474012e-01,  4.62900700e-02,  1.00000000e+00,  1.42857143e+00,
         1.71428571e+00,  1.64727595e+01,  1.04907047e+01,  1.77425267e+00,
        -1.94896098e+00,  1.81010376e+00, -1.91756393e+00,  4.24744837e+00,
         1.35272616e-01,  2.44747289e+00,  2.33609904e+01,  5.53553391e+00,
         5.23667542e+00,  5.23667542e+00,  3.41421356e+00,  2.99156383e+00,
         2.99156383e+00,  1.61237244e+00,  1.61237244e+00,  6.96923425e-01,
         6.96923425e-01,  3.90737207e-01,  3.90737207e-01, -4.00000000e-02,
         3.55096099e+01,  6.96000000e+00,  5.96000000e+00,  5.96000000e+00,
         4.56775068e+01,  4.73686295e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.0

### Transformer concatenation

Another interesting features offered in  `molfeat` is the ability to concatenate multiple featurizers together. Feature concatenation has some limitations, the most major being the inability to set the parameters of all transformers in a single call. Thus it is not compatible with the *scikit learn grid search cv*  API and you will need to handle the update of the parameters of the concatenated featurizer yourself. 


In [12]:
## To make this interesting let combine some descriptors that could give nan results that we might want to remove

In [13]:
from molfeat.trans.concat import FeatConcat
from molfeat.trans.fp import FPVecTransformer, FPVecFilteredTransformer
from molfeat.trans.base import BaseFeaturizer

desc = FPVecTransformer("desc3D", callbacks=cb)
ecfp = FPVecTransformer("ecfp:4")
mord = FPVecFilteredTransformer("mordred", del_invariant=True, ignore_3D=True)

In [14]:
# use a dataframe as the datatype
cat_fp = FeatConcat([desc, ecfp], dtype="pandas")
cat_fp.append(mord)

In [15]:
# print initial columns
len(cat_fp.columns)

4252

In [16]:
cat_fp.fit(data["smiles"])

[FPVecTransformer(kind="desc3D", length=2000, dtype=np.float32),
 FPVecTransformer(kind="ecfp:4", length=2000, dtype=np.float32),
 FPVecFilteredTransformer (kind="mordred", length=2000, occ_threshold=0, del_invariant=True, dtype=np.float32)]

In [17]:
# print columns after fitting
len(cat_fp.columns)

3404

In [18]:
df, idx = cat_fp(data["smiles"], ignore_errors=True, enforce_dtype=True)

In [19]:
df

Unnamed: 0,CalcAsphericity,CalcEccentricity,CalcInertialShapeFactor,CalcNPR1,CalcNPR2,CalcPMI1,CalcPMI2,CalcPMI3,CalcRadiusOfGyration,CalcSpherocityIndex,...,SRW09,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb2
0,0.811071,0.997717,0.086153,0.067538,0.932462,10.823307,149.432492,160.255798,1.390225,3.734071e-18,...,0.0,4.174387,17.310771,83.953355,16.790671,4.0,0.0,6.0,4.0,1.000000
1,0.339734,0.947944,0.008826,0.318438,0.716364,81.165741,182.591842,254.887112,1.716004,5.546652e-02,...,0.0,8.124151,33.544698,100.088815,5.267832,42.0,5.0,30.0,31.0,1.666667
2,0.644469,0.990454,0.007675,0.137845,0.945576,123.196831,845.096157,893.736562,2.615380,1.288769e-01,...,0.0,8.906935,40.567492,150.104465,6.004179,162.0,13.0,50.0,55.0,2.611111
3,0.807545,0.997624,0.011258,0.068898,0.931123,82.709180,1117.766940,1200.450706,2.970339,3.110447e-05,...,0.0,7.933438,36.894490,146.057909,7.302895,151.0,7.0,38.0,36.0,2.416667
4,0.760002,0.996152,0.012188,0.087638,0.933331,76.580622,815.569036,873.826546,2.758088,3.157529e-02,...,0.0,7.609367,34.745525,130.099380,5.656495,114.0,6.0,32.0,30.0,2.333333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.415542,0.965014,0.012094,0.262197,0.940776,77.785722,279.099207,296.669007,1.971456,2.764084e-01,...,0.0,7.655864,32.211905,100.125201,4.353270,48.0,4.0,26.0,24.0,1.666667
96,0.499001,0.977626,0.002435,0.210349,0.814517,334.468670,1295.135895,1590.066114,2.909762,3.092651e-02,...,0.0,9.496496,46.853672,201.078979,7.733807,362.0,21.0,74.0,85.0,3.472222
97,0.824747,0.998055,0.032237,0.062334,0.958168,29.723052,456.888388,476.835478,2.365731,3.014104e-02,...,0.0,6.900731,30.257210,101.120449,4.596384,56.0,4.0,22.0,20.0,2.000000
98,0.619703,0.988923,0.010568,0.148433,0.903536,85.497555,520.438738,576.001906,2.296174,8.252447e-02,...,0.0,8.379998,36.722228,122.073165,6.424903,94.0,8.0,38.0,40.0,2.250000


### Collate function

Most molecule transformer also provides their own collate function for pytorch dataloader. 

In this example, we will use the default collate function of an adjacency graph transformer and show how the features are collated based on the provided arguments

In [20]:
import torch
from molfeat.trans.graph.adj import AdjGraphTransformer

In [25]:
trans = AdjGraphTransformer(explicit_hydrogens=False, self_loop=True, dtype=torch.float)
graphs, ids = trans(data["smiles"], ignore_errors=True)

In [67]:
graphs[0]

(tensor([[1., 1., 1.],
         [1., 1., 0.],
         [1., 0., 1.]]),
 tensor([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       

In [68]:
from torch.utils.data import DataLoader

Batching using padding to a maximum values. When the maximum size is not provided, it's determined by the **batch largest molecule**

In [72]:
trans.get_collate_fn()([graphs[i] for i in range(5)])

ValueError: not enough values to unpack (expected 3, got 2)

In [80]:
loader = DataLoader(graphs, batch_size=32, collate_fn=trans.get_collate_fn())
for i, batch in enumerate(loader):
    g_batch, f_batch = batch
    if i > 1:
        break
    print(i, g_batch.shape, f_batch.shape)

0 torch.Size([32, 19, 19]) torch.Size([32, 19, 82])
1 torch.Size([32, 24, 24]) torch.Size([32, 24, 82])


If instead we choose to pack the graphs, we get a single giant super graph

In [81]:
loader = DataLoader(graphs, batch_size=64, collate_fn=trans.get_collate_fn(pack=True))
for i, batch in enumerate(loader):
    g_batch, f_batch = batch
    if i > 1:
        break
    print(i, g_batch.shape, f_batch.shape)

0 torch.Size([582, 582]) torch.Size([582, 82])
1 torch.Size([310, 310]) torch.Size([310, 82])
