Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Crissman committed Apr 18, 2019
2 parents 5ae5084 + aa38d54 commit 75846ea
Show file tree
Hide file tree
Showing 184 changed files with 6,786 additions and 3,874 deletions.
7 changes: 6 additions & 1 deletion .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ please take a look at [our contribution guide](https://docs.chainer.org/en/stabl
Specifically, if it is a bug report, these information are very helpful:

* Conditions
<!-- If you're using Chainer 4.0+, you can also get this information by typing `python -c 'import chainer; chainer.print_runtime_info()'. -->
<!--
You can also get this information by typing the following:
```
python -c 'import chainer; chainer.print_runtime_info()'
```
-->
- Chainer version
- CuPy version
- OS/Platform
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.pyo
*.cpp
*.so
*.dylib
build
\#*\#
.\#*
Expand Down
11 changes: 6 additions & 5 deletions chainer/_backprop_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import sys
import traceback

import six
Expand Down Expand Up @@ -212,11 +213,11 @@ def iter_gxs(gxs):


def _get_columns():
try:
get_terminal_size = shutil.get_terminal_size
except AttributeError:
return os.getenv('COLUMNS', 80)
return get_terminal_size()[0]
# Returns the terminal column width.
if sys.version_info >= (3, 3):
cols, rows = shutil.get_terminal_size()
return cols
return int(os.getenv('COLUMNS', 80))


def _reraise_with_stack(func, e):
Expand Down
7 changes: 3 additions & 4 deletions chainer/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_device(device_spec):
* A string starts with ``'@cupy:'``.
(ex. ``'@cupy:0'``)
* A :class:`chainer.backends.cuda.Device` object.
* A :class:`cupy.cuda.Device` object.
* NumPy
Expand Down Expand Up @@ -180,8 +180,7 @@ def using_device(device_spec):


def get_array_module(*args):
"""Gets an appropriate one from :mod:`numpy`, :mod:`cupy`, or
:mod:`chainerx`.
"""Gets an appropriate NumPy-compatible module to process arguments
This function will return their data arrays' array module for
:class:`~chainer.Variable` arguments.
Expand All @@ -191,7 +190,7 @@ def get_array_module(*args):
used.
Returns:
module: :mod:`cupy`, :mod:`numpy`, or :mod:`chainerx` is returned based
module: :mod:`numpy`, :mod:`cupy`, or :mod:`chainerx` is returned based
on the types of the arguments.
"""
Expand Down
12 changes: 1 addition & 11 deletions chainer/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,17 +417,7 @@ def _array_to_gpu(array, device, stream):
# the array interface.
if array.device.backend.name == 'cuda':
# Convert to cupy.ndarray on the same device as source array
array = cupy.ndarray(
array.shape,
array.dtype,
cupy.cuda.MemoryPointer(
cupy.cuda.UnownedMemory(
array.data_ptr + array.offset,
array.data_size,
array,
array.device.index),
0),
strides=array.strides)
array = chainerx._to_cupy(array)
else:
array = chainerx.to_numpy(array)
elif isinstance(array, (numpy.number, numpy.bool_)):
Expand Down
2 changes: 1 addition & 1 deletion chainer/datasets/pickle_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_example(self, index):
def open_pickle_dataset(path):
"""Opens a dataset stored in a given path.
This is a hepler function to open :class:`PickleDataset`. It opens a given
This is a helper function to open :class:`PickleDataset`. It opens a given
file in binary mode, and creates a :class:`PickleDataset` instance.
This method does not close the opened file. A user needs to call
Expand Down
26 changes: 22 additions & 4 deletions chainer/device_resident.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def to_cpu(self):
visitor = _ToDeviceVisitor(
backend.CpuDevice(),
entry_method_info=('to_cpu', {}),
skip_between_cupy_devices=True)
skip_between_cupy_devices=True,
starting_device_resident=self)
self.__to_device(visitor)
return self

Expand Down Expand Up @@ -91,7 +92,8 @@ def to_gpu(
visitor = _ToDeviceVisitor(
device,
entry_method_info=('to_gpu', {'device': device.device}),
skip_between_cupy_devices=True)
skip_between_cupy_devices=True,
starting_device_resident=self)
self.__to_device(visitor)
return self

Expand All @@ -101,7 +103,8 @@ def to_intel64(self):
intel64.check_ideep_available()
visitor = _ToDeviceVisitor(
chainer.get_device(intel64.Intel64Device()),
entry_method_info=('to_intel64', {}))
entry_method_info=('to_intel64', {}),
starting_device_resident=self)
self.__to_device(visitor)
return self

Expand Down Expand Up @@ -205,7 +208,8 @@ class _ToDeviceVisitor(DeviceResidentsVisitor):

def __init__(
self, device, entry_method_info=None,
skip_between_cupy_devices=False):
skip_between_cupy_devices=False,
starting_device_resident=None):

assert isinstance(device, chainer.backend.Device)

Expand All @@ -219,15 +223,29 @@ def __init__(
assert len(entry_method_info) == 2
assert entry_method_info[0] in ('to_cpu', 'to_gpu', 'to_intel64')

# starting_device_resident is also for backward compatibility
# workaround for overridden methods.
# It is a DeviceResident if to_xxx methods were initially called
# on this visitor. This is used to avoid infinite accept-visit loop
# that would occur by calling to_xxx methods.
assert (starting_device_resident is None
or isinstance(starting_device_resident, DeviceResident))

self._device = device
self._entry_method_info = entry_method_info
self._skip_between_cupy_devices = skip_between_cupy_devices
self._starting_device_resident = starting_device_resident

def visit_device_resident(self, device_resident):
device_resident._device = self._device

# Backward compatibility workaround for overridden methods
if device_resident._overridden_to_methods:
# Skip this device resident, if the visitor was initially triggered
# from it.
if device_resident is self._starting_device_resident:
return

if self._entry_method_info is not None:
# Deprecated method is being called: e.g. to_cpu and to_gpu.
method_name, kwargs = self._entry_method_info
Expand Down
44 changes: 22 additions & 22 deletions chainer/distributions/independent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import functools
import operator

import numpy

from chainer.backend import cuda
Expand All @@ -10,6 +7,7 @@
from chainer.functions.array import transpose
from chainer.functions.math import sum as sum_mod
from chainer.functions.math import prod
from chainer.utils import array
from chainer.utils import cache


Expand Down Expand Up @@ -60,10 +58,11 @@ def batch_shape(self):
def event_shape(self):
return self.__event_shape

@property
@cache.cached_property
def covariance(self):
'''Returns the covariance of the distribution based on the original
i.i.d. distribution. By definition, the covariance of the new
""" The covariance of the independent distribution.
By definition, the covariance of the new
distribution becomes block diagonal matrix. Let
:math:`\\Sigma_{\\mathbf{x}}` be the covariance matrix of the original
random variable :math:`\\mathbf{x} \\in \\mathbb{R}^d`, and
Expand All @@ -82,11 +81,13 @@ def covariance(self):
Note that this relationship holds only if the covariance matrix of the
original distribution is given analytically.
'''
num_repeat = functools.reduce(
operator.mul,
self.distribution.batch_shape[-self.reinterpreted_batch_ndims:], 1)
dim = functools.reduce(operator.mul, self.distribution.event_shape, 1)
Returns:
~chainer.Variable: The covariance of the distribution.
"""
num_repeat = array.size_of_shape(
self.distribution.batch_shape[-self.reinterpreted_batch_ndims:])
dim = array.size_of_shape(self.distribution.event_shape)
cov = repeat.repeat(
reshape.reshape(
self.distribution.covariance,
Expand All @@ -110,7 +111,9 @@ def cdf(self, x):
return self._reduce(prod.prod, self.distribution.cdf(x))

def icdf(self, x):
'''Cumulative distribution function for multivariate variable is not
"""The inverse cumulative distribution function for multivariate variable.
Cumulative distribution function for multivariate variable is not
invertible. This function always raises :class:`RuntimeError`.
Args:
Expand All @@ -119,7 +122,7 @@ def icdf(self, x):
Raises:
:class:`RuntimeError`
'''
"""

raise RuntimeError(
'Cumulative distribution function for multivariate variable '
Expand Down Expand Up @@ -176,8 +179,7 @@ def xp(self):
return self.distribution.xp

def _reduce(self, op, stat):
range_ = tuple(
(-1 - numpy.arange(self.reinterpreted_batch_ndims)).tolist())
range_ = tuple(range(-self.reinterpreted_batch_ndims, 0))
return op(stat, axis=range_)

def _get_default_reinterpreted_batch_ndims(self, distribution):
Expand All @@ -186,10 +188,9 @@ def _get_default_reinterpreted_batch_ndims(self, distribution):

@cache.cached_property
def _block_indicator(self):
num_repeat = functools.reduce(
operator.mul,
self.distribution.batch_shape[-self.reinterpreted_batch_ndims:], 1)
dim = functools.reduce(operator.mul, self.distribution.event_shape, 1)
num_repeat = array.size_of_shape(
self.distribution.batch_shape[-self.reinterpreted_batch_ndims:])
dim = array.size_of_shape(self.distribution.event_shape)
block_indicator = numpy.fromfunction(
lambda i, j: i // dim == j // dim,
(num_repeat * dim, num_repeat * dim)).astype(int)
Expand All @@ -200,8 +201,7 @@ def _block_indicator(self):

@distribution.register_kl(Independent, Independent)
def _kl_independent_independent(dist1, dist2):
'''Batched KL divergence :math:`\\mathrm{KL}(\\mathrm{dist1} ||
\\mathrm{dist2})` for Independent distributions.
"""Computes Kullback-Leibler divergence for independent distributions.
We can leverage the fact that
.. math::
Expand All @@ -223,7 +223,7 @@ def _kl_independent_independent(dist1, dist2):
Raises:
:class:`ValueError`: If the event space for ``dist1`` and ``dist2``,
or their underlying distributions don't match.
'''
"""

p = dist1.distribution
q = dist2.distribution
Expand Down
4 changes: 3 additions & 1 deletion chainer/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,9 @@ def _backprop(outputs, inputs, grad_required, retain_grad, grads, loss_scale):

# Collect the gradients w.r.t. the outputs
ys = [y() for y in func.outputs] # access via weak ref
gys = tuple([grads.pop(y) for y in ys])
gys = tuple([grads.pop(y)
if y is not None and y.creator_node is not None else None
for y in ys])

for node, gy in six.moves.zip(ys, gys):
if node is not None:
Expand Down
2 changes: 1 addition & 1 deletion chainer/functions/connection/convolution_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs):
If ``cover_all`` option is ``True``, the filter will cover the all
spatial locations. So, if the last stride of filter does not cover the
end of spatial locations, an addtional stride will be applied to the end
end of spatial locations, an additional stride will be applied to the end
part of spatial locations. In this case, the output size :math:`(h_O, w_O)`
is determined by the following equations:
Expand Down
2 changes: 1 addition & 1 deletion chainer/functions/connection/convolution_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def convolution_nd(x, W, b=None, stride=1, pad=0, cover_all=False,
If ``cover_all`` option is ``True``, the filter will cover the all
spatial locations. So, if the last stride of filter does not cover the
end of spatial locations, an addtional stride will be applied to the end
end of spatial locations, an additional stride will be applied to the end
part of spatial locations. In this case, the output size is determined by
the following equations:
Expand Down
18 changes: 18 additions & 0 deletions chainer/functions/math/maximum.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,28 @@ def maximum(x1, x2):
Args:
x1 (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variables to be compared.
A :math:`(s_1, s_2, ..., s_N)` -shaped float array.
x2 (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variables to be compared.
A :math:`(s_1, s_2, ..., s_N)` -shaped float array.
Returns:
~chainer.Variable: Output variable.
.. admonition:: Example
>>> x1 = np.arange(6).astype(np.float32)
>>> x1
array([0., 1., 2., 3., 4., 5.], dtype=float32)
>>> x2 = np.array([5, 4, 3, 2, 1, 0]).astype(np.float32)
>>> x2
array([5., 4., 3., 2., 1., 0.], dtype=float32)
>>> y = F.maximum(x1, x2)
>>> y.shape
(6,)
>>> y.array
array([5., 4., 3., 3., 4., 5.], dtype=float32)
"""
return Maximum().apply((x1, x2))[0]
3 changes: 3 additions & 0 deletions chainer/functions/math/minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def _fwd(self, x, xp):

class Min(SelectorBase):

def forward_chainerx(self, x):
return chainerx.amin(x[0], axis=self.axis, keepdims=self.keepdims),

def _fwd(self, x, xp):
return xp.amin(x, axis=self.axis, keepdims=self.keepdims)

Expand Down
16 changes: 16 additions & 0 deletions chainer/functions/math/square.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,24 @@ def square(x):
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
A :math:`(s_1, s_2, ..., s_N)` -shaped float array.
Returns:
~chainer.Variable: Output variable.
A :math:`(s_1, s_2, ..., s_N)` -shaped float array.
.. admonition:: Example
>>> x = np.arange(6).reshape(2,3).astype(np.float32)
>>> x
array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32)
>>> y = F.square(x)
>>> y.shape
(2, 3)
>>> y.array
array([[ 0., 1., 4.],
[ 9., 16., 25.]], dtype=float32)
"""
return Square().apply((x,))[0]

0 comments on commit 75846ea

Please sign in to comment.