Skip to content

Commit

Permalink
[Feature] Add ProfilerHook (#768)
Browse files Browse the repository at this point in the history
* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* Apply suggestions from code review

* Update mmengine/hooks/profiler_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
BayMaxBHL and zhouzaida committed Dec 27, 2022
1 parent c382f8a commit 16589ce
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/en/api/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mmengine.hooks
IterTimerHook
SyncBuffersHook
EmptyCacheHook
ProfilerHook
1 change: 1 addition & 0 deletions docs/zh_cn/api/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mmengine.hooks
IterTimerHook
SyncBuffersHook
EmptyCacheHook
ProfilerHook
3 changes: 2 additions & 1 deletion mmengine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from .logger_hook import LoggerHook
from .naive_visualization_hook import NaiveVisualizationHook
from .param_scheduler_hook import ParamSchedulerHook
from .profiler_hook import ProfilerHook
from .runtime_info_hook import RuntimeInfoHook
from .sampler_seed_hook import DistSamplerSeedHook
from .sync_buffer_hook import SyncBuffersHook

__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook'
]
232 changes: 232 additions & 0 deletions mmengine/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Callable, Optional, Union

import torch

from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.registry import HOOKS


def check_kineto() -> bool: # noqa
kineto_exist = False
try:
if torch.autograd.kineto_available():
kineto_exist = True
except AttributeError:
warnings.warn('NO KINETO')
return kineto_exist


@HOOKS.register_module()
class ProfilerHook(Hook):
"""A hook to analyze performance during training and inference.
PyTorch Profiler is a tool that allows the collection of the performance
metrics during the training. More details on Profiler can be found at
`official docs <https://pytorch.org/docs/stable/profiler.html
#torch.profiler.profile>`_
Args:
by_epoch (bool): Profile performance by epoch or by iteration.
Defaults to True.
profile_times (int): The period (epoch/iter) recorded by the profiler.
Defaults to 1. For example, profile_iters=10 and by_epoch=False,
indicate that 0-10 iterations are recorded.
activity_with_cpu (bool): Activities to be used in the analysis (CPU)
activity_with_cuda (bool): Activities to be used in the analysis (CUDA)
schedule (dict, optional): Key-word arguments passed to
`torch.profile.schedule <https://pytorch.org/docs/stable/
profiler.html#torch.profiler.schedule>`_.
Defaults to None, which means profiling without a schedule
on_trace_ready (callable, dict, optional): Either a handler or a dict
of generating handler. Defaults to None, which means profiling
without an on_trace_ready.The Callable type needs to construct its
own function that can handle 'torch.autograd.profiler.profile'.
Two officially recommended ways are provided, namely terminal
display or tensorboard display. The terminal display content can be
adjusted through 'EventList.table()'
from 'torch.autograd.profiler_util.py'.
If using tensorboard, save to '{work_dir}/tf_tracing_logs'
by default.
record_shapes (bool): Save information about operator's input shapes.
Defaults to False.
profile_memory (bool): Track tensor memory allocation/deallocation.
Defaults to False.
with_stack (bool): Record source information (file and line number)
for the ops. Defaults to False.
with_flops (bool): Use formula to estimate the FLOPS of specific
operators (matrix multiplication and 2D convolution).
Defaults to False.
json_trace_path (str, optional): Exports the collected trace in Chrome
JSON format. Chrome use 'chrome://tracing' view json file.
Defaults to None, which means profiling does not store json files.
Examples:
>>> # tensorboard trace
>>> trace_config = dict(type='tb_trace')
>>> profiler_hook_cfg = dict(on_trace_ready=trace_config)
"""
priority = 'VERY_LOW'

def __init__(self,
*,
by_epoch: bool = True,
profile_times: int = 1,
activity_with_cpu: bool = True,
activity_with_cuda: bool = False,
schedule: Optional[dict] = None,
on_trace_ready: Union[Callable, dict, None] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
json_trace_path: Optional[str] = None) -> None:

try:
from torch import profiler
except ImportError:
raise ImportError('please upgrade torch above 1.8.1')
if not check_kineto():
raise ImportError('Due to Kineto support issues, please upgrade '
'pytorch above 1.8.1(windows users above 1.9.1)')

assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
self.by_epoch = by_epoch

if profile_times < 1:
raise ValueError('profile_iters should be greater than 0, '
f'but got {profile_times}')
if by_epoch and profile_times > 1:
raise ValueError(
f'Profiler will profile 0-{profile_times} epochs.\n'
'Since profiler will slow down the training, it is recommended'
' to train 1 epoch with ProfilerHook and adjust your setting '
'according to the profiler summary.\n'
'During normal training(epoch > 1), '
'you may disable the ProfilerHook.')
self.profile_times = profile_times

assert isinstance(activity_with_cpu, bool), \
'``activity_with_cpu`` should be a boolean.'
assert isinstance(activity_with_cuda, bool), \
'``activity_with_cuda`` should be a boolean.'
self.activities = []
if activity_with_cpu:
self.activities.append(profiler.ProfilerActivity.CPU)
if activity_with_cuda:
self.activities.append(profiler.ProfilerActivity.CUDA)

if schedule is not None:
assert isinstance(schedule, dict), '``schedule`` should be a dict.'
self.schedule = profiler.schedule(**schedule)
else:
self.schedule = None

self.on_trace_ready = on_trace_ready
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_flops = with_flops

self.json_trace_path = json_trace_path

@master_only
def before_run(self, runner):
"""Initialize the profiler.
Through the runner parameter, the validity of the parameter is further
determined.
"""
max_times = runner.max_epochs if self.by_epoch else runner.max_iters
if max_times < self.profile_times:
raise ValueError(
f'``profile_times`` should not be greater than {max_times}')

on_trace_ready = self._parse_trace_config(runner)

self.profiler = torch.profiler.profile( # noqa
activities=self.activities,
schedule=self.schedule,
on_trace_ready=on_trace_ready,
record_shapes=self.record_shapes,
profile_memory=self.profile_memory,
with_stack=self.with_stack,
with_flops=self.with_flops)

self.profiler.__enter__()
runner.logger.info('profiler is profiling...')

def _parse_trace_config(self, runner):
"""Used to parse the parameter 'on_trace_ready'."""
if self.on_trace_ready is None:
_on_trace_ready = None
elif callable(self.on_trace_ready):
_on_trace_ready = self.on_trace_ready
elif isinstance(self.on_trace_ready, dict):
trace_cfg = self.on_trace_ready.copy()
trace_type = trace_cfg.pop('type')

# Build a log printing handle
if trace_type == 'log_trace':

def _log_handler(_profile):
print(_profile.key_averages().table(**trace_cfg))

_on_trace_ready = _log_handler

elif trace_type == 'tb_trace': # tensorboard_trace handler
try:
import torch_tb_profiler # noqa: F401
except ImportError:
raise ImportError(
'please run ``pip install torch-tb-profiler``')

if 'dir_name' not in trace_cfg:
trace_cfg['dir_name'] = osp.join(runner.log_dir,
'tf_tracing_logs')
elif not osp.isabs(trace_cfg['dir_name']):
trace_cfg['dir_name'] = osp.join(runner.log_dir,
trace_cfg['dir_name'])
runner.logger.info('trace_files of ProfilerHook will be '
f'saved to {trace_cfg["dir_name"]}.')

if self.json_trace_path is not None:
runner.logger.warn(
'When using tensorboard_trace, it is recommended to '
'save json files by setting ``worker_name`` instead of'
' setting ``json_trace_path``')
_on_trace_ready = torch.profiler.tensorboard_trace_handler(
**trace_cfg)
else:
raise ValueError('trace_type should be "log_trace" or '
f'"tb_trace", but got {trace_type}')
else:
raise ValueError(
'``on_trace_ready`` should be a handler, or dict, or None, '
f'but got {self.on_trace_ready}')
return _on_trace_ready

@master_only
def after_train_epoch(self, runner):
"""Determine if the content is exported."""
if self.by_epoch and runner.epoch == self.profile_times - 1:
self._export_chrome_trace(runner)

@master_only
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Update the content according to the schedule, and determine if the
content is exported."""
if self.schedule is None:
self.profiler.step()
if not self.by_epoch and runner.iter == self.profile_times - 1:
self._export_chrome_trace(runner)

def _export_chrome_trace(self, runner):
"""Exporting content."""
runner.logger.info('profiler may take a few minutes...')
self.profiler.__exit__(None, None, None)
if self.json_trace_path is not None:
self.profiler.export_chrome_trace(self.json_trace_path)

0 comments on commit 16589ce

Please sign in to comment.