Skip to content

Commit

Permalink
Merge pull request #7063 from kshitij12345/add-losses
Browse files Browse the repository at this point in the history
[chainerx] add basic loss functions
  • Loading branch information
hvy committed Jun 24, 2019
2 parents 51709e0 + 04b4b07 commit 3c567d3
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 1 deletion.
12 changes: 11 additions & 1 deletion chainerx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ class ndarray:
def abs(x: ndarray) -> ndarray: ...


def absolute_error(x1: ndarray, x2: ndarray) -> ndarray: ...


def add(x1: tp.Any, x2: tp.Any) -> ndarray: ...


Expand Down Expand Up @@ -717,8 +720,12 @@ def greater(x1: ndarray, x2: ndarray) -> ndarray: ...

def greater_equal(x1: ndarray, x2: ndarray) -> ndarray: ...

def gaussian_kl_divergence(mean: ndarray, ln_var: ndarray, reduce: tp.Optional[str]="sum") -> ndarray: ...

def hstack(arrays: tp.List[ndarray]) -> ndarray: ...

def huber_loss(x: ndarray, t: ndarray, delta: float,reduce: tp.Optional[str]="sum_along_second_axis") -> ndarray: ...

def identity(
n: int,
dtype: tp.Optional[tp.Any]=None,
Expand Down Expand Up @@ -858,7 +865,10 @@ def split(
def square(x: ndarray) -> ndarray: ...


def squared_difference(x1: tp.Any, x2: tp.Any) -> ndarray: ...
def squared_error(x1: ndarray, x2: ndarray) -> ndarray: ...


def squared_difference(x1: ndarray, x2: ndarray) -> ndarray: ...


def sqrt(x: ndarray) -> ndarray: ...
Expand Down
139 changes: 139 additions & 0 deletions chainerx/_docs/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,145 @@ def _docs_logic():
""")


def _docs_loss():
_docs.set_doc(
chainerx.absolute_error,
"""Element-wise absolute error function.
Computes the element-wise absolute error :math:`L` between two inputs
:math:`x_0` and :math:`x_1` defined as follows.
.. math::
L = |x_0 - x_1|
Args:
x0 (~chainerx.ndarray): Input variable.
x1 (~chainerx.ndarray): Input variable.
Returns:
:class:`~chainerx.ndarray`: A variable holding an array representing
the absolute error of two inputs.
.. seealso:: :func:`chainer.functions.absolute_error`
""")

_docs.set_doc(
chainerx.squared_error,
"""Element-wise squared error function.
Computes the squared error between two variables:
.. math::
(x_0 - x_1)^2
where operation is done in elementwise manner.
Note that the error is not scaled by 1/2:
Can be used to compute Mean Squared Error by just calling `mean()`
on the output array.
Args:
x0 (~chainerx.ndarray): Input variable.
x1 (~chainerx.ndarray): Input variable.
Returns:
:class:`~chainerx.ndarray`: A variable holding an array representing
the squared error of two inputs.
.. seealso:: :func:`chainer.functions.squared_error`
""")

_docs.set_doc(
chainerx.huber_loss,
"""Computes the Huber loss.
The Huber loss is similar to the :func:`mean_squared_error` but is less
sensitive to outliers in the data. It is defined as
.. math::
L_{\\delta}(a) = \\left \\{ \\begin{array}{cc}
\\frac{1}{2} a^2 & {\\rm if~|a| \\leq \\delta} \\\\
\\delta (|a| - \\frac{1}{2} \\delta) & {\\rm otherwise,}
\\end{array} \\right.
where :math:`a = x - t` is the difference between the input :math:`x`
and the target :math:`t`.
The loss is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the elementwise
loss values. If it is ``'sum_along_second_axis'``, loss values are
summed up along the second axis (i.e. ``axis=1``).
See: `Huber loss - Wikipedia <https://en.wikipedia.org/wiki/Huber_loss>`_.
Args:
x (~chainerx.ndarray): Input variable.
The shape of ``x`` should be (:math:`N`, :math:`K`, ...) if
``reduce='sum_along_second_axis'``.
t (~chainerx.ndarray): Target variable for
regression. The shape of ``t`` should be
(:math:`N`, :math:`K`, ...) if ``reduce='sum_along_second_axis'``.
delta (float): Constant variable for Huber loss function
as used in definition.
reduce (str): Reduction option. Its value must be either
``'sum_along_second_axis'`` or ``'no'``. Otherwise,
:class:`ValueError` is raised.
Returns:
:class:`~chainerx.ndarray`:
A variable object holding a scalar array of the
Huber loss :math:`L_{\\delta}`.
If ``reduce`` is ``'no'``, the output variable holds array
whose shape is same as one of (hence both of) input variables.
If it is ``'sum_along_second_axis'``, the shape of the array
is same as the input variables, except the second axis is removed.
.. seealso:: :func:`chainer.functions.huber_loss`
""")

_docs.set_doc(
chainerx.gaussian_kl_divergence,
"""Computes the KL-divergence of Gaussian variables from the standard one.
Given two variable ``mean`` representing :math:`\\mu` and ``ln_var``
representing :math:`\\log(\\sigma^2)`, this function calculates
the KL-divergence in elementwise manner between the given multi-dimensional
Gaussian :math:`N(\\mu, S)` and the standard Gaussian :math:`N(0, I)`
.. math::
D_{\\mathbf{KL}}(N(\\mu, S) \\| N(0, I)),
where :math:`S` is a diagonal matrix such that :math:`S_{ii} = \\sigma_i^2`
and :math:`I` is an identity matrix.
The output is a variable whose value depends on the value of
the option ``reduce``. If it is ``'no'``, it holds the elementwise
loss values. If it is ``'sum'`` or ``'mean'``, loss values are summed up
or averaged respectively.
Args:
mean (~chainerx.ndarray):
A variable representing mean of given
gaussian distribution, :math:`\\mu`.
ln_var (~chainerx.ndarray):
A variable representing logarithm of
variance of given gaussian distribution, :math:`\\log(\\sigma^2)`.
reduce (str): Reduction option. Its value must be either
``'sum'``, ``'mean'`` or ``'no'``. Otherwise, :class:`ValueError`
is raised.
Returns:
:class:`~chainerx.ndarray`:
A variable representing KL-divergence between
given gaussian distribution and the standard gaussian.
If ``reduce`` is ``'no'``, the output variable holds array
whose shape is same as one of (hence both of) input variables.
If it is ``'sum'`` or ``'mean'``, the output variable holds a
scalar value.
.. seealso:: :func:`chainer.functions.gaussian_kl_divergence`
""")


def _docs_manipulation():
_docs.set_doc(
chainerx.reshape,
Expand Down
26 changes: 26 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "chainerx/routines/indexing.h"
#include "chainerx/routines/linalg.h"
#include "chainerx/routines/logic.h"
#include "chainerx/routines/loss.h"
#include "chainerx/routines/manipulation.h"
#include "chainerx/routines/misc.h"
#include "chainerx/routines/normalization.h"
Expand Down Expand Up @@ -1008,13 +1009,38 @@ void InitChainerxPooling(pybind11::module& m) {
"pad_mode"_a = "ignore");
}

void InitChainerxLoss(pybind11::module& m) {
m.def("absolute_error",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(AbsoluteError(Array{x1}, Array{x2})); },
"x1"_a,
"x2"_a);
m.def("squared_error",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(SquaredError(Array{x1}, Array{x2})); },
"x1"_a,
"x2"_a);
m.def("gaussian_kl_divergence",
[](const ArrayBodyPtr& mean, const ArrayBodyPtr& ln_var) {
return MoveArrayBody(GaussianKLDivergence(Array{mean}, Array{ln_var}));
},
"mean"_a,
"ln_var"_a);
m.def("huber_loss",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2, Scalar delta) {
return MoveArrayBody(HuberLoss(Array{x1}, Array{x2}, delta));
},
"x1"_a,
"x2"_a,
"delta"_a);
}

} // namespace

void InitChainerxRoutines(pybind11::module& m) {
InitChainerxCreation(m);
InitChainerxIndexing(m);
InitChainerxLinalg(m);
InitChainerxLogic(m);
InitChainerxLoss(m);
InitChainerxManipulation(m);
InitChainerxActivation(m);
InitChainerxArithmetic(m);
Expand Down
2 changes: 2 additions & 0 deletions chainerx_cc/chainerx/routines/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_library(chainerx_routines STATIC
indexing.cc
linalg.cc
logic.cc
loss.cc
manipulation.cc
misc.cc
normalization.cc
Expand All @@ -32,6 +33,7 @@ install(FILES
indexing.h
linalg.h
logic.h
loss.h
manipulation.h
misc.h
normalization.h
Expand Down
27 changes: 27 additions & 0 deletions chainerx_cc/chainerx/routines/loss.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "chainerx/routines/loss.h"

#include "chainerx/array.h"
#include "chainerx/routines/creation.h"
#include "chainerx/routines/explog.h"
#include "chainerx/routines/indexing.h"
#include "chainerx/routines/misc.h"
#include "chainerx/scalar.h"

namespace chainerx {

Array AbsoluteError(const Array& x1, const Array& x2) { return Absolute(x1 - x2); }

Array SquaredError(const Array& x1, const Array& x2) { return SquaredDifference(x1, x2); }

Array GaussianKLDivergence(const Array& mean, const Array& ln_var) { return (Square(mean) + Exp(ln_var) - ln_var - 1) * 0.5; }

Array HuberLoss(const Array& x1, const Array& x2, Scalar delta) {
Array a = x1 - x2;
Array abs_a = Absolute(a);
Array delta_array = chainerx::FullLike(a, delta, a.device());

// TODO(kshitij12345) : use Array < Scalar when implemented.
return Where(abs_a < delta_array, 0.5 * Square(a), delta * (abs_a - Scalar{0.5} * delta));
}

} // namespace chainerx
16 changes: 16 additions & 0 deletions chainerx_cc/chainerx/routines/loss.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "chainerx/array.h"
#include "chainerx/scalar.h"

namespace chainerx {

Array AbsoluteError(const Array& x1, const Array& x2);

Array SquaredError(const Array& x1, const Array& x2);

Array GaussianKLDivergence(const Array& mean, const Array& ln_var);

Array HuberLoss(const Array& x1, const Array& x2, Scalar delta);

} // namespace chainerx
12 changes: 12 additions & 0 deletions docs/source/chainerx/reference/routines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ Logic functions
chainerx.equal
chainerx.not_equal

Loss functions
---------------

.. autosummary::
:toctree: generated/
:nosignatures:

chainerx.absolute_error
chainerx.squared_error
chainerx.huber_loss
chainerx.gaussian_kl_divergence

Mathematical functions
----------------------

Expand Down
Loading

0 comments on commit 3c567d3

Please sign in to comment.