Skip to content

Commit

Permalink
Merge pull request janelia-flyem#82 from jni/speedup
Browse files Browse the repository at this point in the history
Speed up RAG building and agglomeration:
- 30x speedup in RAG building
- 5x speedup in flat learning
- 6x speedup in agglomerative learning
- 7x speedup in test segmentation
- And 30% reduction in RAM usage by the RAG.

Key changes:
- New fast RAG building method based on dilation and erosion of labels.
- New row-expandable CSR format for the contingency matrix.
- New sparselol-based boundary extents.
- New fast testing framework.
- New benchmarking file.
- Don't update edges when they are unchanged by a merge.
  • Loading branch information
jni committed Jun 27, 2016
2 parents ebc25b7 + a1c5261 commit a55ceaa
Show file tree
Hide file tree
Showing 22 changed files with 1,145 additions and 172 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ virtualenv:
system_site_packages: false
install:
# all installing is now handled by conda as it is faster and more robust
- wget http://repo.continuum.io/miniconda/Miniconda-3.4.2-Linux-x86_64.sh -O miniconda.sh;
- wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- hash -r
Expand Down
134 changes: 134 additions & 0 deletions asv.conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
{
// The version of the config file format. Do not change, unless
// you know what you are doing.
"version": 1,

// The name of the project being benchmarked
"project": "gala",

// The project's homepage
"project_url": "http://gala.readthedocs.io/",

// The URL or local path of the source code repository for the
// project being benchmarked
"repo": ".",

// List of branches to benchmark. If not provided, defaults to "master"
// (for git) or "tip" (for mercurial).
"branches": ["master", "speedup"], // for git
// "branches": ["tip"], // for mercurial

// The DVCS being used. If not set, it will be automatically
// determined from "repo" by looking at the protocol in the URL
// (if remote), or by looking for special directories, such as
// ".git" (if local).
// "dvcs": "git",

// The tool to use to create environments. May be "conda",
// "virtualenv" or other value depending on the plugins in use.
// If missing or the empty string, the tool will be automatically
// determined by looking for tools on the PATH environment
// variable.
"environment_type": "conda",

// the base URL to show a commit for the project.
"show_commit_url": "http://github.com/jni/gala/commit/",

// The Pythons you'd like to test against. If not provided, defaults
// to the current version of Python used to run `asv`.
// "pythons": ["2.7", "3.3"],

// The matrix of dependencies to test. Each key is the name of a
// package (in PyPI) and the values are version numbers. An empty
// list or empty string indicates to just test against the default
// (latest) version. null indicates that the package is to not be
// installed.
//
"matrix": {
"environment_type": "conda",
"python": "3.5",
"numpy": [],
"scipy": [],
"pip+numpydoc": [],
"networkx": [],
"h5py": [],
"cython": [],
"pip+viridis": [],
"pyzmq": [],
"scikit-learn": [],
"scikit-image": [],
"pytest": [],
"setuptools": [],
"coverage": []
},

// Combinations of libraries/python versions can be excluded/included
// from the set to test. Each entry is a dictionary containing additional
// key-value pairs to include/exclude.
//
// An exclude entry excludes entries where all values match. The
// values are regexps that should match the whole string.
//
// An include entry adds an environment. Only the packages listed
// are installed. The 'python' key is required. The exclude rules
// do not apply to includes.
//
// In addition to package names, the following keys are available:
//
// - python
// Python version, as in the *pythons* variable above.
// - environment_type
// Environment type, as above.
// - sys_platform
// Platform, as in sys.platform. Possible values for the common
// cases: 'linux2', 'win32', 'cygwin', 'darwin'.
//
// "exclude": [
// {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows
// {"environment_type": "conda", "six": null}, // don't run without six on conda
// ],
//
// "include": [
// // additional env for python2.7
// {"python": "2.7", "numpy": "1.8"},
// // additional env if run on windows+conda
// {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""},
// ],

// The directory (relative to the current directory) that benchmarks are
// stored in. If not provided, defaults to "benchmarks"
// "benchmark_dir": "benchmarks",

// The directory (relative to the current directory) to cache the Python
// environments in. If not provided, defaults to "env"
"env_dir": ".asv/env",

// The directory (relative to the current directory) that raw benchmark
// results are stored in. If not provided, defaults to "results".
"results_dir": ".asv/results",

// The directory (relative to the current directory) that the html tree
// should be written to. If not provided, defaults to "html".
"html_dir": ".asv/html"

// The number of characters to retain in the commit hashes.
// "hash_length": 8,

// `asv` will cache wheels of the recent builds in each
// environment, making them faster to install next time. This is
// number of builds to keep, per environment.
// "wheel_cache_size": 0

// The commits after which the regression search in `asv publish`
// should start looking for regressions. Dictionary whose keys are
// regexps matching to benchmark names, and values corresponding to
// the commit (exclusive) after which to start looking for
// regressions. The default is to start from the first commit
// with results. If the commit is `null`, regression detection is
// skipped for the matching benchmark.
//
// "regressions_first_commits": {
// "some_benchmark": "352cdf", // Consider regressions only after this commit
// "another_benchmark": null, // Skip regression detection altogether
// }
}
136 changes: 136 additions & 0 deletions benchmarks/bench_gala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os

from contextlib import contextmanager
from collections import OrderedDict

import numpy as np

from gala import imio, features, agglo, classify
from asv.extern.asizeof import asizeof


rundir = os.path.dirname(__file__)
## dd: the data directory
dd = os.path.abspath(os.path.join(rundir, '../tests/example-data'))


from time import process_time


@contextmanager
def timer():
time = []
t0 = process_time()
yield time
t1 = process_time()
time.append(t1 - t0)


em = features.default.paper_em()


def trdata():
wstr = imio.read_h5_stack(os.path.join(dd, 'train-ws.lzf.h5'))
prtr = imio.read_h5_stack(os.path.join(dd, 'train-p1.lzf.h5'))
gttr = imio.read_h5_stack(os.path.join(dd, 'train-gt.lzf.h5'))
return wstr, prtr, gttr


def tsdata():
wsts = imio.read_h5_stack(os.path.join(dd, 'test-ws.lzf.h5'))
prts = imio.read_h5_stack(os.path.join(dd, 'test-p1.lzf.h5'))
gtts = imio.read_h5_stack(os.path.join(dd, 'test-gt.lzf.h5'))
return wsts, prts, gtts


def trgraph():
ws, pr, ts = trdata()
g = agglo.Rag(ws, pr)
return g


def tsgraph():
ws, pr, ts = tsdata()
g = agglo.Rag(ws, pr, feature_manager=em)
return g


def trexamples():
gt = imio.read_h5_stack(os.path.join(dd, 'train-gt.lzf.h5'))
g = trgraph()
(X, y, w, e), _ = g.learn_agglomerate(gt, em, min_num_epochs=5)
y = y[:, 0]
return X, y


def classifier():
X, y = trexamples()
rf = classify.DefaultRandomForest()
rf.fit(X, y)
return rf


def policy():
rf = classify.DefaultRandomForest()
cl = agglo.classifier_probability(em, rf)
return cl


def tsgraph_queue():
g = tsgraph()
cl = policy()
g.merge_priority_function = cl
g.rebuild_merge_queue()
return g

def bench_suite():
times = OrderedDict()
memory = OrderedDict()
wstr, prtr, gttr = trdata()
with timer() as t_build_rag:
g = agglo.Rag(wstr, prtr)
times['build RAG'] = t_build_rag[0]
memory['base RAG'] = asizeof(g)
with timer() as t_features:
g.set_feature_manager(em)
times['build feature caches'] = t_features[0]
memory['feature caches'] = asizeof(g) - memory['base RAG']
with timer() as t_flat:
_ignore = g.learn_flat(gttr, em)
times['learn flat'] = t_flat[0]
with timer() as t_gala:
(X, y, w, e), allepochs = g.learn_agglomerate(gttr, em,
min_num_epochs=5)
y = y[:, 0] # ignore rand-sign and vi-sign schemes
memory['training data'] = asizeof((X, y, w, e))
times['learn agglo'] = t_gala[0]
with timer() as t_train_classifier:
cl = classify.DefaultRandomForest()
cl.fit(X, y)
times['classifier training'] = t_train_classifier[0]
memory['classifier training'] = asizeof(cl)
policy = agglo.classifier_probability(em, cl)
wsts, prts, gtts = tsdata()
gtest = agglo.Rag(wsts, prts, merge_priority_function=policy,
feature_manager=em)
with timer() as t_segment:
gtest.agglomerate(np.inf)
times['segment test volume'] = t_segment[0]
memory['segment test volume'] = asizeof(gtest)
return times, memory


def print_bench_results(times=None, memory=None):
if times is not None:
print('Timing results:')
for key in times:
print('--- ', key, times[key])
if memory is not None:
print('Memory results:')
for key in memory:
print('--- ', key, '%.3f MB' % (memory[key] / 1e6))


if __name__ == '__main__':
times, memory = bench_suite()
print_bench_results(times, memory)
Loading

0 comments on commit a55ceaa

Please sign in to comment.