|
132 | 132 |
|
133 | 133 | from pytorch_lightning.core.lightning import LightningModule |
134 | 134 | from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType |
135 | | -from pytorch_lightning.core.step_result import Result, EvalResult |
136 | | -from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 135 | +from pytorch_lightning.core.step_result import EvalResult, Result |
137 | 136 | from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop |
138 | 137 |
|
139 | 138 | try: |
@@ -273,55 +272,19 @@ def _evaluate( |
273 | 272 | if batch_idx >= dl_max_batches: |
274 | 273 | break |
275 | 274 |
|
276 | | - # ----------------- |
277 | | - # eval_batch_start |
278 | | - # ----------------- |
| 275 | + # val loop hooks |
279 | 276 | self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) |
280 | | - |
281 | | - # ----------------- |
282 | | - # RUN EVALUATION STEP |
283 | | - # ----------------- |
284 | | - args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) |
285 | | - output = self.evaluation_loop.evaluation_step(args) |
286 | | - |
287 | | - # track batch size for weighted average |
288 | | - is_result_obj = isinstance(output, Result) |
289 | | - if is_result_obj: |
290 | | - output.track_batch_size(len(batch)) |
291 | | - |
292 | | - # allow only EvalResult when using structured results (from val_step) |
293 | | - if is_result_obj and not isinstance(output, EvalResult): |
294 | | - m = 'only EvalResults or dicts are allowed from validation_step' |
295 | | - raise MisconfigurationException(m) |
296 | | - |
297 | | - # ------------------ |
298 | | - # EVAL STEP END |
299 | | - # ------------------ |
| 277 | + output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) |
300 | 278 | output = self.evaluation_loop.evaluation_step_end(output) |
301 | | - |
302 | | - # ------------------ |
303 | | - # Hook: on_eval_batch_end |
304 | | - # ------------------ |
305 | 279 | self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) |
306 | 280 |
|
307 | | - # ---------------------- |
308 | | - # Post processing |
309 | | - # ---------------------- |
310 | | - # track outputs for collation |
311 | | - if output is not None: |
312 | | - |
313 | | - # Add step predictions to prediction collection to write later |
314 | | - do_write_predictions = is_result_obj and test_mode |
315 | | - if do_write_predictions: |
316 | | - self.evaluation_loop.predictions.add(output.pop('predictions', None)) |
| 281 | + # clean up |
| 282 | + self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) |
| 283 | + self.evaluation_loop.log_metrics(output, batch_idx) |
317 | 284 |
|
| 285 | + if output is not None: |
318 | 286 | dl_outputs.append(output) |
319 | 287 |
|
320 | | - self.__eval_add_step_metrics(output, batch_idx) |
321 | | - |
322 | | - # track debug metrics |
323 | | - self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output) |
324 | | - |
325 | 288 | self.evaluation_loop.outputs.append(dl_outputs) |
326 | 289 |
|
327 | 290 | # --------------------- |
@@ -454,23 +417,6 @@ def __gather_epoch_end_eval_results(self, outputs): |
454 | 417 | eval_results = eval_results[0] |
455 | 418 | return eval_results |
456 | 419 |
|
457 | | - def __eval_add_step_metrics(self, output, batch_idx): |
458 | | - # track step level metrics |
459 | | - if isinstance(output, EvalResult) and not self.running_sanity_check: |
460 | | - step_log_metrics = output.batch_log_metrics |
461 | | - step_pbar_metrics = output.batch_pbar_metrics |
462 | | - |
463 | | - if len(step_log_metrics) > 0: |
464 | | - # make the metrics appear as a different line in the same graph |
465 | | - metrics_by_epoch = {} |
466 | | - for k, v in step_log_metrics.items(): |
467 | | - metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v |
468 | | - |
469 | | - self.log_metrics(metrics_by_epoch, {}, step=batch_idx) |
470 | | - |
471 | | - if len(step_pbar_metrics) > 0: |
472 | | - self.add_progress_bar_metrics(step_pbar_metrics) |
473 | | - |
474 | 420 | def __auto_reduce_result_objs(self, outputs): |
475 | 421 | # outputs has a list of results per dataloader |
476 | 422 | eval_results = [] |
@@ -588,12 +534,3 @@ def __log_evaluation_epoch_metrics(self, eval_results, test_mode): |
588 | 534 | print('-' * 80) |
589 | 535 |
|
590 | 536 | return eval_loop_results |
591 | | - |
592 | | - def build_args(self, test_mode, batch, batch_idx, dataloader_idx): |
593 | | - # make dataloader_idx arg in validation_step optional |
594 | | - args = [batch, batch_idx] |
595 | | - |
596 | | - if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1): |
597 | | - args.append(dataloader_idx) |
598 | | - |
599 | | - return args |
0 commit comments