From 31334f993127a3c5665ae5a527a76f9e2db79830 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:04:01 +0900 Subject: [PATCH 01/15] add benchmark framework --- benchmarks/.gitignore | 5 + benchmarks/README.rst | 38 +++++ benchmarks/asv.conf.json | 149 ++++++++++++++++++++ benchmarks/benchmarks/__init__.py | 40 ++++++ benchmarks/benchmarks/utils/__init__.py | 7 + benchmarks/benchmarks/utils/backend.py | 180 ++++++++++++++++++++++++ benchmarks/benchmarks/utils/helper.py | 101 +++++++++++++ benchmarks/run.sh | 33 +++++ 8 files changed, 553 insertions(+) create mode 100644 benchmarks/.gitignore create mode 100644 benchmarks/README.rst create mode 100644 benchmarks/asv.conf.json create mode 100644 benchmarks/benchmarks/__init__.py create mode 100644 benchmarks/benchmarks/utils/__init__.py create mode 100644 benchmarks/benchmarks/utils/backend.py create mode 100644 benchmarks/benchmarks/utils/helper.py create mode 100755 benchmarks/run.sh diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore new file mode 100644 index 000000000000..491fb7d04081 --- /dev/null +++ b/benchmarks/.gitignore @@ -0,0 +1,5 @@ +html/ +results/ +env/ +chainer/ +cupy/ diff --git a/benchmarks/README.rst b/benchmarks/README.rst new file mode 100644 index 000000000000..29ce1f46d33e --- /dev/null +++ b/benchmarks/README.rst @@ -0,0 +1,38 @@ +Chainer Benchmarks +================== + +Benchmarking Chainer with Airspeed Velocity. + +Note that CuPy earlier than v3.1.0 or v4.0.0b1 are not supported. + +Requirements +------------ + +* ``asv`` +* ``Cython`` (to build CuPy) + +Usage +----- + +.. code-block:: sh + + # Enable ccache for performance (optional). + export PATH="/usr/lib/ccache:${PATH}" + export NVCC="ccache nvcc" + + # Run benchmark against target commit-ish of Chainer and CuPy. + # Note that specified versions must be a compatible combination. + ./run.sh master master + ./run.sh v4.0.0b4 v4.0.0b4 + + # Compare the benchmark results between two commits to see regression + # and/or performance improvements in command line. + alias git_commit='git show --format="%H"' + asv compare $(git_commit v4.0.0b4) $(git_commit master) + + # Convert the results into HTML. + # The result will be in `html` directory. + asv publish + + # Start the HTTP server to browse HTML. + asv preview diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json new file mode 100644 index 000000000000..06f53473e029 --- /dev/null +++ b/benchmarks/asv.conf.json @@ -0,0 +1,149 @@ +{ + // The version of the config file format. Do not change, unless + // you know what you are doing. + "version": 1, + + // The name of the project being benchmarked + "project": "chainer", + + // The project's homepage + "project_url": "https://chainer.org/", + + // The URL or local path of the source code repository for the + // project being benchmarked + "repo": "https://github.com/chainer/chainer.git", + + // List of branches to benchmark. If not provided, defaults to "master" + // (for git) or "default" (for mercurial). + "branches": ["master", "v3"], // for git + // "branches": ["default"], // for mercurial + + // The DVCS being used. If not set, it will be automatically + // determined from "repo" by looking at the protocol in the URL + // (if remote), or by looking for special directories, such as + // ".git" (if local). + // "dvcs": "git", + + // The tool to use to create environments. May be "conda", + // "virtualenv" or other value depending on the plugins in use. + // If missing or the empty string, the tool will be automatically + // determined by looking for tools on the PATH environment + // variable. + "environment_type": "virtualenv", + + // timeout in seconds for installing any dependencies in environment + // defaults to 10 min + //"install_timeout": 600, + + // the base URL to show a commit for the project. + "show_commit_url": "https://github.com/chainer/chainer/commit/", + + // The Pythons you'd like to test against. If not provided, defaults + // to the current version of Python used to run `asv`. + // "pythons": ["2.7", "3.3"], + + // The list of conda channel names to be searched for benchmark + // dependency packages in the specified order + // "conda_channels": ["conda-forge", "defaults"] + + // The matrix of dependencies to test. Each key is the name of a + // package (in PyPI) and the values are version numbers. An empty + // list or empty string indicates to just test against the default + // (latest) version. null indicates that the package is to not be + // installed. If the package to be tested is only available from + // PyPi, and the 'environment_type' is conda, then you can preface + // the package name by 'pip+', and the package will be installed via + // pip (with all the conda available packages installed first, + // followed by the pip installed packages). + // + // "matrix": { + // "numpy": ["1.6", "1.7"], + // "six": ["", null], // test with and without six installed + // "pip+emcee": [""], // emcee is only available for install with pip. + // }, + "matrix": { + // Optional dependencies required for benchmark. + "ideep4py": [], + }, + + // Combinations of libraries/python versions can be excluded/included + // from the set to test. Each entry is a dictionary containing additional + // key-value pairs to include/exclude. + // + // An exclude entry excludes entries where all values match. The + // values are regexps that should match the whole string. + // + // An include entry adds an environment. Only the packages listed + // are installed. The 'python' key is required. The exclude rules + // do not apply to includes. + // + // In addition to package names, the following keys are available: + // + // - python + // Python version, as in the *pythons* variable above. + // - environment_type + // Environment type, as above. + // - sys_platform + // Platform, as in sys.platform. Possible values for the common + // cases: 'linux2', 'win32', 'cygwin', 'darwin'. + // + // "exclude": [ + // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows + // {"environment_type": "conda", "six": null}, // don't run without six on conda + // ], + // + // "include": [ + // // additional env for python2.7 + // {"python": "2.7", "numpy": "1.8"}, + // // additional env if run on windows+conda + // {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""}, + // ], + + // The directory (relative to the current directory) that benchmarks are + // stored in. If not provided, defaults to "benchmarks" + // "benchmark_dir": "benchmarks", + + // The directory (relative to the current directory) to cache the Python + // environments in. If not provided, defaults to "env" + // "env_dir": "env", + + // The directory (relative to the current directory) that raw benchmark + // results are stored in. If not provided, defaults to "results". + // "results_dir": "results", + + // The directory (relative to the current directory) that the html tree + // should be written to. If not provided, defaults to "html". + // "html_dir": "html", + + // The number of characters to retain in the commit hashes. + // "hash_length": 8, + + // `asv` will cache wheels of the recent builds in each + // environment, making them faster to install next time. This is + // number of builds to keep, per environment. + // "wheel_cache_size": 0 + + // The commits after which the regression search in `asv publish` + // should start looking for regressions. Dictionary whose keys are + // regexps matching to benchmark names, and values corresponding to + // the commit (exclusive) after which to start looking for + // regressions. The default is to start from the first commit + // with results. If the commit is `null`, regression detection is + // skipped for the matching benchmark. + // + // "regressions_first_commits": { + // "some_benchmark": "352cdf", // Consider regressions only after this commit + // "another_benchmark": null, // Skip regression detection altogether + // } + + // The thresholds for relative change in results, after which `asv + // publish` starts reporting regressions. Dictionary of the same + // form as in ``regressions_first_commits``, with values + // indicating the thresholds. If multiple entries match, the + // maximum is taken. If no entry matches, the default is 5%. + // + // "regressions_thresholds": { + // "some_benchmark": 0.01, // Threshold of 1% + // "another_benchmark": 0.5, // Threshold of 50% + // } +} diff --git a/benchmarks/benchmarks/__init__.py b/benchmarks/benchmarks/__init__.py new file mode 100644 index 000000000000..5b33179d2911 --- /dev/null +++ b/benchmarks/benchmarks/__init__.py @@ -0,0 +1,40 @@ +import inspect + +import chainer + + +# Ensure that CuPy and cuDNN are available. +assert chainer.cuda.available +assert chainer.cuda.cudnn_enabled + + +class BenchmarkBase(object): + """Base class for all benchmarks. + + See also: http://asv.readthedocs.io/en/v0.2.1/writing_benchmarks.html + """ + + def __init__(self, *args, **kwargs): + # Set pretty_name to ``.`` instead of the default + # ``..``. This is because it is often too + # verbose to display module name in result HTML. + # This is a workaround needed until ASV 0.3 release. + members = inspect.getmembers( + self.__class__, + predicate=lambda x: inspect.ismethod(x) or inspect.isfunction(x)) + for (name, func) in members: + if hasattr(func, '__func__'): + # For Python 2 + func = func.__func__ + if name.startswith('time_'): + name = name[5:] + func.pretty_name = '{}.{}'.format(type(self).__name__, name) + + def setup(self, *args, **kwargs): + pass + + def setup_cache(self, *args, **kwargs): + pass + + def teardown(self, *args, **kwargs): + pass diff --git a/benchmarks/benchmarks/utils/__init__.py b/benchmarks/benchmarks/utils/__init__.py new file mode 100644 index 000000000000..b222ec6bb6f1 --- /dev/null +++ b/benchmarks/benchmarks/utils/__init__.py @@ -0,0 +1,7 @@ +from benchmarks.utils.backend import backends # NOQA +from benchmarks.utils.backend import is_backend_gpu # NOQA +from benchmarks.utils.backend import is_backend_ideep # NOQA +from benchmarks.utils.backend import have_ideep # NOQA + +from benchmarks.utils.helper import parameterize # NOQA +from benchmarks.utils.helper import sync # NOQA diff --git a/benchmarks/benchmarks/utils/backend.py b/benchmarks/benchmarks/utils/backend.py new file mode 100644 index 000000000000..3c036cb62644 --- /dev/null +++ b/benchmarks/benchmarks/utils/backend.py @@ -0,0 +1,180 @@ +from functools import wraps +import inspect +import os +import warnings + +import chainer +import cupy +import numpy + +from benchmarks.utils.helper import _is_func +from benchmarks.utils.helper import parameterize +from benchmarks.utils.helper import sync + + +_backend_modes = [ + # GPU (with use_cudnn == 'never') + 'gpu', + + # GPU (with use_cudnn == 'auto') + 'gpu-cudnn', + + # CPU (with use_ideep == 'never') + 'cpu', + + # CPU (with use_ideep == 'auto') + 'cpu-ideep', +] + + +def backends(*modes): + """Class decorator to parameterize the benchmark class with backends. + + This is a special form of :func:`parameterize` to parameterize the + backend variation. For all `time_*` functions and `setup` function + in the class, this decorator: + + * wraps the function to be called with the Chainer configuration + (`use_cudnn` and `use_ideep`) set to the current backend variation. + * wraps the function to perform CPU/GPU synchronization after the + benchmark, when the current backend variation uses GPU. The time + taken for synchronization is counted as a elapsed time in the benchmark. + * injects the array module (`cupy` or `numpy` depending on the current + variation) as `self.xp` so that benchmark code can use it to work with + array modules with each backend. + * provides access to `is_backend_gpu()` and `is_backend_ideep()` methods + so that benchmark code can use it to change behavior depending on the + backend variation (e.g., `if is_backend_gpu(): model.to_gpu()`). + + This decorator adds parameter axis with the name of `backend`. + + Note that `cpu-ideep` mode will automatically be skipped if the current + benchmark setup does not support it, e.g., when running benchmark + against older Chainer version that does not support iDeep. + + You cannot apply `parameterize` decorator to the class already decorated + by this decorator. If you want to use `parameterize` along with this + decorator, make `parameterize` the most inner (i.e., the closest to the + class declaration) decorator. + + Example of usage is as follows: + + >>> @backend('gpu', 'gpu-cudnn', 'cpu', 'cpu-ideep') + ... class ConvolutionBenchmark(object): + ... def time_benchmark(self): + ... ... + """ + + assert all([m in _backend_modes for m in modes]) + + def _wrap_class(klass): + assert isinstance(klass, type) + return _inject_backend_mode(klass, modes) + + return _wrap_class + + +def _inject_backend_mode(klass, modes): + klass = parameterize([('backend', modes)])(klass) + + # `setup` method is mandatory to inject backends to skip axis. + if not hasattr(klass, 'setup'): + def _setup(self, *args, **kwargs): + pass + klass.setup = _setup + + members = inspect.getmembers(klass, predicate=_is_func) + + for (name, func) in members: + if not (name == 'setup' or name.startswith('time_')): + continue + + def _wrap_func(f): + @wraps(f) + def _wrapped_func(self, backend, *args, **kwargs): + _benchmark_backend_gpu = False + _benchmark_backend_ideep = False + xp = numpy + use_cudnn = 'never' + use_ideep = 'never' + + target = f + if backend.startswith('gpu'): + xp = cupy + _benchmark_backend_gpu = True + target = sync(target) + if 'cudnn' in backend: + use_cudnn = 'auto' + elif 'ideep' in backend: + if not have_ideep(): + # Raise in `setup` to skip this parameter axis. + warnings.warn('iDeep is unavailable') + raise NotImplementedError + use_ideep = 'auto' + _benchmark_backend_ideep = True + + with _BackendConfig({ + 'use_cudnn': use_cudnn, + 'use_ideep': use_ideep, + '_benchmark_backend_gpu': _benchmark_backend_gpu, + '_benchmark_backend_ideep': _benchmark_backend_ideep, + }): + + # Inject self.xp + assert not hasattr(self, 'xp') + setattr(self, 'xp', xp) + target(self, *args, **kwargs) + delattr(self, 'xp') + + return _wrapped_func + setattr(klass, name, _wrap_func(func)) + + return klass + + +class _BackendConfig(object): + """Context manager that changes multiple Chainer configurations.""" + + def __init__(self, params): + self._params = params + self._contexts = [] + + def __enter__(self): + self._contexts = [ + chainer.using_config(k, v) for (k, v) in self._params.items() + ] + for c in self._contexts: + c.__enter__() + return self + + def __exit__(self, typ, value, traceback): + for c in reversed(self._contexts): + c.__exit__(typ, value, traceback) + + +def is_backend_gpu(): + """Returns True if the current backend is GPU.""" + + return chainer.config._benchmark_backend_gpu + + +def is_backend_ideep(): + """Returns True if the current backend is iDeep.""" + + return chainer.config._benchmark_backend_ideep + + +def have_ideep(): + """Tests if iDeep can be used in the current benchmark configuration. + + If you intend to write benchmark for iDeep outside of `backend` decorator, + first make sure that iDeep is available using this function. + This makes possible to run the same benchmark code over past versions of + Chainer (prior to iDeep support). + """ + + try: + import chainer.backends.intel64 + except ImportError: + return False + return chainer.backends.intel64.is_ideep_available() diff --git a/benchmarks/benchmarks/utils/helper.py b/benchmarks/benchmarks/utils/helper.py new file mode 100644 index 000000000000..acfc7a55cc1f --- /dev/null +++ b/benchmarks/benchmarks/utils/helper.py @@ -0,0 +1,101 @@ +from functools import wraps +import inspect + +import cupy + + +def _is_func(target): + return inspect.ismethod(target) or inspect.isfunction(target) + + +def sync(target): + """Decorator to perform CPU/GPU synchronization. + + This decorator can be applied to both classes and functions. + """ + + if isinstance(target, type): + klass = target + members = inspect.getmembers(klass, predicate=_is_func) + for (name, func) in members: + if not (name == 'setup' or name.startswith('time_')): + continue + setattr(klass, name, _synchronized_func(func)) + return klass + elif _is_func(target): + return _synchronized_func(target) + else: + raise TypeError('cannot apply decorator to {}'.format(target)) + + +def _synchronized_func(func): + @wraps(func) + def _wrap_func(*args, **kwargs): + event = cupy.cuda.stream.Event() + event.record() + event.synchronize() + func(*args, **kwargs) + event = cupy.cuda.stream.Event() + event.record() + event.synchronize() + return _wrap_func + + +def parameterize(args): + """Class decorator to parameterize the benchmark. + + Pass the list of pair of parameter name and values. Each parameter + value will be passed as the function argument when benchmark runs. + See the example below for the usage. + + >>> @parameterize([ + ... ('batchsize', [32, 64, 128]), + ... ('n_gpus', [1, 2]), + ... ]) + ... class MyBenchmark(object): + ... def time_all(self, batchsize, n_gpus): + ... ... + + Parameters cannot be sparse due to the limitation of ASV. + """ + + def _wrap_class(klass): + """Wraps the given class. + + Internally, this function utilizes the parameterization feature of + ASV, i.e., set `params` and `param_names` attribute of the class. + `params` is a list of list of parameters, and `param_names` is a list + of parameter names. `params[i]` is a list of parameters for parameter + named `param_names[i]` where `i` is an index. + """ + + assert isinstance(klass, type) + + params = [arg[1] for arg in args] + param_names = [arg[0] for arg in args] + + orig_params = getattr(klass, 'params', []) + orig_param_names = getattr(klass, 'param_names', []) + + if 0 < len(orig_params): + # ASV allows specifying list of parameters (instead of list of + # list of parameters) if only one parameter axis is given. + if not isinstance(orig_params[0], (tuple, list)): + orig_params = [orig_params] + if len(orig_param_names) == 0: + orig_param_names = ['param'] + assert len(orig_param_names) == 1 + else: + assert len(orig_param_names) == 0 + + params += orig_params + param_names += orig_param_names + + assert len(params) == len(param_names) + + setattr(klass, 'params', params) + setattr(klass, 'param_names', param_names) + + return klass + + return _wrap_class diff --git a/benchmarks/run.sh b/benchmarks/run.sh new file mode 100755 index 000000000000..0ba482f788d2 --- /dev/null +++ b/benchmarks/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash -uex + +function run_asv() { + CHAINER_COMMIT="${1}"; shift + CUPY_COMMIT="${1}"; shift + + # Clone CuPy. + if [ ! -d cupy ]; then + git clone https://github.com/cupy/cupy.git + fi + + # Build CuPy commit to use for benchmark. + # Note that CuPy will be injected from current environment via `PYTHONPATH` + # instead of `matrix` in `asv.conf.json`, because Chainer and CuPy are + # tightly-coupled that we should manually pick which commit of CuPy to use. + # The version of the python command in outer world must match with the + # version used in the benchmark virtualenv. + pushd cupy + git remote update + git clean -fdx + git checkout "$(git show --format="%H" ${CUPY_COMMIT})" + python setup.py build_ext --inplace + export PYTHONPATH="${PWD}:${PYTHONPATH:-}" + popd + + # Run the benchmark. + # Uncomment the following lines to diagnose installation issues. + #export PIP_VERBOSE=True + #export PIP_LOG=pip.log + asv run --step 1 "$@" "${CHAINER_COMMIT}" +} + +run_asv "$@" From 53aab4cb8036a05bd05765bb42a4c1e912e122f1 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:04:23 +0900 Subject: [PATCH 02/15] import convnet-benchmark --- benchmarks/benchmarks/convnet/__init__.py | 0 benchmarks/benchmarks/convnet/benchmark.py | 98 +++++++++++++++++++ .../benchmarks/convnet/nets/__init__.py | 0 benchmarks/benchmarks/convnet/nets/alex.py | 29 ++++++ .../benchmarks/convnet/nets/googlenet.py | 73 ++++++++++++++ .../benchmarks/convnet/nets/overfeat.py | 29 ++++++ benchmarks/benchmarks/convnet/nets/vgga.py | 35 +++++++ 7 files changed, 264 insertions(+) create mode 100644 benchmarks/benchmarks/convnet/__init__.py create mode 100644 benchmarks/benchmarks/convnet/benchmark.py create mode 100644 benchmarks/benchmarks/convnet/nets/__init__.py create mode 100644 benchmarks/benchmarks/convnet/nets/alex.py create mode 100644 benchmarks/benchmarks/convnet/nets/googlenet.py create mode 100644 benchmarks/benchmarks/convnet/nets/overfeat.py create mode 100644 benchmarks/benchmarks/convnet/nets/vgga.py diff --git a/benchmarks/benchmarks/convnet/__init__.py b/benchmarks/benchmarks/convnet/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/benchmarks/benchmarks/convnet/benchmark.py b/benchmarks/benchmarks/convnet/benchmark.py new file mode 100644 index 000000000000..7dd2b3fc9068 --- /dev/null +++ b/benchmarks/benchmarks/convnet/benchmark.py @@ -0,0 +1,98 @@ +import chainer +from chainer import cuda +from chainer import optimizers + +import cupy +import numpy + +from benchmarks import BenchmarkBase +from benchmarks.utils import backends +from benchmarks.utils import is_backend_gpu +from benchmarks.utils import is_backend_ideep +from benchmarks.utils import parameterize + + +class _ConvnetBase(BenchmarkBase): + """ + Benchmark code from convnet-benchmark. + + https://github.com/soumith/convnet-benchmarks/tree/master/chainer + """ + + timeout = 600 + number = 1 + + def setup(self, arch, batchsize): + xp = self.xp + + if arch == 'alexnet': + from benchmarks.convnet.nets import alex + model = alex.Alex() + elif arch == 'googlenet': + from benchmarks.convnet.nets import googlenet + model = googlenet.GoogLeNet() + elif arch == 'vgga': + from benchmarks.convnet.nets import vgga + model = vgga.vgga() + elif arch == 'overfeat': + from benchmarks.convnet.nets import overfeat + model = overfeat.overfeat() + else: + raise ValueError('Invalid architecture name') + + if is_backend_gpu(): + model.to_gpu() + elif is_backend_ideep(): + model.to_intel64() + + # Setup optimizer + optimizer = optimizers.SGD(lr=0.01) + optimizer.setup(model) + + # Set cuDNN workspace size + workspace_size = int(1 * 2**30) + chainer.cuda.set_max_workspace_size(workspace_size) + + chainer.config.train = True + + x = xp.ndarray((batchsize, 3, model.insize, + model.insize), dtype=xp.float32) + x.fill(33333) + + if arch == 'googlenet': + out1, out2, out3 = model.forward(x) + out = out1 + out2 + out3 + else: + out = model.forward(x) + + out.zerograd() + out.grad.fill(3) + model.cleargrads() + + self._x = x + self._model = model + self._out = out + + def time_forward(self, arch, batchsize): + self._model.forward(self._x) + + def time_backward(self, arch, batchsize): + self._out.backward() + + +@backends('gpu', 'gpu-cudnn') +@parameterize([ + ('arch', ['vgga']), + ('batchsize', [32]), +]) +class ConvnetVGGA(_ConvnetBase): + pass + + +@backends('gpu', 'gpu-cudnn') +@parameterize([ + ('arch', ['alexnet', 'googlenet', 'overfeat']), + ('batchsize', [128]), +]) +class ConvnetOthers(_ConvnetBase): + pass diff --git a/benchmarks/benchmarks/convnet/nets/__init__.py b/benchmarks/benchmarks/convnet/nets/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/benchmarks/benchmarks/convnet/nets/alex.py b/benchmarks/benchmarks/convnet/nets/alex.py new file mode 100644 index 000000000000..8214b4db372c --- /dev/null +++ b/benchmarks/benchmarks/convnet/nets/alex.py @@ -0,0 +1,29 @@ +import chainer +import chainer.functions as F +import chainer.links as L + + +class Alex(chainer.Chain): + insize = 224 + + def __init__(self): + super(Alex, self).__init__() + with self.init_scope(): + self.conv1 = L.Convolution2D(3, 64, 11, stride=4, pad=2) + self.conv2 = L.Convolution2D(64, 192, 5, pad=2) + self.conv3 = L.Convolution2D(192, 384, 3, pad=1) + self.conv4 = L.Convolution2D(384, 256, 3, pad=1) + self.conv5 = L.Convolution2D(256, 256, 3, pad=1) + self.fc6 = L.Linear(256 * 6 * 6, 4096) + self.fc7 = L.Linear(4096, 4096) + self.fc8 = L.Linear(4096, 1000) + + def forward(self, x): + h = F.max_pooling_2d(F.relu(self.conv1(x)), 3, stride=2) + h = F.max_pooling_2d(F.relu(self.conv2(h)), 3, stride=2) + h = F.relu(self.conv3(h)) + h = F.relu(self.conv4(h)) + h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2) + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + return self.fc8(h) diff --git a/benchmarks/benchmarks/convnet/nets/googlenet.py b/benchmarks/benchmarks/convnet/nets/googlenet.py new file mode 100644 index 000000000000..6c1c1fe3bff3 --- /dev/null +++ b/benchmarks/benchmarks/convnet/nets/googlenet.py @@ -0,0 +1,73 @@ +import chainer +import chainer.functions as F +import chainer.links as L + + +class GoogLeNet(chainer.Chain): + + insize = 224 + + def __init__(self): + super(GoogLeNet, self).__init__() + with self.init_scope(): + self.conv1 = L.Convolution2D(3, 64, 7, stride=2, pad=3) + self.conv2_reduce = L.Convolution2D(64, 64, 1) + self.conv2 = L.Convolution2D(64, 192, 3, stride=1, pad=1) + self.inc3a = L.Inception(192, 64, 96, 128, 16, 32, 32) + self.inc3b = L.Inception(256, 128, 128, 192, 32, 96, 64) + self.inc4a = L.Inception(480, 192, 96, 208, 16, 48, 64) + self.inc4b = L.Inception(512, 160, 112, 224, 24, 64, 64) + self.inc4c = L.Inception(512, 128, 128, 256, 24, 64, 64) + self.inc4d = L.Inception(512, 112, 144, 288, 32, 64, 64) + self.inc4e = L.Inception(528, 256, 160, 320, 32, 128, 128) + self.inc5a = L.Inception(832, 256, 160, 320, 32, 128, 128) + self.inc5b = L.Inception(832, 384, 192, 384, 48, 128, 128) + self.loss3_fc = L.Linear(1024, 1000) + + self.loss1_conv = L.Convolution2D(512, 128, 1) + self.loss1_fc1 = L.Linear(4 * 4 * 128, 1024) + self.loss1_fc2 = L.Linear(1024, 1000) + + self.loss2_conv = L.Convolution2D(528, 128, 1) + self.loss2_fc1 = L.Linear(4 * 4 * 128, 1024) + self.loss2_fc2 = L.Linear(1024, 1000) + + def forward(self, x): + h = F.relu(self.conv1(x)) + h = F.local_response_normalization( + F.max_pooling_2d(h, 3, stride=2), n=5) + + h = F.relu(self.conv2_reduce(h)) + h = F.relu(self.conv2(h)) + h = F.max_pooling_2d( + F.local_response_normalization(h, n=5), 3, stride=2) + + h = self.inc3a(h) + h = self.inc3b(h) + h = F.max_pooling_2d(h, 3, stride=2) + h = self.inc4a(h) + + if chainer.config.train: + out1 = F.average_pooling_2d(h, 5, stride=3) + out1 = F.relu(self.loss1_conv(out1)) + out1 = F.relu(self.loss1_fc1(out1)) + out1 = self.loss1_fc2(out1) + + h = self.inc4b(h) + h = self.inc4c(h) + h = self.inc4d(h) + + if chainer.config.train: + out2 = F.average_pooling_2d(h, 5, stride=3) + out2 = F.relu(self.loss2_conv(out2)) + out2 = F.relu(self.loss2_fc1(out2)) + out2 = self.loss2_fc2(out2) + + h = self.inc4e(h) + h = F.max_pooling_2d(h, 3, stride=2) + h = self.inc5a(h) + h = self.inc5b(h) + + h = F.dropout(F.average_pooling_2d(h, 7, stride=1), 0.4) + out3 = self.loss3_fc(h) + return out1, out2, out3 diff --git a/benchmarks/benchmarks/convnet/nets/overfeat.py b/benchmarks/benchmarks/convnet/nets/overfeat.py new file mode 100644 index 000000000000..ca36110f9028 --- /dev/null +++ b/benchmarks/benchmarks/convnet/nets/overfeat.py @@ -0,0 +1,29 @@ +import chainer +import chainer.functions as F +import chainer.links as L + + +class overfeat(chainer.Chain): + insize = 231 + + def __init__(self): + super(overfeat, self).__init__() + with self.init_scope(): + self.conv1 = L.Convolution2D( 3, 96, 11, stride=4) + self.conv2 = L.Convolution2D( 96, 256, 5, pad=0) + self.conv3 = L.Convolution2D( 256, 512, 3, pad=1) + self.conv4 = L.Convolution2D( 512, 1024, 3, pad=1) + self.conv5 = L.Convolution2D(1024, 1024, 3, pad=1) + self.fc6 = L.Linear(1024 * 6 * 6, 3072) + self.fc7 = L.Linear(3072, 4096) + self.fc8 = L.Linear(4096, 1000) + + def forward(self, x): + h = F.max_pooling_2d(F.relu(self.conv1(x)), 2, stride=2) + h = F.max_pooling_2d(F.relu(self.conv2(h)), 2, stride=2) + h = F.relu(self.conv3(h)) + h = F.relu(self.conv4(h)) + h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2) + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + return self.fc8(h) diff --git a/benchmarks/benchmarks/convnet/nets/vgga.py b/benchmarks/benchmarks/convnet/nets/vgga.py new file mode 100644 index 000000000000..adbb34928c28 --- /dev/null +++ b/benchmarks/benchmarks/convnet/nets/vgga.py @@ -0,0 +1,35 @@ +import chainer +import chainer.functions as F +import chainer.links as L + + +class vgga(chainer.Chain): + insize = 224 + + def __init__(self): + super(vgga, self).__init__() + with self.init_scope(): + self.conv1 = L.Convolution2D( 3, 64, 3, stride=1, pad=1) + self.conv2 = L.Convolution2D( 64, 128, 3, stride=1, pad=1) + self.conv3 = L.Convolution2D(128, 256, 3, stride=1, pad=1) + self.conv4 = L.Convolution2D(256, 256, 3, stride=1, pad=1) + self.conv5 = L.Convolution2D(256, 512, 3, stride=1, pad=1) + self.conv6 = L.Convolution2D(512, 512, 3, stride=1, pad=1) + self.conv7 = L.Convolution2D(512, 512, 3, stride=1, pad=1) + self.conv8 = L.Convolution2D(512, 512, 3, stride=1, pad=1) + self.fc6 = L.Linear(512 * 7 * 7, 4096) + self.fc7 = L.Linear(4096, 4096) + self.fc8 = L.Linear(4096, 1000) + + def forward(self, x): + h = F.max_pooling_2d(F.relu(self.conv1(x)), 2, stride=2) + h = F.max_pooling_2d(F.relu(self.conv2(h)), 2, stride=2) + h = F.relu(self.conv3(h)) + h = F.max_pooling_2d(F.relu(self.conv4(h)), 2, stride=2) + h = F.relu(self.conv5(h)) + h = F.max_pooling_2d(F.relu(self.conv6(h)), 2, stride=2) + h = F.relu(self.conv7(h)) + h = F.max_pooling_2d(F.relu(self.conv8(h)), 2, stride=2) + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + return self.fc8(h) From b35c363cd294c8516d092d371b8363e0e6c95d22 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:05:16 +0900 Subject: [PATCH 03/15] add MLP benchmark --- benchmarks/benchmarks/mnist/__init__.py | 0 benchmarks/benchmarks/mnist/mnist.py | 67 +++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 benchmarks/benchmarks/mnist/__init__.py create mode 100644 benchmarks/benchmarks/mnist/mnist.py diff --git a/benchmarks/benchmarks/mnist/__init__.py b/benchmarks/benchmarks/mnist/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/benchmarks/benchmarks/mnist/mnist.py b/benchmarks/benchmarks/mnist/mnist.py new file mode 100644 index 000000000000..bd8a30739929 --- /dev/null +++ b/benchmarks/benchmarks/mnist/mnist.py @@ -0,0 +1,67 @@ +import chainer +import chainer.functions as F +import chainer.links as L +from chainer import training +from chainer.training import extensions + +from benchmarks import BenchmarkBase +from benchmarks.utils import backends +from benchmarks.utils import is_backend_gpu +from benchmarks.utils import is_backend_ideep +from benchmarks.utils import parameterize + + +class Network(chainer.Chain): + + def __init__(self, n_units, n_out): + super(Network, self).__init__() + with self.init_scope(): + self.l1 = L.Linear(None, n_units) + self.l2 = L.Linear(None, n_units) + self.l3 = L.Linear(None, n_out) + + def __call__(self, x): + h1 = F.relu(self.l1(x)) + h2 = F.relu(self.l2(h1)) + return self.l3(h2) + + +class Application(object): + + def main(self, units, epoch, batchsize): + model = L.Classifier(Network(units, 10)) + + gpu = -1 + if is_backend_gpu(): + gpu = 0 + chainer.cuda.get_device_from_id(gpu).use() + model.to_gpu() + elif is_backend_ideep(): + model.to_intel64() + + optimizer = chainer.optimizers.MomentumSGD() + optimizer.setup(model) + + train, test = chainer.datasets.get_mnist() + train_iter = chainer.iterators.SerialIterator( + train, batchsize) + test_iter = chainer.iterators.SerialIterator( + test, batchsize, repeat=False, shuffle=False) + + updater = training.updater.StandardUpdater( + train_iter, optimizer, device=gpu) + trainer = training.Trainer(updater, (epoch, 'epoch')) + trainer.extend(extensions.Evaluator(test_iter, model, device=gpu)) + + trainer.run() + + +@backends('gpu', 'gpu-cudnn', 'cpu', 'cpu-ideep') +@parameterize([ + ('units', [10, 100, 150]), +]) +class MLP(BenchmarkBase): + timeout = 360 + + def time_overall(self, units): + Application().main(units=units, epoch=1, batchsize=100) From 803d3785f80196d552a92e927a7c1212ceae549f Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:05:56 +0900 Subject: [PATCH 04/15] add framework for function tests --- benchmarks/benchmarks/functions/__init__.py | 68 +++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 benchmarks/benchmarks/functions/__init__.py diff --git a/benchmarks/benchmarks/functions/__init__.py b/benchmarks/benchmarks/functions/__init__.py new file mode 100644 index 000000000000..fe6e1183eece --- /dev/null +++ b/benchmarks/benchmarks/functions/__init__.py @@ -0,0 +1,68 @@ +import functools +import operator + +import numpy + +import chainer + +from benchmarks import BenchmarkBase + + +class FunctionBenchmark(BenchmarkBase): + + """The base class for benchmark of functions.""" + + # Call `test_*` methods only once as `backward()` has a side-effect. + number = 1 + + # Repeat the test for 10 times instead of 3 (`timeit.default_repeat`). + repeat = 10 + + def setup_benchmark(self, function, inputs, grad_outputs=None): + """Performs setup of benchmark for functions. + + Call this in `setup` method of your benchmark class. + Note that this function performs forward computatio + """ + self.function = function + + # Prepare for forward. + self.forward_inputs = ( + [None if x is None else chainer.Variable(x) for x in inputs]) + + # Prepare for backward. + outputs = chainer.functions.identity(self.forward()) + if isinstance(outputs, (list, tuple)): + self.forward_outputs = outputs + else: + self.forward_outputs = outputs, + + if grad_outputs is not None: + assert len(grad_outputs) == len(self.forward_outputs) + for i in range(len(grad_outputs)): + self.forward_outputs[i].grad = grad_outputs[i] + + def forward(self): + """Runs forward computation.""" + return self.function(*self.forward_inputs) + + def backward(self): + """Runs backward computation.""" + self.forward_outputs[0].backward() + + +class UnaryMathFunctionBenchmark(FunctionBenchmark): + + """The base class for benchmark of unary element-wise math functions. + + Unlike `FunctionBenchmark`, this class automatically generates inputs and + grads.""" + + def setup_benchmark( + self, function, shape=(1000, 1000), dtype=numpy.float32): + inputs = (self.xp.arange( + functools.reduce(operator.mul, shape), + dtype=dtype).reshape(shape) + 1,) + grad_outputs = (self.xp.array(inputs[0]),) + super(UnaryMathFunctionBenchmark, self).setup_benchmark( + function, inputs, grad_outputs) From 7d1ef8c60afdee434683d3f86585a35e03aa7cd4 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:06:05 +0900 Subject: [PATCH 05/15] add benchmarks for F.convolution_2d --- .../functions/connection/__init__.py | 0 .../functions/connection/convolution_2d.py | 35 +++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 benchmarks/benchmarks/functions/connection/__init__.py create mode 100644 benchmarks/benchmarks/functions/connection/convolution_2d.py diff --git a/benchmarks/benchmarks/functions/connection/__init__.py b/benchmarks/benchmarks/functions/connection/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/benchmarks/benchmarks/functions/connection/convolution_2d.py b/benchmarks/benchmarks/functions/connection/convolution_2d.py new file mode 100644 index 000000000000..54fbe571e4b2 --- /dev/null +++ b/benchmarks/benchmarks/functions/connection/convolution_2d.py @@ -0,0 +1,35 @@ +import chainer +import chainer.functions as F + +from benchmarks.functions import FunctionBenchmark +from benchmarks.utils import backends, config + + +@backends('gpu', 'gpu-cudnn', 'cpu', 'cpu-ideep') +class Convolution2D(FunctionBenchmark): + def setup(self): + xp = self.xp + + # Prepare test data. + batches = 128 + in_channels = 3 + out_channels = 64 + ih, iw = (128, 128) + kh, kw = (12, 12) + x = xp.random.uniform( + -1, 1, (batches, in_channels, ih, iw)).astype(xp.float32) + W = xp.random.normal( + 0, xp.sqrt(1. / (kh * kw * in_channels)), + (out_channels, in_channels, kh, kw)).astype(xp.float32) + b = xp.random.uniform(-1, 1, out_channels).astype(xp.float32) + gy = xp.random.uniform( + -1, 1, (batches, out_channels, 117, 117)).astype(xp.float32) + + # Setup benchmark. + self.setup_benchmark(F.convolution_2d, (x, W, b), (gy,)) + + def time_forward(self): + self.forward() + + def time_backward(self): + self.backward() From 0c3d56a5bd636605588a1fd814e4bd3c0b9abaf1 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:06:21 +0900 Subject: [PATCH 06/15] add benchmarks for F.sqrt and F.rsqrt --- .../benchmarks/functions/math/__init__.py | 0 benchmarks/benchmarks/functions/math/sqrt.py | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 benchmarks/benchmarks/functions/math/__init__.py create mode 100644 benchmarks/benchmarks/functions/math/sqrt.py diff --git a/benchmarks/benchmarks/functions/math/__init__.py b/benchmarks/benchmarks/functions/math/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/benchmarks/benchmarks/functions/math/sqrt.py b/benchmarks/benchmarks/functions/math/sqrt.py new file mode 100644 index 000000000000..97d67eaab166 --- /dev/null +++ b/benchmarks/benchmarks/functions/math/sqrt.py @@ -0,0 +1,30 @@ +import chainer.functions as F + +from benchmarks.utils import backends +from benchmarks.functions import UnaryMathFunctionBenchmark + + +@backends('gpu', 'cpu') +class SqrtFunc(UnaryMathFunctionBenchmark): + + def setup(self): + self.setup_benchmark(F.sqrt) + + def time_forward(self): + self.forward() + + def time_backward(self): + self.backward() + + +@backends('gpu', 'cpu') +class RsqrtFunc(UnaryMathFunctionBenchmark): + + def setup(self): + self.setup_benchmark(F.rsqrt) + + def time_forward(self): + self.forward() + + def time_backward(self): + self.backward() From a06c68ae1095aa38455c4c1dfcc013ed400ac1ce Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 15:24:36 +0900 Subject: [PATCH 07/15] flake8 --- benchmarks/benchmarks/convnet/benchmark.py | 9 ++------- benchmarks/benchmarks/convnet/nets/overfeat.py | 2 ++ benchmarks/benchmarks/convnet/nets/vgga.py | 2 ++ benchmarks/benchmarks/functions/__init__.py | 3 ++- .../benchmarks/functions/connection/convolution_2d.py | 3 +-- benchmarks/benchmarks/functions/math/sqrt.py | 2 +- benchmarks/benchmarks/utils/__init__.py | 2 +- benchmarks/benchmarks/utils/backend.py | 1 - 8 files changed, 11 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmarks/convnet/benchmark.py b/benchmarks/benchmarks/convnet/benchmark.py index 7dd2b3fc9068..f2c1b3012cb3 100644 --- a/benchmarks/benchmarks/convnet/benchmark.py +++ b/benchmarks/benchmarks/convnet/benchmark.py @@ -1,10 +1,6 @@ import chainer -from chainer import cuda from chainer import optimizers -import cupy -import numpy - from benchmarks import BenchmarkBase from benchmarks.utils import backends from benchmarks.utils import is_backend_gpu @@ -13,8 +9,7 @@ class _ConvnetBase(BenchmarkBase): - """ - Benchmark code from convnet-benchmark. + """Benchmark code from convnet-benchmark. https://github.com/soumith/convnet-benchmarks/tree/master/chainer """ @@ -56,7 +51,7 @@ def setup(self, arch, batchsize): chainer.config.train = True x = xp.ndarray((batchsize, 3, model.insize, - model.insize), dtype=xp.float32) + model.insize), dtype=xp.float32) x.fill(33333) if arch == 'googlenet': diff --git a/benchmarks/benchmarks/convnet/nets/overfeat.py b/benchmarks/benchmarks/convnet/nets/overfeat.py index ca36110f9028..757d14fe2f2e 100644 --- a/benchmarks/benchmarks/convnet/nets/overfeat.py +++ b/benchmarks/benchmarks/convnet/nets/overfeat.py @@ -1,3 +1,5 @@ +# flake8: noqa + import chainer import chainer.functions as F import chainer.links as L diff --git a/benchmarks/benchmarks/convnet/nets/vgga.py b/benchmarks/benchmarks/convnet/nets/vgga.py index adbb34928c28..9c043236f4cc 100644 --- a/benchmarks/benchmarks/convnet/nets/vgga.py +++ b/benchmarks/benchmarks/convnet/nets/vgga.py @@ -1,3 +1,5 @@ +# flake8: noqa + import chainer import chainer.functions as F import chainer.links as L diff --git a/benchmarks/benchmarks/functions/__init__.py b/benchmarks/benchmarks/functions/__init__.py index fe6e1183eece..45a4925c43af 100644 --- a/benchmarks/benchmarks/functions/__init__.py +++ b/benchmarks/benchmarks/functions/__init__.py @@ -56,7 +56,8 @@ class UnaryMathFunctionBenchmark(FunctionBenchmark): """The base class for benchmark of unary element-wise math functions. Unlike `FunctionBenchmark`, this class automatically generates inputs and - grads.""" + grads. + """ def setup_benchmark( self, function, shape=(1000, 1000), dtype=numpy.float32): diff --git a/benchmarks/benchmarks/functions/connection/convolution_2d.py b/benchmarks/benchmarks/functions/connection/convolution_2d.py index 54fbe571e4b2..96187ce9beb6 100644 --- a/benchmarks/benchmarks/functions/connection/convolution_2d.py +++ b/benchmarks/benchmarks/functions/connection/convolution_2d.py @@ -1,8 +1,7 @@ -import chainer import chainer.functions as F from benchmarks.functions import FunctionBenchmark -from benchmarks.utils import backends, config +from benchmarks.utils import backends @backends('gpu', 'gpu-cudnn', 'cpu', 'cpu-ideep') diff --git a/benchmarks/benchmarks/functions/math/sqrt.py b/benchmarks/benchmarks/functions/math/sqrt.py index 97d67eaab166..3df5cd56f86b 100644 --- a/benchmarks/benchmarks/functions/math/sqrt.py +++ b/benchmarks/benchmarks/functions/math/sqrt.py @@ -1,7 +1,7 @@ import chainer.functions as F -from benchmarks.utils import backends from benchmarks.functions import UnaryMathFunctionBenchmark +from benchmarks.utils import backends @backends('gpu', 'cpu') diff --git a/benchmarks/benchmarks/utils/__init__.py b/benchmarks/benchmarks/utils/__init__.py index b222ec6bb6f1..1e8d798c69a1 100644 --- a/benchmarks/benchmarks/utils/__init__.py +++ b/benchmarks/benchmarks/utils/__init__.py @@ -1,7 +1,7 @@ from benchmarks.utils.backend import backends # NOQA +from benchmarks.utils.backend import have_ideep # NOQA from benchmarks.utils.backend import is_backend_gpu # NOQA from benchmarks.utils.backend import is_backend_ideep # NOQA -from benchmarks.utils.backend import have_ideep # NOQA from benchmarks.utils.helper import parameterize # NOQA from benchmarks.utils.helper import sync # NOQA diff --git a/benchmarks/benchmarks/utils/backend.py b/benchmarks/benchmarks/utils/backend.py index 3c036cb62644..aeeb77edb0b6 100644 --- a/benchmarks/benchmarks/utils/backend.py +++ b/benchmarks/benchmarks/utils/backend.py @@ -1,6 +1,5 @@ from functools import wraps import inspect -import os import warnings import chainer From 155a2663ffb3c753b3a372cfb318f81f674f451d Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Mon, 12 Mar 2018 18:14:36 +0900 Subject: [PATCH 08/15] autopep8 --- benchmarks/benchmarks/convnet/nets/overfeat.py | 12 +++++------- benchmarks/benchmarks/convnet/nets/vgga.py | 6 ++---- benchmarks/benchmarks/utils/backend.py | 10 +++++----- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/benchmarks/benchmarks/convnet/nets/overfeat.py b/benchmarks/benchmarks/convnet/nets/overfeat.py index 757d14fe2f2e..7941cf9492d4 100644 --- a/benchmarks/benchmarks/convnet/nets/overfeat.py +++ b/benchmarks/benchmarks/convnet/nets/overfeat.py @@ -1,5 +1,3 @@ -# flake8: noqa - import chainer import chainer.functions as F import chainer.links as L @@ -11,11 +9,11 @@ class overfeat(chainer.Chain): def __init__(self): super(overfeat, self).__init__() with self.init_scope(): - self.conv1 = L.Convolution2D( 3, 96, 11, stride=4) - self.conv2 = L.Convolution2D( 96, 256, 5, pad=0) - self.conv3 = L.Convolution2D( 256, 512, 3, pad=1) - self.conv4 = L.Convolution2D( 512, 1024, 3, pad=1) - self.conv5 = L.Convolution2D(1024, 1024, 3, pad=1) + self.conv1 = L.Convolution2D(3, 96, 11, stride=4) + self.conv2 = L.Convolution2D(96, 256, 5, pad=0) + self.conv3 = L.Convolution2D(256, 512, 3, pad=1) + self.conv4 = L.Convolution2D(512, 1024, 3, pad=1) + self.conv5 = L.Convolution2D(1024, 1024, 3, pad=1) self.fc6 = L.Linear(1024 * 6 * 6, 3072) self.fc7 = L.Linear(3072, 4096) self.fc8 = L.Linear(4096, 1000) diff --git a/benchmarks/benchmarks/convnet/nets/vgga.py b/benchmarks/benchmarks/convnet/nets/vgga.py index 9c043236f4cc..3e9388ba515c 100644 --- a/benchmarks/benchmarks/convnet/nets/vgga.py +++ b/benchmarks/benchmarks/convnet/nets/vgga.py @@ -1,5 +1,3 @@ -# flake8: noqa - import chainer import chainer.functions as F import chainer.links as L @@ -11,8 +9,8 @@ class vgga(chainer.Chain): def __init__(self): super(vgga, self).__init__() with self.init_scope(): - self.conv1 = L.Convolution2D( 3, 64, 3, stride=1, pad=1) - self.conv2 = L.Convolution2D( 64, 128, 3, stride=1, pad=1) + self.conv1 = L.Convolution2D(3, 64, 3, stride=1, pad=1) + self.conv2 = L.Convolution2D(64, 128, 3, stride=1, pad=1) self.conv3 = L.Convolution2D(128, 256, 3, stride=1, pad=1) self.conv4 = L.Convolution2D(256, 256, 3, stride=1, pad=1) self.conv5 = L.Convolution2D(256, 512, 3, stride=1, pad=1) diff --git a/benchmarks/benchmarks/utils/backend.py b/benchmarks/benchmarks/utils/backend.py index aeeb77edb0b6..63268b788db9 100644 --- a/benchmarks/benchmarks/utils/backend.py +++ b/benchmarks/benchmarks/utils/backend.py @@ -113,11 +113,11 @@ def _wrapped_func(self, backend, *args, **kwargs): _benchmark_backend_ideep = True with _BackendConfig({ - 'use_cudnn': use_cudnn, - 'use_ideep': use_ideep, - '_benchmark_backend_gpu': _benchmark_backend_gpu, - '_benchmark_backend_ideep': _benchmark_backend_ideep, - }): + 'use_cudnn': use_cudnn, + 'use_ideep': use_ideep, + '_benchmark_backend_gpu': _benchmark_backend_gpu, + '_benchmark_backend_ideep': _benchmark_backend_ideep, + }): # Inject self.xp assert not hasattr(self, 'xp') From eec254f7f5d0af6ff547da0064ec05e82ce1257a Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Thu, 15 Mar 2018 18:28:29 +0900 Subject: [PATCH 09/15] add utility to find appropriate CuPy commit --- benchmarks/README.rst | 2 + benchmarks/find_cupy_version.py | 91 +++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100755 benchmarks/find_cupy_version.py diff --git a/benchmarks/README.rst b/benchmarks/README.rst index 29ce1f46d33e..9a68cef5b408 100644 --- a/benchmarks/README.rst +++ b/benchmarks/README.rst @@ -22,6 +22,8 @@ Usage # Run benchmark against target commit-ish of Chainer and CuPy. # Note that specified versions must be a compatible combination. + # You can use `find_cupy_version.py` helper tool to get appropriate CuPy + # version for the given Chainer version. ./run.sh master master ./run.sh v4.0.0b4 v4.0.0b4 diff --git a/benchmarks/find_cupy_version.py b/benchmarks/find_cupy_version.py new file mode 100755 index 000000000000..40e693c3a352 --- /dev/null +++ b/benchmarks/find_cupy_version.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python + +import argparse +import re +import sys +import subprocess + + +def _git(*args): + return subprocess.check_output(('git',) + args) + + +def get_cupy_commit_for(chainer_commit, cupy_branch, chainer_dir, cupy_dir): + """Returns CuPy commit required for the given Chainer commit.""" + + # Retrieve a commit time of the given Chainer commit. + commit_time = _git( + '-C', chainer_dir, + 'show', '--format=%ct', chainer_commit, '--') + + # Retrieve a CuPy commit made just before the commit_time, i.e., HEAD + # at the time of commit_time. + cupy_commit = _git( + '-C', cupy_dir, + 'log', '--merges', '--first-parent', '--max-count', '1', + '--until', commit_time, '--format=%H', cupy_branch, '--') + + return cupy_commit + + +def get_cupy_release_for(chainer_version): + """Returns CuPy version required for the given Chainer version.""" + + m = re.search(r'^v(\d)\.(.+)$', chainer_version) + if m is None: + raise ValueError(chainer_version) + + chainer_major = int(m.group(1)) + chainer_rest = m.group(2) + if chainer_major <= 1: + raise ValueError('Chainer v1 or earlier is unsupported') + elif 2 <= chainer_major <= 3: + # Chainer vN requires CuPy v(N-1). + return 'v{}.{}'.format((chainer_major - 1), chainer_rest) + else: + # The same versioning as Chainer. + return chainer_version + + +def parse_args(args): + parser = argparse.ArgumentParser() + + # Find CuPy commit from Chainer commit. + parser.add_argument( + '--commit', type=str, default=None, + help='Chainer commit') + parser.add_argument( + '--cupy-branch', type=str, + help='CuPy branch to find commit') + parser.add_argument( + '--chainer', type=str, default='chainer', + help='Chainer source tree (default: chainer)') + parser.add_argument( + '--cupy', type=str, default='cupy', + help='CuPy source tree (default: cupy)') + + # Find CuPy version (tag) from Chainer version. + parser.add_argument( + '--release', type=str, default=None, + help='Chainer release version') + + return parser.parse_args(args) + + +def main(args): + params = parse_args(args[1:]) + if params.commit is not None: + assert params.cupy_branch is not None + assert params.release is None + print(get_cupy_commit_for( + params.commit, params.cupy_branch, params.chainer, params.cupy)) + elif params.release is not None: + print(get_cupy_release_for(params.release)) + else: + print('either --commit nor --release must be specified') + return 1 + return 0 + + +if __name__ == '__main__': + sys.exit(main(sys.argv)) From 27c853af9ca33cc2d8a42f04becd815282797129 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Thu, 15 Mar 2018 18:28:43 +0900 Subject: [PATCH 10/15] improve performance of CuPy build --- benchmarks/run.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/benchmarks/run.sh b/benchmarks/run.sh index 0ba482f788d2..244835ee6fd2 100755 --- a/benchmarks/run.sh +++ b/benchmarks/run.sh @@ -17,9 +17,14 @@ function run_asv() { # version used in the benchmark virtualenv. pushd cupy git remote update - git clean -fdx git checkout "$(git show --format="%H" ${CUPY_COMMIT})" - python setup.py build_ext --inplace + + # First try without git clean to use build cache as much as possible. + # If failed, rebuild it after git clean. + BUILD_COMMAND="python setup.py build_ext --inplace" + ${BUILD_COMMAND} || ( git clean -fdx && ${BUILD_COMMAND} ) + python -c 'import cupy; import cupy.cudnn' || ( git clean -fdx && ${BUILD_COMMAND} ) + export PYTHONPATH="${PWD}:${PYTHONPATH:-}" popd From 46557542d731acc66e1a78d2c9faeff9dcee35da Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Thu, 15 Mar 2018 18:37:45 +0900 Subject: [PATCH 11/15] use CuPy for environment check --- benchmarks/benchmarks/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks/__init__.py b/benchmarks/benchmarks/__init__.py index 5b33179d2911..efada9bc9bab 100644 --- a/benchmarks/benchmarks/__init__.py +++ b/benchmarks/benchmarks/__init__.py @@ -1,11 +1,8 @@ import inspect -import chainer - - # Ensure that CuPy and cuDNN are available. -assert chainer.cuda.available -assert chainer.cuda.cudnn_enabled +import cupy # NOQA +import cupy.cudnn # NOQA class BenchmarkBase(object): From 1e52ff3da8c4af13428eb3961fd40d1ec7141da8 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Fri, 16 Mar 2018 12:54:43 +0900 Subject: [PATCH 12/15] add CuPy dependencies --- benchmarks/asv.conf.json | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 06f53473e029..270b5d5bb19b 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -62,6 +62,10 @@ // "pip+emcee": [""], // emcee is only available for install with pip. // }, "matrix": { + // CuPy dependencies. + "numpy": [], + "six": [], + "fastrlock": [], // Optional dependencies required for benchmark. "ideep4py": [], }, From c4658c851743435977a0dacd7e885428946b9b98 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Fri, 16 Mar 2018 14:20:07 +0900 Subject: [PATCH 13/15] extend timeout for benchmarks --- benchmarks/benchmarks/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/benchmarks/__init__.py b/benchmarks/benchmarks/__init__.py index efada9bc9bab..a009da036b60 100644 --- a/benchmarks/benchmarks/__init__.py +++ b/benchmarks/benchmarks/__init__.py @@ -11,6 +11,9 @@ class BenchmarkBase(object): See also: http://asv.readthedocs.io/en/v0.2.1/writing_benchmarks.html """ + # Allow up to 10 minutes, instead of the default (60 seconds). + timeout = 600 + def __init__(self, *args, **kwargs): # Set pretty_name to ``.`` instead of the default # ``..``. This is because it is often too From f04c3b0e6c1a09f3b8e986637eb206e90cfcb30b Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Thu, 15 Mar 2018 19:45:36 +0900 Subject: [PATCH 14/15] add Dockerfile --- benchmarks/README.rst | 13 +++++++++++++ benchmarks/docker/Dockerfile | 7 +++++++ 2 files changed, 20 insertions(+) create mode 100644 benchmarks/docker/Dockerfile diff --git a/benchmarks/README.rst b/benchmarks/README.rst index 9a68cef5b408..adcbb5fbe0de 100644 --- a/benchmarks/README.rst +++ b/benchmarks/README.rst @@ -38,3 +38,16 @@ Usage # Start the HTTP server to browse HTML. asv preview + +Alternatively you can use Docker. + +.. code-block:: sh + + # Build docker image for benchmark. + docker build -t chainer-benchmark docker + + # Create a machine configuration file (`.asv-machine.json`) in this directory (first time only). + nvidia-docker run --rm -it -u ${UID}:${GID} -v ${PWD}:/benchmarks -w /benchmarks -e HOME=/benchmarks chainer-benchmark asv machine --machine $(hostname) + + # Run benchmark. + nvidia-docker run --rm -it -u ${UID}:${GID} -v ${PWD}:/benchmarks -w /benchmarks -e HOME=/benchmarks chainer-benchmark ./run.sh master master --machine $(hostname) diff --git a/benchmarks/docker/Dockerfile b/benchmarks/docker/Dockerfile new file mode 100644 index 000000000000..c1e833df229c --- /dev/null +++ b/benchmarks/docker/Dockerfile @@ -0,0 +1,7 @@ +FROM chainer/chainer:latest + +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +RUN pip install asv virtualenv Cython From 3ca056a1df4b820103f07752e8c911e92a3bba80 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Fri, 16 Mar 2018 12:54:59 +0900 Subject: [PATCH 15/15] add .asv-machine.json to gitignore for docker run --- benchmarks/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore index 491fb7d04081..6894dcc3e502 100644 --- a/benchmarks/.gitignore +++ b/benchmarks/.gitignore @@ -3,3 +3,4 @@ results/ env/ chainer/ cupy/ +.asv-machine.json