diff --git a/chainer/__init__.py b/chainer/__init__.py index 767511fcbf22..5c8fea253830 100644 --- a/chainer/__init__.py +++ b/chainer/__init__.py @@ -56,6 +56,7 @@ from chainer.serializer import Deserializer # NOQA from chainer.serializer import Serializer # NOQA from chainer.variable import Parameter # NOQA +from chainer.variable import to_variable # NOQA from chainer.variable import Variable # NOQA diff --git a/chainer/function_node.py b/chainer/function_node.py index 06a3759dbb7c..4a29d36cc832 100644 --- a/chainer/function_node.py +++ b/chainer/function_node.py @@ -193,9 +193,7 @@ def apply(self, inputs): A tuple of output :class:`Variable` objects. """ - input_vars = [x if isinstance(x, variable.Variable) - else variable.Variable(x, requires_grad=False) - for x in inputs] + input_vars = [chainer.to_variable(x) for x in inputs] in_data = tuple([x.data for x in input_vars]) requires_grad = any([x.requires_grad for x in input_vars]) diff --git a/chainer/functions/connection/linear.py b/chainer/functions/connection/linear.py index 4c6e40837947..ebbe8a61c4c3 100644 --- a/chainer/functions/connection/linear.py +++ b/chainer/functions/connection/linear.py @@ -100,3 +100,4 @@ def linear(x, W, b=None): return LinearFunction()(x, W) else: return LinearFunction()(x, W, b) +0 \ No newline at end of file diff --git a/chainer/variable.py b/chainer/variable.py index 680b810d3aa1..1c7768357f04 100644 --- a/chainer/variable.py +++ b/chainer/variable.py @@ -1202,6 +1202,36 @@ def update(self): self.update_rule.update(self) +def to_variable(obj): + """Converts an array or a variable into :class:`~chainer.Variable`. + + This is a convenient function to get a :class:`~chainer.Variable` object + transparently from a raw array or a variable. + + Note that this function should only be used for type consistency (i.e., to + enforce the return value of an API having type :class:`~chainer.Varialbe`). + The :class:`~chainer.Variable.requires_grad` flag is kept as is; if ``obj`` + is a raw array, the newly created variable has ``requires_grad = False``. + In order to make a variable w.r.t. which you want to compute the gradient, + you should use :class:`~chainer.Variable` directly. + + Args: + obj (numpy.ndarray or cupy.ndarray or ~chainer.Variable): An array or + a variable that you want to convert to :class:`~chainer.Variable`. + + Returns: + ~chainer.Variable: + A variable converted from ``obj``. If ``obj`` is a raw array, this is a + new :class:`~chainer.Variable` object that wraps the array. If ``obj`` + is already a :class:`~chainer.Variable` object, this function returns + ``obj`` as is. + + """ + if isinstance(obj, Variable): + return obj + return Variable(obj, requires_grad=False) + + def _recover_parameter(data, name, grad, initializer, update_rule): p = Parameter(initializer=initializer, name=name) p.data = data diff --git a/docs/source/reference/core/variable.rst b/docs/source/reference/core/variable.rst index a6b4bb32a9ca..c6d21667fe7d 100644 --- a/docs/source/reference/core/variable.rst +++ b/docs/source/reference/core/variable.rst @@ -6,5 +6,6 @@ Variable and Parameter :nosignatures: chainer.Variable + chainer.to_variable chainer.Parameter chainer.variable.VariableNode diff --git a/tests/chainer_tests/test_variable.py b/tests/chainer_tests/test_variable.py index feb66a15914f..fa0b88501dcf 100644 --- a/tests/chainer_tests/test_variable.py +++ b/tests/chainer_tests/test_variable.py @@ -1392,4 +1392,26 @@ def test_raise_double_backprop_2(self): x.grad_var.backward() +class TestToVariable(unittest.TestCase): + + def check_to_variable_from_array(self, x): + y = chainer.to_variable(x) + self.assertIsInstance(y, chainer.Variable) + self.assertIs(y.data, x) + self.assertFalse(y.requires_grad) + + def test_to_variable_from_numpy(self): + self.check_to_variable_from_array(np.empty(1, np.float32)) + + @attr.gpu + def test_to_variable_from_cupy(self): + self.check_to_variable_from_array(cuda.cupy.empty(1, np.float32)) + + def test_to_variable_from_variable(self): + x = chainer.Variable(np.array(1, np.float32)) + y = chainer.to_variable(x) + self.assertIs(x, y) + self.assertTrue(y.requires_grad) + + testing.run_module(__name__, __file__)