Skip to content

Commit

Permalink
[Fix] Fix ProfileHook can not profile performance in ddp-training (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed May 26, 2023
1 parent 277b530 commit 5d4e721
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
43 changes: 28 additions & 15 deletions mmengine/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ class ProfilerHook(Hook):
of generating handler. Defaults to None, which means profiling
without an on_trace_ready.The Callable type needs to construct its
own function that can handle 'torch.autograd.profiler.profile'.
Two officially recommended ways are provided, namely terminal
display or tensorboard display. The terminal display content can be
adjusted through 'EventList.table()'
from 'torch.autograd.profiler_util.py'.
If using tensorboard, save to '{work_dir}/tf_tracing_logs'
by default.
Two officially recommended ways are provided:
- ``schedule=dict(type='log_trace')``: Print the profiling result
in the terminal. See more details in the `PyTorch official tutorial`_.
The configurable arguments are the same as
``prof.key_averages().table``
- ``scheduler=dict(type='tb_trace')``: Profile the performance
with tensorboard. See more details in the tutorial
`profile with tensorboard`_.
record_shapes (bool): Save information about operator's input shapes.
Defaults to False.
profile_memory (bool): Track tensor memory allocation/deallocation.
Expand All @@ -67,11 +71,20 @@ class ProfilerHook(Hook):
JSON format. Chrome use 'chrome://tracing' view json file.
Defaults to None, which means profiling does not store json files.
Warnings:
The profiler will be closed after ``profile_times`` iterations
automatically. Please make sure the configuration of your scheduler
will not close the profiler before the iteration reach the value of
``profile_times``
Examples:
>>> # tensorboard trace
>>> trace_config = dict(type='tb_trace')
>>> profiler_hook_cfg = dict(on_trace_ready=trace_config)
"""
.. _PyTorch official tutorial: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-execution-time
.. _profile with tensorboard: https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard
""" # noqa: E501
priority = 'VERY_LOW'

def __init__(self,
Expand Down Expand Up @@ -135,8 +148,8 @@ def __init__(self,
self.with_flops = with_flops

self.json_trace_path = json_trace_path
self._closed = False

@master_only
def before_run(self, runner):
"""Initialize the profiler.
Expand Down Expand Up @@ -212,23 +225,23 @@ def _log_handler(_profile):
f'but got {self.on_trace_ready}')
return _on_trace_ready

@master_only
def after_train_epoch(self, runner):
"""Determine if the content is exported."""
if self.by_epoch and runner.epoch == self.profile_times - 1:
# `after_train_epoch` will also be called in IterBasedTrainLoop.
# Here we check `self._closed` to avoid exiting twice.
if not self._closed:
self._export_chrome_trace(runner)

@master_only
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Update the content according to the schedule, and determine if the
content is exported."""
if self.schedule is None:
"""profiler will call `step` method if it is not closed."""
if not self._closed:
self.profiler.step()
if not self.by_epoch and runner.iter == self.profile_times - 1:
if runner.iter == self.profile_times - 1 and not self.by_epoch:
self._export_chrome_trace(runner)

def _export_chrome_trace(self, runner):
"""Exporting content."""
self._closed = True
runner.logger.info('profiler may take a few minutes...')
self.profiler.__exit__(None, None, None)
if self.json_trace_path is not None:
Expand Down
10 changes: 2 additions & 8 deletions tests/test_hooks/test_profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_parse_trace_config_tensorboard(self):
dict(
type='ProfilerHook',
on_trace_ready=dict(
type='tb_trace', dir_name='/home/baymax/RunTime/tb'))
type='tb_trace', dir_name=self.temp_dir.name))
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
Expand Down Expand Up @@ -143,9 +143,6 @@ def test_after_train_iter(self):
runner.iter = 9

hook = ProfilerHook(by_epoch=False, profile_times=10, schedule=None)
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)

hook.profiler = MagicMock()
hook.after_train_iter(runner, 1, 1, 1)
hook.profiler.__exit__.assert_called_once()
Expand All @@ -154,12 +151,9 @@ def test_after_train_iter(self):
hook = ProfilerHook(
by_epoch=False,
schedule=dict(wait=1, warmup=1, active=3, repeat=1))
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)

hook.profiler = MagicMock()
hook.after_train_iter(runner, 1, 1, 1)
hook.profiler.step.assert_not_called()
hook.profiler.step.assert_called_once()

def test_with_runner(self):
self.epoch_based_cfg['custom_hooks'] = [
Expand Down

0 comments on commit 5d4e721

Please sign in to comment.