diff --git a/vissl/config/defaults.yaml b/vissl/config/defaults.yaml index b8f1f719a..a2b082872 100644 --- a/vissl/config/defaults.yaml +++ b/vissl/config/defaults.yaml @@ -83,6 +83,21 @@ config: # valid for some systems. LOG_GPU_STATS: True + # ----------------------------------------------------------------------------------- # + # HOOKS + # ----------------------------------------------------------------------------------- # + HOOKS: + # ----------------------------------------------------------------------------------- # + # torch.cuda.memory_summary() + # ----------------------------------------------------------------------------------- # + MEMORY_SUMMARY: + # set this to true if you want to print memory summary. useful for profiling + # memory consumption of model + PRINT_MEMORY_SUMMARY: False + # at what iteration number should the memory summary be printed. usually + # set to 1 for very large models + LOG_ITERATION_NUM: 0 + # ----------------------------------------------------------------------------------- # # DATA # ----------------------------------------------------------------------------------- # diff --git a/vissl/hooks/__init__.py b/vissl/hooks/__init__.py index f7dd50f21..5bdd278a9 100644 --- a/vissl/hooks/__init__.py +++ b/vissl/hooks/__init__.py @@ -7,6 +7,7 @@ from vissl.hooks.deepclusterv2_hooks import ClusterMemoryHook, InitMemoryHook # noqa from vissl.hooks.log_hooks import ( # noqa LogGpuStatsHook, + LogGpuMemoryHook, LogLossLrEtaHook, LogLossMetricsCheckpointHook, LogPerfTimeMetricsHook, @@ -100,6 +101,8 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]: hooks.extend([SSLModelComplexityHook()]) if cfg.LOG_GPU_STATS: hooks.extend([LogGpuStatsHook()]) + if cfg.HOOKS.MEMORY_SUMMARY.PRINT_MEMORY_SUMMARY: + hooks.extend([LogGpuMemoryHook(cfg.HOOKS.MEMORY_SUMMARY.LOG_ITERATION_NUM)]) if cfg.TENSORBOARD_SETUP.USE_TENSORBOARD: assert is_tensorboard_available(), "Tensorboard must be installed to use it." tb_hook = get_tensorboard_hook(cfg) diff --git a/vissl/hooks/log_hooks.py b/vissl/hooks/log_hooks.py index 132c11741..835019539 100644 --- a/vissl/hooks/log_hooks.py +++ b/vissl/hooks/log_hooks.py @@ -22,6 +22,61 @@ from vissl.utils.perf_stats import PerfStats +class LogGpuMemoryHook(ClassyHook): + """ + Hook executed at a specified iteration number and prints the + memory summary for the primary device at several steps of training. + """ + + on_start = ClassyHook._noop + on_loss_and_meter = ClassyHook._noop + on_step = ClassyHook._noop + on_phase_end = ClassyHook._noop + on_end = ClassyHook._noop + + def __init__( + self, + log_iteration_num: int = 1, + ) -> None: + super().__init__() + self.log_iteration_num = log_iteration_num + + def on_phase_start(self, task: "tasks.ClassyTask") -> None: + """ + Print the stats just before the training epoch starts + """ + self._print_memory_summary(task, "on_phase_start") + + def on_forward(self, task: "tasks.ClassyTask") -> None: + """ + Print the stats after the model forward pass is done + """ + self._print_memory_summary(task, "on_forward") + + def on_backward(self, task: "tasks.ClassyTask") -> None: + """ + Print the stats just after model.backward() is done + """ + self._print_memory_summary(task, "on_backward") + + def on_update(self, task: "tasks.ClassyTask") -> None: + """ + Print the stats just after model params are updated + """ + self._print_memory_summary(task, "on_update") + + def _print_memory_summary(self, task: "tasks.ClassyTask", stage_name: str) -> None: + if ( + is_primary() + and (task.device.type == "cuda") + and task.local_iteration_num == self.log_iteration_num + ): + logging.info( + f"========= Memory Summary at {stage_name} =======" + f"\n{torch.cuda.memory_summary()}\n" + ) + + class LogGpuStatsHook(ClassyHook): """ Hook executed at the start of training and after every training iteration is done. @@ -92,8 +147,8 @@ def on_update(self, task: "tasks.ClassyTask") -> None: monitoring the stats (optionally) for every N iterations to get better idea about the batch time and training eta. - Set the btime_freq input using cfg.PERF_STAT_FREQUENCY=N ensuring that - cfg.MONITOR_PERF_STATS = True. + Set the btime_freq input using cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY=N + ensuring that cfg.HOOKS.PERF_STATS.MONITOR_PERF_STATS = True. """ phase_type = "train" if task.train else "test" if is_primary() and phase_type == "train":