Skip to content

Commit

Permalink
Merge c566cd5 into eaf1662
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaehashi committed Apr 10, 2018
2 parents eaf1662 + c566cd5 commit e345496
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# import class and function
# These functions from backends.cuda are kept for backward compatibility
from chainer._runtime_info import print_runtime_info # NOQA
from chainer.backends.cuda import should_use_cudnn # NOQA
from chainer.backends.cuda import should_use_cudnn_tensor_core # NOQA
from chainer.configuration import config # NOQA
Expand Down
46 changes: 46 additions & 0 deletions chainer/_runtime_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import sys

import numpy
import six

import chainer
from chainer.backends import cuda


class _RuntimeInfo(object):

chainer_version = None
numpy_version = None
cuda_info = None

def __init__(self):
self.chainer_version = chainer.__version__
self.numpy_version = numpy.__version__
if cuda.available:
self.cuda_info = cuda.cupyx.get_runtime_info()
else:
self.cuda_info = None

def __str__(self):
s = six.StringIO()
s.write('''Chainer: {}\n'''.format(self.chainer_version))
s.write('''NumPy: {}\n'''.format(self.numpy_version))
if self.cuda_info is None:
s.write('''CuPy: Not Available\n''')
else:
s.write('''CuPy:\n''')
for line in str(self.cuda_info).splitlines():
s.write(''' {}\n'''.format(line))
return s.getvalue()


def get_runtime_info():
return _RuntimeInfo()


def print_runtime_info(out=None):
if out is None:
out = sys.stdout
out.write(str(get_runtime_info()))
if hasattr(out, 'flush'):
out.flush()
21 changes: 21 additions & 0 deletions tests/chainer_tests/test_runtime_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

import six

import chainer
from chainer import _runtime_info
from chainer import testing


class TestRuntimeInfo(unittest.TestCase):
def test_get_runtime_info(self):
info = _runtime_info.get_runtime_info()
assert chainer.__version__ in str(info)

def test_print_runtime_info(self):
out = six.StringIO()
_runtime_info.print_runtime_info(out)
assert out.getvalue() == str(_runtime_info.get_runtime_info())


testing.run_module(__name__, __file__)

0 comments on commit e345496

Please sign in to comment.