Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass kwargs through to DataLoader in MolGraphDataLoader #808

Merged
merged 7 commits into from
Apr 19, 2024

Conversation

KnathanM
Copy link
Contributor

Currently trainer.predict uses a batch size of 1 when given a MolGraphDataLoader. Simple code to reproduce:

from lightning import pytorch as pl
from chemprop import nn
from chemprop.models import MPNN

mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
ffn = nn.RegressionFFN()
mpnn = MPNN(mp, agg, ffn)
trainer = pl.Trainer()

import torch
from chemprop.data import MoleculeDatapoint, MoleculeDataset, MolGraphDataLoader
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer
fake_data = [MoleculeDatapoint.from_smi('C', torch.tensor([1])), MoleculeDatapoint.from_smi('C', torch.tensor([2]))]
featurizer = SimpleMoleculeMolGraphFeaturizer()
fake_dset = MoleculeDataset(fake_data, featurizer)
fake_loader = MolGraphDataLoader(fake_dset, shuffle=False)

trainer.test(mpnn, fake_loader)
trainer.predict(mpnn, fake_loader)

Also add print(targets)/print(Y) to MPNN like this:

    def _evaluate_batch(self, batch) -> list[Tensor]:
        bmg, V_d, X_d, targets, _, lt_mask, gt_mask = batch
        print(targets)
        mask = targets.isfinite()
        targets = targets.nan_to_num(nan=0.0)
        preds = self(bmg, V_d, X_d)

        return [
            metric(preds, targets, mask, None, None, lt_mask, gt_mask)
            for metric in self.metrics[:-1]
        ]

    def predict_step(self, batch: TrainingBatch, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Return the predictions of the input batch

        Parameters
        ----------
        batch : TrainingBatch
            the input batch

        Returns
        -------
        Tensor
            a tensor of varying shape depending on the task type:

            * regression/binary classification: ``n x (t * s)``, where ``n`` is the number of input
            molecules/reactions, ``t`` is the number of tasks, and ``s`` is the number of targets
            per task. The final dimension is flattened, so that the targets for each task are
            grouped. I.e., the first ``t`` elements are the first target for each task, the second
            ``t`` elements the second target, etc.
            * multiclass classification: ``n x t x c``, where ``c`` is the number of classes
        """
        bmg, X_vd, X_d, Y, *_ = batch
        print(Y)

        return self(bmg, X_vd, X_d)

You'll see trainer.test prints

tensor([[1.],
        [2.]])

while trainer.predict prints

tensor([[1.]])
tensor([[2.]])

The batch size of 1 makes prediction slower so we want to fix it.

The problem has to do with batch_sampler. In MolGraphDataLoader we don't accept batch_sampler as an argument (except in **kwargs, which will be important), and we don't give a batch sampler to pytorch.DataLoader in the super().__init__(). This means pytorch will automaticaly make us a SequentialSampler (when shuffle=False like for a test_loader, see here).

During prediction, lightning reconstructs the data loader for us for their purposes (I think for distributed predicting). This means they turn the SequentialSampler from pytorch into a _IndexBatchSamplerWrapper (which I think keeps track of the original batch size) and gives that back to our data loader class to make a new loader. But MolGraphDataLoader doesn't take batch_sampler so it winds up in **kwargs which isn't used.

For pytorch data loaders, the "batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last" (link) so lighting also turns these off by setting sampler to None, shuffle to False, drop_last to False, and importantly batch_size to 1. MolGraphDataLoader does take the batch_size argument so that gets passed through while the batch_sampler does not.

The fix then is to pass **kwargs through to the super().__init(). But special consideration is needed as sampler, collate_fn and drop_last are created by MolGraphDataLoader so the values in kwargs would be duplicates. I added a check to pop them out of kwargs if they are given. Alternatively we could just delete them from kwargs as they are recreated by MolGraphDataLoader.

Side note, MolGraphDataLoader gave us a bit of hassle here, so a question is if we want to remove it entirely. My vote is to keep it because it handles collate_fn and drop_last (for training with a batch norm layer), both of which, I wouldn't expect a basic user to know they need to handle.

These changes made trainer.predict on my dataset of ~45,000 molecules go from taking 2 minutes to 7 seconds, the difference of using a batch size of 1 vs 50.

@davidegraff
Copy link
Contributor

re: keep or remove the MolGraphDataloader

I think we would make the code and our lives simpler by removing it. If you look at the definition, you’ll notice that it’s just a wrapper around a call to a regular dataloader. That is, the initializer takes in some arguments, does some stuff, and then passes them to the parent class’s initializer. The MolGraphDataloader doesn’t actually override any methods of the parent class, so this is functionally the same as just using a specific parametrization of a vanilla dataloader.

I made this switch a long time ago in my personal fork. It’s a simple change, removes some of the complexity in the package, and enhances flexibility. Users won’t need to reason about what a MolGraphDataloader is anymore and you don’t need to add complexity to a class that just passes it through to a different class.

The change, as I made it, is to:

  1. remove the MolGraphDataloader class entirely
  2. make its initializer method into a function (e.g., build_dataloader) that returns a regular torch dataloader

@KnathanM KnathanM added this to the v2.1.0 milestone Apr 18, 2024
@KnathanM
Copy link
Contributor Author

If I understand right, your suggestion is to replace the MolGraphDataLoader class with a build_dataloader helper function. I like this idea as it achieves the goal of helping the user make a working dataloader while not needing to be its own class. We want to release v2.0 tomorrow morning, so we don't have time to do that tonight, but will plan to make this change shortly thereafter and include it in v2.1. This PR would then be a temporary fix that we may still include in v2.0 if we have time to review it tonight.

@davidegraff
Copy link
Contributor

davidegraff commented Apr 18, 2024

if you want the fastest possible fix, change:

class MolGraphDataloader:
    def __init__(self, *args, **kwargs):
        # do so something with *args and **kwargs
        super().__init__(*args, **kwargs)

to:

def MolGraphDataloader(*args, **kwargs):
    # do so something with *args and **kwargs
    return DataLoader(*args, **kwargs)

that is:

  1. delete the class MolGraphDataLoader(DataLoader): statement
  2. de-indent the whole __init__() method and rename it to MolGraphDataLoader
  3. change super().__init__(...) to return DataLoader(...)

@davidegraff
Copy link
Contributor

For now, it'll work, but in 2.1 or whatever you can change the function name to build_dataloader to make it clear that it's a function that returns an object rather than an actual class itself. The change above is just so that other parts of the code don't have to change.

@KnathanM KnathanM merged commit 4bba677 into chemprop:v2/dev Apr 19, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants