Skip to content

Commit

Permalink
Avoid keeping reference to function nodes in CupyMemoryProfileHook
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Feb 7, 2018
1 parent a3a1edb commit b182c70
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
22 changes: 11 additions & 11 deletions chainer/function_hooks/cupy_memory_profile.py
Expand Up @@ -40,10 +40,10 @@ class CupyMemoryProfileHook(function_hook.FunctionHook):
pool acquired from GPU device on the function call, and *Occurrence*
is the number of calls.
Attributes:
call_history: List of measurement results. It consists of the function
that calls this hook, the memory bytes the function used from cupy
memory pool, and the memory bytes the cupy memory pool acquired
from GPU device on the function call.
call_history: List of measurement results. It consists of the name of
the function that calls this hook, the memory bytes the function
used from cupy memory pool, and the memory bytes the cupy memory
pool acquired from GPU device on the function call.
"""

name = 'CupyMemoryProfileHook'
Expand Down Expand Up @@ -83,7 +83,8 @@ def _postprocess(self, function):
used_bytes = end_used_bytes - start_used_bytes
acquired_bytes = end_acquired_bytes - start_acquired_bytes
depth = len(self._running_stack)
self.call_history.append((function, used_bytes, acquired_bytes, depth))
self.call_history.append(
(function._impl_name, used_bytes, acquired_bytes, depth))
if depth == 0:
self._total_used_bytes += used_bytes
self._total_acquired_bytes += acquired_bytes
Expand Down Expand Up @@ -112,12 +113,11 @@ def summary(self):
"""
# TODO(sonots): PROBLEM: takes count of nested functions duplicately
summary = collections.OrderedDict()
for func, used_bytes, acquired_bytes, depth in self.call_history:
function_name = func._impl_name
if function_name not in summary:
summary[function_name] = {'used_bytes': 0,
'acquired_bytes': 0, 'occurrence': 0}
record = summary[function_name]
for func_name, used_bytes, acquired_bytes, depth in self.call_history:
if func_name not in summary:
summary[func_name] = {'used_bytes': 0,
'acquired_bytes': 0, 'occurrence': 0}
record = summary[func_name]
record['used_bytes'] += used_bytes
record['acquired_bytes'] += acquired_bytes
record['occurrence'] += 1
Expand Down
Expand Up @@ -14,8 +14,8 @@

def check_history(self, t, function_type, used_bytes_type,
acquired_bytes_type):
func = getattr(t[0], 'function', t[0])
self.assertIsInstance(func, function_type)
func_name = t[0]
assert func_name == function_type.__name__
self.assertIsInstance(t[1], used_bytes_type)
self.assertIsInstance(t[2], acquired_bytes_type)

Expand Down Expand Up @@ -65,7 +65,7 @@ def check_backward(self, x, gy):
# It includes forward of + that accumulates gradients to W and b
self.assertEqual(3, len(self.h.call_history))
for entry in self.h.call_history:
if entry[0].label == '_ + _':
if entry[0] == 'Add':
continue
check_history(self, entry,
basic_math.Mul, int, int)
Expand Down Expand Up @@ -126,10 +126,10 @@ def test_reentrant(self):

history = {f: (u, a, d) for (f, u, a, d) in self.h.call_history}
self.assertEqual(len(history), 2)
self.assertIn(f, history)
self.assertIn(g, history)
f_used_bytes, f_acquired_bytes, f_depth = history[f]
g_used_bytes, g_acquired_bytes, g_depth = history[g]
self.assertIn(f._impl_name, history)
self.assertIn(g._impl_name, history)
f_used_bytes, f_acquired_bytes, f_depth = history[f._impl_name]
g_used_bytes, g_acquired_bytes, g_depth = history[g._impl_name]
self.assertEqual(f_depth, 0)
self.assertEqual(g_depth, 1)
self.assertGreater(f_used_bytes, g_used_bytes)
Expand Down

0 comments on commit b182c70

Please sign in to comment.