Skip to content

Commit e6b34ef

Browse files
awaelchliBorda
andauthored
[WIP] Reduction when batch size < num gpus (Lightning-AI#1609)
* reduce if <= num_gpus * add test with explanation * chlog * fix changelog Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
1 parent fafe5d6 commit e6b34ef

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))
1414

1515
### Changed
16+
17+
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
1618

1719
### Deprecated
1820

pytorch_lightning/trainer/logging.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def reduce_distributed_output(self, output, num_gpus):
196196
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
197197
pass
198198

199-
# reduce only metrics that have the same number of gpus
200-
elif output[k].size(0) == num_gpus:
201-
reduced = torch.mean(output[k])
202-
output[k] = reduced
199+
# do not reduce metrics that have batch size > num gpus
200+
elif output[k].size(0) <= num_gpus:
201+
output[k] = torch.mean(output[k])
202+
203203
return output

tests/trainer/test_dataloaders.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44
import torch
5+
from torch.utils.data.dataloader import DataLoader
6+
from torch.utils.data.dataset import Subset
57

68
import tests.base.utils as tutils
79
from pytorch_lightning import Trainer
@@ -482,3 +484,46 @@ class CustomDummyObj:
482484
assert isinstance(result, torch.utils.data.DataLoader)
483485
assert isinstance(result, CustomDataLoader)
484486
assert hasattr(result, 'dummy_kwarg')
487+
488+
489+
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
490+
def test_batch_size_smaller_than_num_gpus():
491+
# we need at least 3 gpus for this test
492+
num_gpus = 3
493+
batch_size = 3
494+
495+
class CurrentTestModel(
496+
LightTrainDataloader,
497+
TestModelBase,
498+
):
499+
500+
def __init__(self, *args, **kwargs):
501+
super().__init__(*args, **kwargs)
502+
self.c_d1_bn = torch.nn.ReLU()
503+
504+
def train_dataloader(self):
505+
dataloader = super().train_dataloader()
506+
# construct a dataset with a size that is not divisible by num_gpus
507+
# therefore the last batch will have a size < num_gpus
508+
size = num_gpus * batch_size + (num_gpus - 1)
509+
dataset = Subset(dataloader.dataset, range(size))
510+
dataloader = DataLoader(
511+
dataset,
512+
batch_size=self.hparams.batch_size,
513+
drop_last=False,
514+
)
515+
return dataloader
516+
517+
hparams = tutils.get_default_hparams()
518+
hparams.batch_size = batch_size
519+
model = CurrentTestModel(hparams)
520+
521+
trainer = Trainer(
522+
max_epochs=1,
523+
gpus=num_gpus,
524+
)
525+
526+
# we expect the reduction for the metrics also to happen on the last batch
527+
# where we will get fewer metrics than gpus
528+
result = trainer.fit(model)
529+
assert 1 == result

0 commit comments

Comments
 (0)