Skip to content

Commit

Permalink
Merge a0f8c69 into 2cb46ac
Browse files Browse the repository at this point in the history
  • Loading branch information
beam2d committed Aug 23, 2017
2 parents 2cb46ac + a0f8c69 commit 5313f86
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 3 deletions.
1 change: 1 addition & 0 deletions chainer/__init__.py
Expand Up @@ -56,6 +56,7 @@
from chainer.serializer import Deserializer # NOQA
from chainer.serializer import Serializer # NOQA
from chainer.variable import Parameter # NOQA
from chainer.variable import to_variable # NOQA
from chainer.variable import Variable # NOQA


Expand Down
4 changes: 1 addition & 3 deletions chainer/function_node.py
Expand Up @@ -193,9 +193,7 @@ def apply(self, inputs):
A tuple of output :class:`Variable` objects.
"""
input_vars = [x if isinstance(x, variable.Variable)
else variable.Variable(x, requires_grad=False)
for x in inputs]
input_vars = [chainer.to_variable(x) for x in inputs]
in_data = tuple([x.data for x in input_vars])
requires_grad = any([x.requires_grad for x in input_vars])

Expand Down
1 change: 1 addition & 0 deletions chainer/functions/connection/linear.py
Expand Up @@ -100,3 +100,4 @@ def linear(x, W, b=None):
return LinearFunction()(x, W)
else:
return LinearFunction()(x, W, b)
0
30 changes: 30 additions & 0 deletions chainer/variable.py
Expand Up @@ -1202,6 +1202,36 @@ def update(self):
self.update_rule.update(self)


def to_variable(obj):
"""Converts an array or a variable into :class:`~chainer.Variable`.
This is a convenient function to get a :class:`~chainer.Variable` object
transparently from a raw array or a variable.
Note that this function should only be used for type consistency (i.e., to
enforce the return value of an API having type :class:`~chainer.Varialbe`).
The :class:`~chainer.Variable.requires_grad` flag is kept as is; if ``obj``
is a raw array, the newly created variable has ``requires_grad = False``.
In order to make a variable w.r.t. which you want to compute the gradient,
you should use :class:`~chainer.Variable` directly.
Args:
obj (numpy.ndarray or cupy.ndarray or ~chainer.Variable): An array or
a variable that you want to convert to :class:`~chainer.Variable`.
Returns:
~chainer.Variable:
A variable converted from ``obj``. If ``obj`` is a raw array, this is a
new :class:`~chainer.Variable` object that wraps the array. If ``obj``
is already a :class:`~chainer.Variable` object, this function returns
``obj`` as is.
"""
if isinstance(obj, Variable):
return obj
return Variable(obj, requires_grad=False)


def _recover_parameter(data, name, grad, initializer, update_rule):
p = Parameter(initializer=initializer, name=name)
p.data = data
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/core/variable.rst
Expand Up @@ -6,5 +6,6 @@ Variable and Parameter
:nosignatures:

chainer.Variable
chainer.to_variable
chainer.Parameter
chainer.variable.VariableNode
22 changes: 22 additions & 0 deletions tests/chainer_tests/test_variable.py
Expand Up @@ -1392,4 +1392,26 @@ def test_raise_double_backprop_2(self):
x.grad_var.backward()


class TestToVariable(unittest.TestCase):

def check_to_variable_from_array(self, x):
y = chainer.to_variable(x)
self.assertIsInstance(y, chainer.Variable)
self.assertIs(y.data, x)
self.assertFalse(y.requires_grad)

def test_to_variable_from_numpy(self):
self.check_to_variable_from_array(np.empty(1, np.float32))

@attr.gpu
def test_to_variable_from_cupy(self):
self.check_to_variable_from_array(cuda.cupy.empty(1, np.float32))

def test_to_variable_from_variable(self):
x = chainer.Variable(np.array(1, np.float32))
y = chainer.to_variable(x)
self.assertIs(x, y)
self.assertTrue(y.requires_grad)


testing.run_module(__name__, __file__)

0 comments on commit 5313f86

Please sign in to comment.