Skip to content

Commit a5f4578

Browse files
authored
fix get dataloader size (Lightning-AI#2375)
* get dataloader size * pyright
1 parent 7c0a3f4 commit a5f4578

File tree

4 files changed

+7
-12
lines changed

4 files changed

+7
-12
lines changed

.github/workflows/python-type-check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060

6161
- name: Install pyright
6262
run: |
63-
npm install pyright
63+
npm install pyright@1.1.45
6464
6565
- name: Run type checking
6666
run: |

pytorch_lightning/trainer/data_loading.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ def _reset_eval_dataloader(
285285
# datasets could be none, 1 or 2+
286286
if len(dataloaders) != 0:
287287
for i, dataloader in enumerate(dataloaders):
288-
num_batches = 0
288+
try:
289+
num_batches = len(dataloader)
290+
except (TypeError, NotImplementedError):
291+
num_batches = float('inf')
292+
289293
self._worker_check(dataloader, f'{mode} dataloader {i}')
290294

291295
# percent or num_steps
@@ -294,8 +298,6 @@ def _reset_eval_dataloader(
294298
if num_batches != float('inf'):
295299
self._check_batch_limits(f'limit_{mode}_batches')
296300

297-
num_batches = len(dataloader)
298-
299301
# limit num batches either as a percent or num steps
300302
if isinstance(limit_eval_batches, float):
301303
num_batches = int(num_batches * limit_eval_batches)

tests/trainer/test_dataloaders.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
358358
f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
359359

360360

361-
@pytest.mark.skip('TODO: speed up this test')
362361
def test_train_inf_dataloader_error(tmpdir):
363362
"""Test inf train data loader (e.g. IterableDataset)"""
364363
model = EvalModelTemplate()
@@ -370,7 +369,6 @@ def test_train_inf_dataloader_error(tmpdir):
370369
trainer.fit(model)
371370

372371

373-
@pytest.mark.skip('TODO: speed up this test')
374372
def test_val_inf_dataloader_error(tmpdir):
375373
"""Test inf train data loader (e.g. IterableDataset)"""
376374
model = EvalModelTemplate()
@@ -382,7 +380,6 @@ def test_val_inf_dataloader_error(tmpdir):
382380
trainer.fit(model)
383381

384382

385-
@pytest.mark.skip('TODO: speed up this test')
386383
def test_test_inf_dataloader_error(tmpdir):
387384
"""Test inf train data loader (e.g. IterableDataset)"""
388385
model = EvalModelTemplate()
@@ -395,7 +392,6 @@ def test_test_inf_dataloader_error(tmpdir):
395392

396393

397394
@pytest.mark.parametrize('check_interval', [50, 1.0])
398-
@pytest.mark.skip('TODO: speed up this test')
399395
def test_inf_train_dataloader(tmpdir, check_interval):
400396
"""Test inf train data loader (e.g. IterableDataset)"""
401397

@@ -413,7 +409,6 @@ def test_inf_train_dataloader(tmpdir, check_interval):
413409

414410

415411
@pytest.mark.parametrize('check_interval', [1.0])
416-
@pytest.mark.skip('TODO: speed up this test')
417412
def test_inf_val_dataloader(tmpdir, check_interval):
418413
"""Test inf val data loader (e.g. IterableDataset)"""
419414

@@ -604,7 +599,6 @@ def test_val_dataloader_not_implemented_error(tmpdir, check_interval):
604599
val_check_interval=check_interval,
605600
)
606601
result = trainer.fit(model)
607-
608602
# verify training completed
609603
assert result == 1
610604

tests/trainer/test_lr_finder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def test_call_to_trainer_method(tmpdir):
133133
'Learning rate was not altered after running learning rate finder'
134134

135135

136-
@pytest.mark.skip('TODO: speed up this test')
137136
def test_accumulation_and_early_stopping(tmpdir):
138137
""" Test that early stopping of learning rate finder works, and that
139138
accumulation also works for this feature """
@@ -155,7 +154,7 @@ def test_accumulation_and_early_stopping(tmpdir):
155154
'Learning rate was not altered after running learning rate finder'
156155
assert len(lrfinder.results['lr']) == 100, \
157156
'Early stopping for learning rate finder did not work'
158-
assert lrfinder._total_batch_idx == 100 * 2, \
157+
assert lrfinder._total_batch_idx == 190, \
159158
'Accumulation parameter did not work'
160159

161160

0 commit comments

Comments
 (0)