From b182c70e0ca8061401a1cbdd0c7fca5f32c20484 Mon Sep 17 00:00:00 2001 From: niboshi Date: Thu, 8 Feb 2018 00:06:32 +0900 Subject: [PATCH] Avoid keeping reference to function nodes in CupyMemoryProfileHook --- chainer/function_hooks/cupy_memory_profile.py | 22 +++++++++---------- .../test_cupy_memory_profile.py | 14 ++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/chainer/function_hooks/cupy_memory_profile.py b/chainer/function_hooks/cupy_memory_profile.py index 90698102525a..53bb87b3dde5 100644 --- a/chainer/function_hooks/cupy_memory_profile.py +++ b/chainer/function_hooks/cupy_memory_profile.py @@ -40,10 +40,10 @@ class CupyMemoryProfileHook(function_hook.FunctionHook): pool acquired from GPU device on the function call, and *Occurrence* is the number of calls. Attributes: - call_history: List of measurement results. It consists of the function - that calls this hook, the memory bytes the function used from cupy - memory pool, and the memory bytes the cupy memory pool acquired - from GPU device on the function call. + call_history: List of measurement results. It consists of the name of + the function that calls this hook, the memory bytes the function + used from cupy memory pool, and the memory bytes the cupy memory + pool acquired from GPU device on the function call. """ name = 'CupyMemoryProfileHook' @@ -83,7 +83,8 @@ def _postprocess(self, function): used_bytes = end_used_bytes - start_used_bytes acquired_bytes = end_acquired_bytes - start_acquired_bytes depth = len(self._running_stack) - self.call_history.append((function, used_bytes, acquired_bytes, depth)) + self.call_history.append( + (function._impl_name, used_bytes, acquired_bytes, depth)) if depth == 0: self._total_used_bytes += used_bytes self._total_acquired_bytes += acquired_bytes @@ -112,12 +113,11 @@ def summary(self): """ # TODO(sonots): PROBLEM: takes count of nested functions duplicately summary = collections.OrderedDict() - for func, used_bytes, acquired_bytes, depth in self.call_history: - function_name = func._impl_name - if function_name not in summary: - summary[function_name] = {'used_bytes': 0, - 'acquired_bytes': 0, 'occurrence': 0} - record = summary[function_name] + for func_name, used_bytes, acquired_bytes, depth in self.call_history: + if func_name not in summary: + summary[func_name] = {'used_bytes': 0, + 'acquired_bytes': 0, 'occurrence': 0} + record = summary[func_name] record['used_bytes'] += used_bytes record['acquired_bytes'] += acquired_bytes record['occurrence'] += 1 diff --git a/tests/chainer_tests/function_hooks_tests/test_cupy_memory_profile.py b/tests/chainer_tests/function_hooks_tests/test_cupy_memory_profile.py index bceade1e13aa..3fe41587f215 100644 --- a/tests/chainer_tests/function_hooks_tests/test_cupy_memory_profile.py +++ b/tests/chainer_tests/function_hooks_tests/test_cupy_memory_profile.py @@ -14,8 +14,8 @@ def check_history(self, t, function_type, used_bytes_type, acquired_bytes_type): - func = getattr(t[0], 'function', t[0]) - self.assertIsInstance(func, function_type) + func_name = t[0] + assert func_name == function_type.__name__ self.assertIsInstance(t[1], used_bytes_type) self.assertIsInstance(t[2], acquired_bytes_type) @@ -65,7 +65,7 @@ def check_backward(self, x, gy): # It includes forward of + that accumulates gradients to W and b self.assertEqual(3, len(self.h.call_history)) for entry in self.h.call_history: - if entry[0].label == '_ + _': + if entry[0] == 'Add': continue check_history(self, entry, basic_math.Mul, int, int) @@ -126,10 +126,10 @@ def test_reentrant(self): history = {f: (u, a, d) for (f, u, a, d) in self.h.call_history} self.assertEqual(len(history), 2) - self.assertIn(f, history) - self.assertIn(g, history) - f_used_bytes, f_acquired_bytes, f_depth = history[f] - g_used_bytes, g_acquired_bytes, g_depth = history[g] + self.assertIn(f._impl_name, history) + self.assertIn(g._impl_name, history) + f_used_bytes, f_acquired_bytes, f_depth = history[f._impl_name] + g_used_bytes, g_acquired_bytes, g_depth = history[g._impl_name] self.assertEqual(f_depth, 0) self.assertEqual(g_depth, 1) self.assertGreater(f_used_bytes, g_used_bytes)