-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass kwargs through to DataLoader in MolGraphDataLoader #808
Conversation
re: keep or remove the I think we would make the code and our lives simpler by removing it. If you look at the definition, you’ll notice that it’s just a wrapper around a call to a regular dataloader. That is, the initializer takes in some arguments, does some stuff, and then passes them to the parent class’s initializer. The I made this switch a long time ago in my personal fork. It’s a simple change, removes some of the complexity in the package, and enhances flexibility. Users won’t need to reason about what a The change, as I made it, is to:
|
If I understand right, your suggestion is to replace the |
if you want the fastest possible fix, change: class MolGraphDataloader:
def __init__(self, *args, **kwargs):
# do so something with *args and **kwargs
super().__init__(*args, **kwargs) to: def MolGraphDataloader(*args, **kwargs):
# do so something with *args and **kwargs
return DataLoader(*args, **kwargs) that is:
|
For now, it'll work, but in 2.1 or whatever you can change the function name to |
1f897eb
to
6f58be2
Compare
7d1dc82
to
1129a60
Compare
Currently
trainer.predict
uses a batch size of 1 when given aMolGraphDataLoader
. Simple code to reproduce:Also add
print(targets)
/print(Y)
to MPNN like this:You'll see
trainer.test
printswhile
trainer.predict
printsThe batch size of 1 makes prediction slower so we want to fix it.
The problem has to do with
batch_sampler
. InMolGraphDataLoader
we don't acceptbatch_sampler
as an argument (except in **kwargs, which will be important), and we don't give a batch sampler to pytorch.DataLoader in thesuper().__init__()
. This means pytorch will automaticaly make us aSequentialSampler
(when shuffle=False like for a test_loader, see here).During prediction, lightning reconstructs the data loader for us for their purposes (I think for distributed predicting). This means they turn the
SequentialSampler
from pytorch into a_IndexBatchSamplerWrapper
(which I think keeps track of the original batch size) and gives that back to our data loader class to make a new loader. ButMolGraphDataLoader
doesn't takebatch_sampler
so it winds up in**kwargs
which isn't used.For pytorch data loaders, the "batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last" (link) so lighting also turns these off by setting sampler to None, shuffle to False, drop_last to False, and importantly batch_size to 1.
MolGraphDataLoader
does take the batch_size argument so that gets passed through while the batch_sampler does not.The fix then is to pass
**kwargs
through to the super().__init(). But special consideration is needed assampler
,collate_fn
anddrop_last
are created byMolGraphDataLoader
so the values inkwargs
would be duplicates. I added a check to pop them out of kwargs if they are given. Alternatively we could just delete them from kwargs as they are recreated byMolGraphDataLoader
.Side note,
MolGraphDataLoader
gave us a bit of hassle here, so a question is if we want to remove it entirely. My vote is to keep it because it handles collate_fn and drop_last (for training with a batch norm layer), both of which, I wouldn't expect a basic user to know they need to handle.These changes made
trainer.predict
on my dataset of ~45,000 molecules go from taking 2 minutes to 7 seconds, the difference of using a batch size of 1 vs 50.