### Dataloaders

In [1]:
from chemprop.data.dataloader import build_dataloader

This is an example [dataset](./datasets.ipynb) to load.

In [2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset

smis = ["C" * i for i in range(1, 4)]
ys = np.random.rand(len(smis), 1)
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])

### Torch dataloaders

Chemprop uses native `torch.utils.data.Dataloader`s to batch data as input to a model. `build_dataloader` is a helper function to make the dataloader.

In [3]:
dataloader = build_dataloader(dataset)

`build_dataloader` changes the defaults of `Dataloader` to use a batch size of 64 and turn on shuffling. It also automatically uses the correct collating function for the dataset (single component vs multi-component)

In [4]:
from torch.utils.data import DataLoader
from chemprop.data.collate import collate_batch, collate_multicomponent

dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)

### Collate function

The collate function takes an iterable of dataset outputs and batches them together. Iterating through batches is done automatically during training by the lightning `Trainer`.

In [5]:
collate_batch([dataset[0], dataset[1]])

TrainingBatch(bmg=<chemprop.data.collate.BatchMolGraph object at 0x7f29512d50d0>, V_d=None, X_d=None, Y=tensor([[0.9638],
        [0.5570]]), w=tensor([[1.],
        [1.]]), lt_mask=None, gt_mask=None)

### Shuffling

Shuffling the data helps improve model training, so `build_dataloader` has `shuffle=True` as the default. Shuffling should be turned off for validation and test dataloaders. Lightning gives a warning if a dataloader with shuffling is used during prediction.

In [6]:
train_loader = build_dataloader(dataset)
val_loader = build_dataloader(dataset, shuffle=False)
test_loader = build_dataloader(dataset, shuffle=False)

In [7]:
from lightning import pytorch as pl
from chemprop import models, nn

trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)
chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
preds = trainer.predict(chemprop_model, dataloader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 18.53it/s]


In [9]:
preds = trainer.predict(chemprop_model, test_loader)

Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 56.38it/s]


### Parallel data loading

As datapoints are sampled from the dataset, the `MolGraph` data structures are generated on-the-fly, which requires featurization of the molecular graphs. Giving the dataloader multiple workers can increase dataloading speed by preparing the datapoints in parallel. Note that this is not compatible with Windows (the process hangs) and some versions of Mac. 

[Caching](./dataloaders.ipynb) the the `MolGraphs` in the dataset before making the dataloader can also speed up sequential dataloading (`num_workers=0`).

In [10]:
build_dataloader(dataset, num_workers=8)

dataset.cache = True
build_dataloader(dataset)

<torch.utils.data.dataloader.DataLoader at 0x7f2918dfca50>

### Drop last batch

`build_dataloader` drops the last batch if it is a single datapoint as batch normalization (the default) requires at least two data points. If you do not want to drop the last datapoint, you can adjust the batch size, or, if you aren't using batch normalization, build the dataloader manually.

In [11]:
dataloader = build_dataloader(dataset, batch_size=2)



In [12]:
dataloader = build_dataloader(dataset, batch_size=3)
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)

### Samplers

The default sampler for a `torch.utils.data.Dataloader` is a `torch.utils.data.sampler.SequentialSampler` for `shuffle=False`, or a `torch.utils.data.sampler.RandomSampler` if `shuffle=True`. 

`build_dataloader` can be given a seed to make a `chemprop.data.samplers.SeededSampler` for reproducibility. Chemprop also offers `chemprop.data.samplers.ClassSampler` to equally sample positive and negative classes for binary classification tasks. 

In [13]:
build_dataloader(dataset, seed=0)

<torch.utils.data.dataloader.DataLoader at 0x7f2918dac6d0>

In [14]:
smis = ["C" * i for i in range(1, 11)]
ys = np.random.randint(low=0, high=2, size=(len(smis), 1))
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])

dataloader = build_dataloader(dataset, class_balance=True)

_, _, _, Y, *_ = next(iter(dataloader))
print(Y)

tensor([[1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.]])
