Skip to content

Commit

Permalink
Merge pull request #6807 from takagi/dtype-mean-absolute-error
Browse files Browse the repository at this point in the history
Use intermediate dtype in `F.mean_absolute_error` for FP16
  • Loading branch information
mergify[bot] committed Nov 19, 2019
2 parents 467be22 + a8874a3 commit e94ac64
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
22 changes: 18 additions & 4 deletions chainer/functions/loss/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from chainer.utils import type_check


def _get_intermediate_dtype(dtype):
# Returns the dtype for intermediate calculation.
# For float16 input, float32 is used.
# Otherwise the same dtype as the parameter is used.
if dtype == numpy.float16:
return numpy.float32
return dtype


class MeanAbsoluteError(function_node.FunctionNode):

"""Mean absolute error function."""
Expand All @@ -21,14 +30,19 @@ def check_type_forward(self, in_types):
def forward_cpu(self, inputs):
x0, x1 = inputs
self.diff = x0 - x1
diff = self.diff.ravel()
return numpy.array(abs(diff).sum() / diff.size, dtype=diff.dtype),
orig_dtype = self.diff.dtype
dtype = _get_intermediate_dtype(orig_dtype)
diff = self.diff.ravel().astype(dtype, copy=False)
return numpy.array(abs(diff).sum() / diff.size, dtype=orig_dtype),

def forward_gpu(self, inputs):
x0, x1 = inputs
self.diff = x0 - x1
diff = self.diff.ravel()
return abs(diff).sum() / diff.dtype.type(diff.size),
orig_dtype = self.diff.dtype
dtype = _get_intermediate_dtype(orig_dtype)
diff = self.diff.ravel().astype(dtype, copy=False)
return (abs(diff).sum() / diff.dtype.type(diff.size)).astype(
orig_dtype, copy=False),

def backward(self, indexes, grad_outputs):
gy, = grad_outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import numpy

import chainer
from chainer.backends import cuda
from chainer import functions
from chainer import testing
from chainer import utils
from chainer.testing import attr
from chainer.utils import type_check


Expand Down Expand Up @@ -83,4 +85,23 @@ def test_invalid_dtype2(self):
functions.mean_absolute_error(x0, x1)


# See chainer#6702.
class TestMeanAbsoluteErrorFP16Overflow(unittest.TestCase):

def check_fp16_overflow(self, xp):
x0 = chainer.Variable(xp.full(
shape=(64, 1, 16, 16), fill_value=2, dtype=xp.float16))
x1 = chainer.Variable(xp.full(
shape=(64, 1, 16, 16), fill_value=-2, dtype=xp.float16))
loss = functions.mean_absolute_error(x0, x1)
self.assertFalse(xp.isinf(loss.array))

def test_fp16_overflow_cpu(self):
self.check_fp16_overflow(numpy)

@attr.gpu
def test_fp16_overflow_gpu(self):
self.check_fp16_overflow(cuda.cupy)


testing.run_module(__name__, __file__)

0 comments on commit e94ac64

Please sign in to comment.