Skip to content

Commit

Permalink
Merge 01d8041 into 9453976
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Aug 29, 2017
2 parents 9453976 + 01d8041 commit 0af10bb
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 24 deletions.
5 changes: 3 additions & 2 deletions chainer/configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import print_function
import contextlib
import sys
import threading

from chainer import utils


class GlobalConfig(object):

Expand Down Expand Up @@ -99,7 +100,7 @@ def _print_attrs(obj, keys, file):
'''


@contextlib.contextmanager
@utils.contextmanager
def using_config(name, value, config=config):
"""using_config(name, value, config=chainer.config)
Expand Down
4 changes: 2 additions & 2 deletions chainer/link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import contextlib
import copy
import warnings

Expand All @@ -8,6 +7,7 @@

from chainer import cuda
from chainer import initializers
from chainer import utils
from chainer import variable


Expand Down Expand Up @@ -159,7 +159,7 @@ def within_init_scope(self):
"""
return getattr(self, '_within_init_scope', False)

@contextlib.contextmanager
@utils.contextmanager
def init_scope(self):
"""Creates an initialization scope.
Expand Down
6 changes: 3 additions & 3 deletions chainer/reporter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import collections
import contextlib
import copy

import numpy
import six

from chainer import configuration
from chainer import cuda
from chainer import utils
from chainer import variable


Expand Down Expand Up @@ -77,7 +77,7 @@ def __exit__(self, exc_type, exc_value, traceback):
"""Recovers the previous reporter object to the current."""
_reporters.pop()

@contextlib.contextmanager
@utils.contextmanager
def scope(self, observation):
"""Creates a scope to report observed values to ``observation``.
Expand Down Expand Up @@ -225,7 +225,7 @@ def __call__(self, x, y):
current.report(values, observer)


@contextlib.contextmanager
@utils.contextmanager
def report_scope(observation):
"""Returns a report scope with the current reporter.
Expand Down
5 changes: 3 additions & 2 deletions chainer/testing/helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import contextlib
import pkg_resources
import unittest
import warnings

from chainer import utils


def with_requires(*requirements):
"""Run a test case only when given requirements are satisfied.
Expand Down Expand Up @@ -33,7 +34,7 @@ def with_requires(*requirements):
return unittest.skipIf(skip, msg)


@contextlib.contextmanager
@utils.contextmanager
def assert_warns(expected):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
Expand Down
23 changes: 23 additions & 0 deletions chainer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import contextlib
import inspect

import numpy

from chainer.utils import walker_alias # NOQA
Expand Down Expand Up @@ -32,3 +35,23 @@ def force_type(dtype, value):
return value.astype(dtype, copy=False)
else:
return value


def contextmanager(func):
"""A decorator used to define a factory function for ``with`` statement.
This does exactlyl the same thing as ``@contextlib.contextmanager``, but
with workaround for the issue that it does not transfer source file name
and line number at which the original function was defined.
"""

wrapper = contextlib.contextmanager(func)

sourcefile = inspect.getsourcefile(func)
_, linenumber = inspect.getsourcelines(func)

# Note: these attributes are used in docs/source/conf.py.
if sourcefile is not None:
wrapper.__chainer_wrapped_sourcefile__ = sourcefile
wrapper.__chainer_wrapped_linenumber__ = linenumber
return wrapper
4 changes: 2 additions & 2 deletions chainer/utils/type_check.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import contextlib
import operator
import sys
import threading

import numpy

from chainer import cuda
from chainer import utils


_thread_local = threading.local()


@contextlib.contextmanager
@utils.contextmanager
def get_function_check_context(f):
default = getattr(_thread_local, 'current_function', None)
_thread_local.current_function = f
Expand Down
41 changes: 28 additions & 13 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,31 @@ def _get_source_relative_path(source_abs_path):
return os.path.relpath(source_abs_path, _find_source_root(source_abs_path))


def _get_sourcefile_and_linenumber(obj):
# Check to see `obj` has attributes that are injected by
# chainer.utils.contextmanager.
if hasattr(obj, '__chainer_wrapped_sourcefile__'):
filename = obj.__chainer_wrapped_sourcefile__
linenum = obj.__chainer_wrapped_linenumber__
return filename, linenum

# Get the source file name and line number at which obj is defined.
try:
filename = inspect.getsourcefile(obj)
except TypeError:
# obj is not a module, class, function, ..etc.
return None, None

# inspect can return None for cython objects
if filename is None:
return None, None

# Get the source line number
_, linenum = inspect.getsourcelines(obj)

return filename, linenum


def linkcode_resolve(domain, info):
if domain != 'py' or not info['module']:
return None
Expand All @@ -408,21 +433,11 @@ def linkcode_resolve(domain, info):
if not (mod.__name__ == 'chainer' or mod.__name__.startswith('chainer.')):
return None

# Get the source file name and line number at which obj is defined.
try:
filename = inspect.getsourcefile(obj)
except TypeError:
# obj is not a module, class, function, ..etc.
# Retrieve source file name and line number
filename, linenum = _get_sourcefile_and_linenumber(obj)
if filename is None or linenum is None:
return None

# inspect can return None for cython objects
if filename is None:
return None

# Get the source line number
_, linenum = inspect.getsourcelines(obj)
assert isinstance(linenum, six.integer_types)

filename = os.path.realpath(filename)
relpath = _get_source_relative_path(filename)

Expand Down

0 comments on commit 0af10bb

Please sign in to comment.