Skip to content

Commit

Permalink
cache the result of detect_gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
korepwx committed Jul 23, 2018
1 parent f16efc9 commit 1e5db63
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
23 changes: 23 additions & 0 deletions tfsnippet/examples/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'smart_apply',
'flatten',
'unflatten',
'cached',
]


Expand Down Expand Up @@ -329,3 +330,25 @@ def unflatten(x, static_front_shape, front_shape, name=None):
x.set_shape(tf.TensorShape(list(static_front_shape) +
list(static_back_shape)))
return x


def cached(method):
"""
Decorate `method`, to cache its result.
Args:
method: The method whose result should be cached.
Returns:
The decorated method.
"""
results = {}

@six.wraps(method)
def wrapper(*args, **kwargs):
cache_key = (args, tuple((k, kwargs[k]) for k, v in sorted(kwargs)))
if cache_key not in results:
results[cache_key] = method(*args, **kwargs)
return results[cache_key]

return wrapper
4 changes: 2 additions & 2 deletions tfsnippet/examples/utils/multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import six
import tensorflow as tf

from .misc import is_dynamic_tensor
from .misc import is_dynamic_tensor, cached

__all__ = ['detect_gpus', 'average_gradients', 'MultiGPU']


@cached
def detect_gpus():
"""
Detect the GPU devices and their interconnection on current machine.
Expand Down Expand Up @@ -134,7 +135,6 @@ def __init__(self, disable_prebuild=False):
supported by CPUs for the time being, thus the pre-building on
CPUs might need to be disabled.
"""

gpu_groups = detect_gpus()
if not gpu_groups:
self._main_device = '/device:CPU:0'
Expand Down

0 comments on commit 1e5db63

Please sign in to comment.