In [1]:
from chemprop.data.datasets import MoleculeDataset, ReactionDataset, MulticomponentDataset

### Datasets

To make a dataset you first need a list of datapoints. See the datapoint notebook for more details.

In [2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, ReactionDatapoint

ys = np.random.rand(2, 1)

smis = ["C", "CC"]
mol_datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

rxn_smis = ["[H:2][O:1][H:3]>>[H:2][O:1].[H:3]", "[H:2][S:1][H:3]>>[H:2][S:1].[H:3]"]
rxn_datapoints = [
    ReactionDatapoint.from_smi(rxn_smi, y, keep_h=True) for rxn_smi, y in zip(rxn_smis, ys)
]

In [3]:
MoleculeDataset(mol_datapoints)

MoleculeDataset(data=[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd51c0>, y=array([0.14280278]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='C', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd52a0>, y=array([0.86053701]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CC', V_f=None, E_f=None, V_d=None)], featurizer=SimpleMoleculeMolGraphFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f30843ec3d0>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f30843e6350>))

In [4]:
ReactionDataset(rxn_datapoints)

ReactionDataset(data=[ReactionDatapoint(rct=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd5380>, pdt=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd53f0>, y=array([0.14280278]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[H:2][O:1][H:3]>>[H:2][O:1].[H:3]'), ReactionDatapoint(rct=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd54d0>, pdt=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd5540>, y=array([0.86053701]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[H:2][S:1][H:3]>>[H:2][S:1].[H:3]')], featurizer=CondensedGraphOfReactionFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f3002bdf190>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f3002bdf150>))

### Dataset properties

The properties of datapoints are collated in a dataset.

In [5]:
dataset = MoleculeDataset(mol_datapoints)
print(dataset.Y)
print(dataset.names)

[[0.14280278]
 [0.86053701]]
['C', 'CC']


Datasets return a `Datum` when indexed. A `Datum` contains a `MolGraph` (see the molgraph featurizer notebook), the extra atom and datapoint level descriptors, the target(s), the weights, and masks for bounded loss functions.

In [6]:
dataset[0]

Datum(mg=MolGraph(V=array([[0.     , 0.     , 0.     , 0.     , 0.     , 1.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
        1.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
        1.     , 0.     , 1.     , 0.     , 0.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 0.     , 1.     , 0.     , 0.     ,
        0.     , 0.     , 0.     , 1.     , 0.     , 0.     , 0.     ,
        0.     , 0.12011]], dtype=float32), E=array([], shape=(0, 14), dtype=float64), edge_index=array([], shape=(2, 0), dtype=int64), rev_edge_index=array([], dtype=int64)), V_d=None, x_d=None, y=array([0.14280278]), weight=1.0, lt_mask=None, gt_mask=None)

### Caching

The `MolGraph`s are generated as needed by default. For small to medium dataset (exact sizes not yet benchmarked), it is more efficient to generate and cache the molgraphs when the dataset is created. 

If the cache needs to be recreated, set the cache to True again. To clear the cache, set it to False. 

Note we recommend scaling additional atom and bond features (see the scaling notebook for details) before setting the cache, as scaling them after caching will require the cache to be recreated, which is done automatically.

In [7]:
dataset.cache = True  # Generate the molgraphs and cache them
dataset.cache = True  # Recreate the cache
dataset.cache = False  # Clear the cache

dataset.cache = True  # Cache created with unscaled extra bond features
dataset.normalize_inputs(key="E_f")  # Cache recreated automatically with scaled extra bond features

### Datasets with custom featurizers

`SimpleMoleculeMolGraphFeaturizer` and `CGRFeaturizer` are the default featurizers. Customized featurizers can also be given when making a dataset.

In [8]:
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer, MultiHotAtomFeaturizer

mol_featurizer = SimpleMoleculeMolGraphFeaturizer(atom_featurizer=MultiHotAtomFeaturizer.v1())
MoleculeDataset(mol_datapoints, featurizer=mol_featurizer)

MoleculeDataset(data=[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd51c0>, y=array([0.14280278]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='C', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f3002bd52a0>, y=array([0.86053701]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CC', V_f=None, E_f=None, V_d=None)], featurizer=SimpleMoleculeMolGraphFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f3002b77550>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f3002bfc410>))

### Multicomponent datasets

`MulticomponentDataset` is for datasets whose target values depend on multiple components. It is composed of parallel `MoleculeDataset`s and `ReactionDataset`s.

In [9]:
mol_dataset = MoleculeDataset(mol_datapoints)
rxn_dataset = ReactionDataset(rxn_datapoints)

# e.g. reaction in solvent
multi_dataset = MulticomponentDataset(datasets=[mol_dataset, rxn_dataset])

# e.g. solubility
MulticomponentDataset(datasets=[mol_dataset, mol_dataset])

<chemprop.data.datasets.MulticomponentDataset at 0x7f3002bffb90>

`MulticomponentDataset`s return a list of `Datum`s and collate the properties of their datasets. Note that datapoint level properties like target values and extra datapoint descriptors are only retrieved from the first dataset in datasets.

In [10]:
multi_dataset[0]

[Datum(mg=MolGraph(V=array([[0.     , 0.     , 0.     , 0.     , 0.     , 1.     , 0.     ,
         0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
         0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
         0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
         0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
         0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
         1.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,
         1.     , 0.     , 1.     , 0.     , 0.     , 0.     , 0.     ,
         0.     , 0.     , 0.     , 0.     , 1.     , 0.     , 0.     ,
         0.     , 0.     , 0.     , 1.     , 0.     , 0.     , 0.     ,
         0.     , 0.12011]], dtype=float32), E=array([], shape=(0, 14), dtype=float64), edge_index=array([], shape=(2, 0), dtype=int64), rev_edge_index=array([], dtype=int64)), V_d=None, x_d=None, y=array([0.14280278]), weight=1.0, lt_mask=None, gt_mask=None),

In [11]:
multi_dataset.datasets[0].Y

array([[0.14280278],
       [0.86053701]])

In [12]:
multi_dataset.smiles

[('C', ('[O:1]([H:2])[H:3]', '[H:3].[O:1][H:2]')),
 ('CC', ('[S:1]([H:2])[H:3]', '[H:3].[S:1][H:2]'))]