Skip to content

Commit

Permalink
New style flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
okuta committed Aug 24, 2017
1 parent 7930acf commit f6543af
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 17 deletions.
1 change: 0 additions & 1 deletion chainer/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from chainer.functions.array.expand_dims import expand_dims # NOQA
from chainer.functions.array.expand_dims import ExpandDims # NOQA
from chainer.functions.array.flatten import flatten # NOQA
from chainer.functions.array.flatten import Flatten # NOQA
from chainer.functions.array.fliplr import fliplr # NOQA
from chainer.functions.array.fliplr import FlipLR # NOQA
from chainer.functions.array.flipud import flipud # NOQA
Expand Down
17 changes: 2 additions & 15 deletions chainer/functions/array/flatten.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,4 @@
from chainer import function


class Flatten(function.Function):

"""Flatten function."""

def forward(self, inputs):
self.retain_inputs(())
self._in_shape = inputs[0].shape
return inputs[0].ravel(),

def backward(self, inputs, grads):
return grads[0].reshape(self._in_shape),
from chainer.functions.array import reshape


def flatten(x):
Expand Down Expand Up @@ -50,4 +37,4 @@ def flatten(x):
array([0, 1, 2, 3, 4, 5, 6, 7])
"""
return Flatten()(x)
return reshape.reshape(x, (x.size,))
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_forward_gpu(self):

def check_backward(self, x_data, g_data):
gradient_check.check_backward(
functions.Flatten(), x_data, g_data, dtype=numpy.float64)
functions.flatten, x_data, g_data, dtype=numpy.float64)

def test_backward_cpu(self):
self.check_backward(self.x, self.g)
Expand Down

0 comments on commit f6543af

Please sign in to comment.