diff --git a/chainer/__init__.py b/chainer/__init__.py index 767511fcbf22..affdd2ead88c 100644 --- a/chainer/__init__.py +++ b/chainer/__init__.py @@ -56,6 +56,7 @@ from chainer.serializer import Deserializer # NOQA from chainer.serializer import Serializer # NOQA from chainer.variable import Parameter # NOQA +from chainer.variable import as_variable # NOQA from chainer.variable import Variable # NOQA diff --git a/chainer/function_node.py b/chainer/function_node.py index 06a3759dbb7c..85993a471d13 100644 --- a/chainer/function_node.py +++ b/chainer/function_node.py @@ -193,9 +193,7 @@ def apply(self, inputs): A tuple of output :class:`Variable` objects. """ - input_vars = [x if isinstance(x, variable.Variable) - else variable.Variable(x, requires_grad=False) - for x in inputs] + input_vars = [chainer.as_variable(x) for x in inputs] in_data = tuple([x.data for x in input_vars]) requires_grad = any([x.requires_grad for x in input_vars]) diff --git a/chainer/variable.py b/chainer/variable.py index 11663f3b0fdb..19ddd83a05a7 100644 --- a/chainer/variable.py +++ b/chainer/variable.py @@ -1208,6 +1208,36 @@ def update(self): self.update_rule.update(self) +def as_variable(obj): + """Converts an array or a variable into :class:`~chainer.Variable`. + + This is a convenient function to get a :class:`~chainer.Variable` object + transparently from a raw array or a variable. + + Note that this function should only be used for type consistency (i.e., to + enforce the return value of an API having type :class:`~chainer.Varialbe`). + The :class:`~chainer.Variable.requires_grad` flag is kept as is; if ``obj`` + is a raw array, the newly created variable has ``requires_grad = False``. + In order to make a variable w.r.t. which you want to compute the gradient, + you should use :class:`~chainer.Variable` directly. + + Args: + obj (numpy.ndarray or cupy.ndarray or ~chainer.Variable): An array or + a variable that you want to convert to :class:`~chainer.Variable`. + + Returns: + ~chainer.Variable: + A variable converted from ``obj``. If ``obj`` is a raw array, this is a + new :class:`~chainer.Variable` object that wraps the array. If ``obj`` + is already a :class:`~chainer.Variable` object, this function returns + ``obj`` as is. + + """ + if isinstance(obj, Variable): + return obj + return Variable(obj, requires_grad=False) + + def _recover_parameter(data, name, grad, initializer, update_rule): p = Parameter(initializer=initializer, name=name) p.data = data diff --git a/docs/source/reference/core/variable.rst b/docs/source/reference/core/variable.rst index a6b4bb32a9ca..50007d6aac30 100644 --- a/docs/source/reference/core/variable.rst +++ b/docs/source/reference/core/variable.rst @@ -6,5 +6,6 @@ Variable and Parameter :nosignatures: chainer.Variable + chainer.as_variable chainer.Parameter chainer.variable.VariableNode diff --git a/tests/chainer_tests/test_variable.py b/tests/chainer_tests/test_variable.py index cff7c47d3888..0954a5c03cd5 100644 --- a/tests/chainer_tests/test_variable.py +++ b/tests/chainer_tests/test_variable.py @@ -1407,4 +1407,26 @@ def test_raise_double_backprop_2(self): x.grad_var.backward() +class TestAsVariable(unittest.TestCase): + + def check_to_variable_from_array(self, x): + y = chainer.as_variable(x) + self.assertIsInstance(y, chainer.Variable) + self.assertIs(y.data, x) + self.assertFalse(y.requires_grad) + + def test_to_variable_from_numpy(self): + self.check_to_variable_from_array(np.empty(1, np.float32)) + + @attr.gpu + def test_to_variable_from_cupy(self): + self.check_to_variable_from_array(cuda.cupy.empty(1, np.float32)) + + def test_to_variable_from_variable(self): + x = chainer.Variable(np.array(1, np.float32)) + y = chainer.as_variable(x) + self.assertIs(x, y) + self.assertTrue(y.requires_grad) + + testing.run_module(__name__, __file__)