Skip to content

Commit

Permalink
Merge branch 'master' into _v2-doc-api-browsability
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Jun 2, 2017
2 parents 8d6c359 + c6adfeb commit d4786ba
Show file tree
Hide file tree
Showing 105 changed files with 3,902 additions and 2,241 deletions.
39 changes: 1 addition & 38 deletions LICENSE
Expand Up @@ -17,41 +17,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.


######################################################################
# The CuPy is designed based on NumPy's API.
# CuPy's source code and documents contain the original NumPy ones.
######################################################################
Copyright (c) 2005-2016, NumPy Developers.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.

* Neither the name of the NumPy Developers nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
######################################################################
THE SOFTWARE.
160 changes: 40 additions & 120 deletions README.md
@@ -1,146 +1,73 @@
[![pypi](https://img.shields.io/pypi/v/chainer.svg)](https://pypi.python.org/pypi/chainer)
[![GitHub license](https://img.shields.io/github/license/pfnet/chainer.svg)](https://github.com/pfnet/chainer)
[![travis](https://img.shields.io/travis/pfnet/chainer/master.svg)](https://travis-ci.org/pfnet/chainer)
[![coveralls](https://img.shields.io/coveralls/pfnet/chainer.svg)](https://coveralls.io/github/pfnet/chainer)
[![GitHub license](https://img.shields.io/github/license/chainer/chainer.svg)](https://github.com/chainer/chainer)
[![travis](https://img.shields.io/travis/chainer/chainer/master.svg)](https://travis-ci.org/chainer/chainer)
[![coveralls](https://img.shields.io/coveralls/chainer/chainer.svg)](https://coveralls.io/github/chainer/chainer)
[![Read the Docs](https://readthedocs.org/projects/chainer/badge/?version=stable)](http://docs.chainer.org/en/stable/?badge=stable)

# Chainer: a neural network framework

## Informations
- [Official site](http://chainer.org/)
- [Official document](http://docs.chainer.org/)

Examples
- [Code examples](https://github.com/pfnet/chainer/tree/master/examples)
- [External examples](https://github.com/pfnet/chainer/wiki/External-examples)
- [Research projects using Chainer](https://github.com/pfnet/chainer/wiki/Research-projects-using-Chainer)

Social
- [Forum](https://groups.google.com/forum/#!forum/chainer)
- [Twitter](https://twitter.com/ChainerOfficial)
- [Join Slack](https://bit.ly/chainer-slack)
- [Forum (Japanese)](https://groups.google.com/forum/#!forum/chainer-jp)
- [Twitter(Japanese)](https://twitter.com/chainerjp)
- [Join Slack (Japanese)](https://bit.ly/chainer-jp-slack)

## Requirements

Chainer is tested on Ubuntu 14.04 and CentOS 7. We recommend them to use Chainer, though it may run on other systems as well.

Minimum requirements:
- Python 2.7.6+, 3.4.3+, 3.5.1+, 3.6.0+
- NumPy 1.9, 1.10, 1.11, 1.12
- Six 1.9

Requirements for some features:
- CUDA support
- CUDA 6.5, 7.0, 7.5, 8.0
- filelock
- g++ 4.8.4+
- cuDNN support
- cuDNN v2, v3, v4, v5, v5.1, v6
- Caffe model support
- Protocol Buffers (pip install protobuf)
- protobuf>=3.0.0 is required for Py3
- Image dataset support
- Pillow
- HDF5 serialization support
- h5py 2.5.0
- Testing utilities
- Mock
- Nose
# Chainer: a deep learning framework

## Installation

### Minimum installation
[**Website**](http://chainer.org/)
| [**Docs**](http://docs.chainer.org/en/stable/)
| [**Install Guide**](http://docs.chainer.org/en/stable/install.html)
| [**Tutorial**](http://docs.chainer.org/en/stable/tutorial/)
| **Examples** ([Official](https://github.com/chainer/chainer/blob/master/examples), [External](https://github.com/chainer/chainer/wiki/External-examples))
| **Forum** ([en](https://groups.google.com/forum/#!forum/chainer), [ja](https://groups.google.com/forum/#!forum/chainer-jp))
| **Slack** ([en](https://bit.ly/chainer-slack), [ja](https://bit.ly/chainer-jp-slack))
| **Twitter** ([en](https://twitter.com/ChainerOfficial), [ja](https://twitter.com/ChainerJP))

If you use old ``setuptools``, upgrade it:

```
pip install -U setuptools
```
*Chainer* is a Python-based deep learning framework aiming at flexibility.
It provides automatic differentiation APIs based on the **define-by-run** approach (a.k.a. dynamic computational graphs) as well as object-oriented high-level APIs to build and train neural networks.
It also supports CUDA/cuDNN using [CuPy](https://github.com/cupy/cupy) for high performance training and inference.
For more details of Chainer, see the documents and resources listed above and join the community in Forum, Slack, and Twitter.

Then, install Chainer via PyPI:
```
pip install chainer
```
## Stable version

You can also install Chainer from the source code:
```
python setup.py install
```
The stable version of current Chainer is separated in here: [v2](https://github.com/chainer/chainer/tree/v2).

## Installation

### Installation with CUDA
To install Chainer, use `pip`.

If you want to enable CUDA, first you have to install CUDA and set the environment variable `PATH` and `LD_LIBRARY_PATH` for CUDA executables and libraries.
For example, if you are using Ubuntu and CUDA is installed by the official distribution, then CUDA is installed at `/usr/local/cuda`.
In this case, you have to add the following lines to `.bashrc` or `.zshrc` (choose which you are using):
```
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
```sh
$ pip install chainer
```

Chainer had `chainer-cuda-deps` module to enable CUDA in previous version.
Recent version (>=1.3) does not require this module.
So **you do not have to install** `chainer-cuda-deps`.
To enable CUDA support, [set up CUDA](http://docs.nvidia.com/cuda/index.html#installation-guides) and install [CuPy](https://github.com/cupy/cupy).

If you want to enable cuDNN, add a directory containing `cudnn.h` to `CFLAGS`, and add a directory containing `libcudnn.so` to `LDFLAGS` and `LD_LIBRARY_PATH`:
```
export CFLAGS=-I/path/to/cudnn/include
export LDFLAGS=-L/path/to/cudnn/lib
export LD_LIBRARY_PATH=/path/to/cudnn/lib:$LD_LIBRARY_PATH
```sh
$ pip install cupy
```
Do not forget to restart your terminal session (or `source` it) to enable these changes.
And then, reinstall Chainer.


### Image dataset support

If you want to use Image dataset (`chainer/datasets/ImageDataset`), please install Pillow manually.
Supported image format depends on your environment.

```
pip install pillow
```
[See the installation guide for more details.](http://docs.chainer.org/en/stable/install.html).


### HDF5 Support
## Docker image

If you want to use HDF5 serialization, please install h5py manually.
h5py requires libhdf5.
Anaconda distribution includes this package.
If you are using another Python distribution, use either of the following commands to install libhdf5 depending on your Linux environment:
We are providing the official Docker image.
This image supports [nvidia-docker](https://github.com/NVIDIA/nvidia-docker).
Login to the environment with the following command, and run the Python interpreter to use Chainer with CUDA and cuDNN support.

```
apt-get install libhdf5-dev
yum install hdf5-devel
$ nvidia-docker run -it chainer/chainer /bin/bash
```

And then, install h5py via PyPI.
You may need to install Cython for h5py.

```
pip install cython
pip install h5py
```
## Contribution

Any contributions to Chainer are welcome!
If you want to file an issue or send a pull request, [please follow the contribution guide](https://docs.chainer.org/contribution.html).

### Multi-GPU Support

Multi-GPU training is supported by MultiprocessParallelUpdater.
If you want to use MultiprocessParallelUpdater, please install [NCCL](https://github.com/NVIDIA/nccl) by following the installation guide.
## License

MIT License (see `LICENSE` file).

## Run with Docker

We provide the official Docker image.
Use [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) command to run Chainer image with GPU.
You can login to the environment with bash, and run the Python interpreter.
## More information

- [Release notes](https://github.com/chainer/chainer/releases)
- [Research projects using Chainer](https://github.com/chainer/chainer/wiki/Research-projects-using-Chainer)

```
$ nvidia-docker run -it chainer/chainer /bin/bash
```

## Reference

Expand All @@ -149,10 +76,3 @@ Chainer: a Next-Generation Open Source Framework for Deep Learning,
*Proceedings of Workshop on Machine Learning Systems(LearningSys) in
The Twenty-ninth Annual Conference on Neural Information Processing Systems (NIPS)*, (2015)
[URL](http://learningsys.org/papers/LearningSys_2015_paper_33.pdf), [BibTex](chainer_bibtex.txt)


[github](https://github.com/pfnet/chainer)

## License

MIT License (see `LICENSE` file).
6 changes: 5 additions & 1 deletion chainer/__init__.py
Expand Up @@ -50,6 +50,7 @@
from chainer.serializer import AbstractSerializer # NOQA
from chainer.serializer import Deserializer # NOQA
from chainer.serializer import Serializer # NOQA
from chainer.variable import Parameter # NOQA
from chainer.variable import Variable # NOQA


Expand Down Expand Up @@ -78,7 +79,10 @@ def get_function_hooks():


global_config.debug = bool(int(os.environ.get('CHAINER_DEBUG', '0')))
global_config.cudnn_deterministic = False
global_config.enable_backprop = True
global_config.keep_graph_on_report = bool(int(
os.environ.get('CHAINER_KEEP_GRAPH_ON_REPORT', '0')))
global_config.train = True
global_config.type_check = bool(int(os.environ.get('CHAINER_TYPE_CHECK', '1')))
global_config.use_cudnn = os.environ.get('CHAINER_USE_CUDNN', 'auto')
Expand All @@ -90,7 +94,7 @@ def get_function_hooks():
}


_cudnn_version = cuda.cudnn.cudnn.getVersion() if cuda.cudnn_enabled else 1
_cudnn_version = cuda.cudnn.cudnn.getVersion() if cuda.cudnn_enabled else -1


def should_use_cudnn(level, lowest_version=0):
Expand Down
11 changes: 9 additions & 2 deletions chainer/function.py
Expand Up @@ -146,6 +146,8 @@ class Function(object):
"""

rank = 0 # default value of the rank

def __call__(self, *inputs):
"""Applies forward propagation with chaining backward references.
Expand Down Expand Up @@ -233,7 +235,6 @@ def __call__(self, *inputs):
for index in output_indexes_to_retain:
ret[index].retain_data()
del self._output_indexes_to_retain
self.output_data = tuple([y.node.data for y in ret])

if len(ret) == 1:
return ret[0]
Expand Down Expand Up @@ -493,7 +494,7 @@ def retain_inputs(self, indexes):
"""
self._input_indexes_to_retain = indexes

def retain_outputs(self, indexes):
def retain_outputs(self, indexes, retain_after_backward=False):
"""Lets specified output variable nodes keep data arrays.
By calling this method from :meth:`forward`, the function can specify
Expand All @@ -516,8 +517,14 @@ def retain_outputs(self, indexes):
indexes (iterable of int): Indexes of input variables that the
function does not require for backprop.
retain_after_backward (bool): If ``True``, a reference to the
outputs will remain after the backprop of the function is over.
If ``False``, the reference will be deleted.
"""
self._output_indexes_to_retain = indexes
if retain_after_backward:
self._retain_after_backward = retain_after_backward


class FunctionHook(object):
Expand Down
18 changes: 16 additions & 2 deletions chainer/functions/activation/leaky_relu.py
Expand Up @@ -24,19 +24,33 @@ def check_type_forward(self, in_types):
def forward_cpu(self, x):
y = x[0].copy()
y[x[0] < 0] *= self.slope
if self.slope >= 0:
self.retain_inputs(())
self.retain_outputs((0,))
return y,

def forward_gpu(self, x):
y = _kern()(x[0], x[0], self.slope)
if self.slope >= 0:
self.retain_inputs(())
self.retain_outputs((0,))
return y,

def backward_cpu(self, x, gy):
gx = gy[0].copy()
gx[x[0] < 0] *= self.slope
if self.slope >= 0:
y = self.output_data
gx[y[0] < 0] *= self.slope
else:
gx[x[0] < 0] *= self.slope
return gx,

def backward_gpu(self, x, gy):
gx = _kern()(x[0], gy[0], self.slope)
if self.slope >= 0:
y = self.output_data
gx = _kern()(y[0], gy[0], self.slope)
else:
gx = _kern()(x[0], gy[0], self.slope)
return gx,


Expand Down
23 changes: 13 additions & 10 deletions chainer/functions/activation/log_softmax.py
Expand Up @@ -48,9 +48,6 @@ class LogSoftmax(function.Function):

"""Log-softmax activation function."""

def __init__(self):
self.y = None

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)
x_type, = in_types
Expand All @@ -61,25 +58,31 @@ def check_type_forward(self, in_types):
)

def forward(self, xs):
self.y = _log_softmax(xs[0])
return self.y,
y = _log_softmax(xs[0])
self._x_xp = cuda.get_array_module(*xs)
self._x_shape = xs[0].shape
self._x_dtype = xs[0].dtype
self.retain_inputs(())
self.retain_outputs((0,))
return y,

def backward(self, x, gy):
xp = cuda.get_array_module(*x)
y = self.output_data[0]
xp = self._x_xp
if xp is not numpy and chainer.should_use_cudnn('>=auto', 3000):
oz_dtype = 'd' if x[0].dtype == 'd' else 'f'
oz_dtype = 'd' if self._x_dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes
handle = cudnn.get_handle()
gx = xp.empty_like(x[0])
gx = xp.empty(self._x_shape, dtype=self._x_dtype)
gx_cube = gx.reshape(gx.shape[:2] + (-1, 1))
desc = cudnn.create_tensor_descriptor(gx_cube)
libcudnn.softmaxBackward(
handle, _algorithm, _mode, one.data, desc.value,
self.y.data.ptr, desc.value, gy[0].data.ptr, zero.data,
y.data.ptr, desc.value, gy[0].data.ptr, zero.data,
desc.value, gx.data.ptr)
else:
gx = gy[0] - xp.exp(self.y) * gy[0].sum(axis=1, keepdims=True)
gx = gy[0] - xp.exp(y) * gy[0].sum(axis=1, keepdims=True)

return gx,

Expand Down

0 comments on commit d4786ba

Please sign in to comment.