Skip to content

Commit ccc923c

Browse files
ref: refactored inner eval loop (Lightning-AI#3141)
* refactored dataloader process hook * refactored dataloader process hook * refactored dataloader process hook
1 parent f064d74 commit ccc923c

File tree

2 files changed

+68
-74
lines changed

2 files changed

+68
-74
lines changed

pytorch_lightning/trainer/evaluate_loop.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from pytorch_lightning.trainer.supporters import PredictionCollection
3-
from pytorch_lightning.core.step_result import EvalResult
3+
from pytorch_lightning.core.step_result import Result, EvalResult
4+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
45

56

67
class EvaluationLoop(object):
@@ -43,11 +44,38 @@ def on_evaluation_epoch_start(self, *args, **kwargs):
4344
else:
4445
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
4546

46-
def evaluation_step(self, *args, **kwargs):
47+
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
48+
# make dataloader_idx arg in validation_step optional
49+
args = [batch, batch_idx]
50+
51+
multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
52+
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)
53+
54+
if multiple_test_loaders or multiple_val_loaders:
55+
args.append(dataloader_idx)
56+
57+
return args
58+
59+
def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
60+
# configure args
61+
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
62+
63+
# run actual test step
4764
if self.testing:
48-
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
65+
output = self.trainer.accelerator_backend.test_step(args)
4966
else:
50-
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
67+
output = self.trainer.accelerator_backend.validation_step(args)
68+
69+
# track batch size for weighted average
70+
is_result_obj = isinstance(output, Result)
71+
if is_result_obj:
72+
output.track_batch_size(len(batch))
73+
74+
# allow only EvalResult when using structured results (from val_step)
75+
if is_result_obj and not isinstance(output, EvalResult):
76+
m = 'only EvalResults or dicts are allowed from validation_step'
77+
raise MisconfigurationException(m)
78+
5179
return output
5280

5381
def evaluation_step_end(self, *args, **kwargs):
@@ -69,8 +97,37 @@ def on_evaluation_batch_end(self, *args, **kwargs):
6997
else:
7098
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
7199

100+
def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
101+
# Add step predictions to prediction collection to write later
102+
if output is not None:
103+
do_write_predictions = isinstance(output, Result) and self.testing
104+
if do_write_predictions:
105+
self.predictions.add(output.pop('predictions', None))
106+
107+
# track debug metrics
108+
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
109+
72110
def on_evaluation_epoch_end(self, *args, **kwargs):
73111
if self.testing:
74112
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
75113
else:
76114
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
115+
116+
def log_metrics(self, output, batch_idx):
117+
if self.trainer.running_sanity_check:
118+
return
119+
120+
if isinstance(output, EvalResult):
121+
step_log_metrics = output.batch_log_metrics
122+
step_pbar_metrics = output.batch_pbar_metrics
123+
124+
if len(step_log_metrics) > 0:
125+
# make the metrics appear as a different line in the same graph
126+
metrics_by_epoch = {}
127+
for k, v in step_log_metrics.items():
128+
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
129+
130+
self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx)
131+
132+
if len(step_pbar_metrics) > 0:
133+
self.trainer.add_progress_bar_metrics(step_pbar_metrics)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@
132132

133133
from pytorch_lightning.core.lightning import LightningModule
134134
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
135-
from pytorch_lightning.core.step_result import Result, EvalResult
136-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
135+
from pytorch_lightning.core.step_result import EvalResult, Result
137136
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
138137

139138
try:
@@ -273,55 +272,19 @@ def _evaluate(
273272
if batch_idx >= dl_max_batches:
274273
break
275274

276-
# -----------------
277-
# eval_batch_start
278-
# -----------------
275+
# val loop hooks
279276
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
280-
281-
# -----------------
282-
# RUN EVALUATION STEP
283-
# -----------------
284-
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
285-
output = self.evaluation_loop.evaluation_step(args)
286-
287-
# track batch size for weighted average
288-
is_result_obj = isinstance(output, Result)
289-
if is_result_obj:
290-
output.track_batch_size(len(batch))
291-
292-
# allow only EvalResult when using structured results (from val_step)
293-
if is_result_obj and not isinstance(output, EvalResult):
294-
m = 'only EvalResults or dicts are allowed from validation_step'
295-
raise MisconfigurationException(m)
296-
297-
# ------------------
298-
# EVAL STEP END
299-
# ------------------
277+
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
300278
output = self.evaluation_loop.evaluation_step_end(output)
301-
302-
# ------------------
303-
# Hook: on_eval_batch_end
304-
# ------------------
305279
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
306280

307-
# ----------------------
308-
# Post processing
309-
# ----------------------
310-
# track outputs for collation
311-
if output is not None:
312-
313-
# Add step predictions to prediction collection to write later
314-
do_write_predictions = is_result_obj and test_mode
315-
if do_write_predictions:
316-
self.evaluation_loop.predictions.add(output.pop('predictions', None))
281+
# clean up
282+
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
283+
self.evaluation_loop.log_metrics(output, batch_idx)
317284

285+
if output is not None:
318286
dl_outputs.append(output)
319287

320-
self.__eval_add_step_metrics(output, batch_idx)
321-
322-
# track debug metrics
323-
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)
324-
325288
self.evaluation_loop.outputs.append(dl_outputs)
326289

327290
# ---------------------
@@ -454,23 +417,6 @@ def __gather_epoch_end_eval_results(self, outputs):
454417
eval_results = eval_results[0]
455418
return eval_results
456419

457-
def __eval_add_step_metrics(self, output, batch_idx):
458-
# track step level metrics
459-
if isinstance(output, EvalResult) and not self.running_sanity_check:
460-
step_log_metrics = output.batch_log_metrics
461-
step_pbar_metrics = output.batch_pbar_metrics
462-
463-
if len(step_log_metrics) > 0:
464-
# make the metrics appear as a different line in the same graph
465-
metrics_by_epoch = {}
466-
for k, v in step_log_metrics.items():
467-
metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v
468-
469-
self.log_metrics(metrics_by_epoch, {}, step=batch_idx)
470-
471-
if len(step_pbar_metrics) > 0:
472-
self.add_progress_bar_metrics(step_pbar_metrics)
473-
474420
def __auto_reduce_result_objs(self, outputs):
475421
# outputs has a list of results per dataloader
476422
eval_results = []
@@ -588,12 +534,3 @@ def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
588534
print('-' * 80)
589535

590536
return eval_loop_results
591-
592-
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
593-
# make dataloader_idx arg in validation_step optional
594-
args = [batch, batch_idx]
595-
596-
if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
597-
args.append(dataloader_idx)
598-
599-
return args

0 commit comments

Comments
 (0)