forked from janelia-flyem/gala
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request janelia-flyem#82 from jni/speedup
Speed up RAG building and agglomeration: - 30x speedup in RAG building - 5x speedup in flat learning - 6x speedup in agglomerative learning - 7x speedup in test segmentation - And 30% reduction in RAM usage by the RAG. Key changes: - New fast RAG building method based on dilation and erosion of labels. - New row-expandable CSR format for the contingency matrix. - New sparselol-based boundary extents. - New fast testing framework. - New benchmarking file. - Don't update edges when they are unchanged by a merge.
- Loading branch information
Showing
22 changed files
with
1,145 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
{ | ||
// The version of the config file format. Do not change, unless | ||
// you know what you are doing. | ||
"version": 1, | ||
|
||
// The name of the project being benchmarked | ||
"project": "gala", | ||
|
||
// The project's homepage | ||
"project_url": "http://gala.readthedocs.io/", | ||
|
||
// The URL or local path of the source code repository for the | ||
// project being benchmarked | ||
"repo": ".", | ||
|
||
// List of branches to benchmark. If not provided, defaults to "master" | ||
// (for git) or "tip" (for mercurial). | ||
"branches": ["master", "speedup"], // for git | ||
// "branches": ["tip"], // for mercurial | ||
|
||
// The DVCS being used. If not set, it will be automatically | ||
// determined from "repo" by looking at the protocol in the URL | ||
// (if remote), or by looking for special directories, such as | ||
// ".git" (if local). | ||
// "dvcs": "git", | ||
|
||
// The tool to use to create environments. May be "conda", | ||
// "virtualenv" or other value depending on the plugins in use. | ||
// If missing or the empty string, the tool will be automatically | ||
// determined by looking for tools on the PATH environment | ||
// variable. | ||
"environment_type": "conda", | ||
|
||
// the base URL to show a commit for the project. | ||
"show_commit_url": "http://github.com/jni/gala/commit/", | ||
|
||
// The Pythons you'd like to test against. If not provided, defaults | ||
// to the current version of Python used to run `asv`. | ||
// "pythons": ["2.7", "3.3"], | ||
|
||
// The matrix of dependencies to test. Each key is the name of a | ||
// package (in PyPI) and the values are version numbers. An empty | ||
// list or empty string indicates to just test against the default | ||
// (latest) version. null indicates that the package is to not be | ||
// installed. | ||
// | ||
"matrix": { | ||
"environment_type": "conda", | ||
"python": "3.5", | ||
"numpy": [], | ||
"scipy": [], | ||
"pip+numpydoc": [], | ||
"networkx": [], | ||
"h5py": [], | ||
"cython": [], | ||
"pip+viridis": [], | ||
"pyzmq": [], | ||
"scikit-learn": [], | ||
"scikit-image": [], | ||
"pytest": [], | ||
"setuptools": [], | ||
"coverage": [] | ||
}, | ||
|
||
// Combinations of libraries/python versions can be excluded/included | ||
// from the set to test. Each entry is a dictionary containing additional | ||
// key-value pairs to include/exclude. | ||
// | ||
// An exclude entry excludes entries where all values match. The | ||
// values are regexps that should match the whole string. | ||
// | ||
// An include entry adds an environment. Only the packages listed | ||
// are installed. The 'python' key is required. The exclude rules | ||
// do not apply to includes. | ||
// | ||
// In addition to package names, the following keys are available: | ||
// | ||
// - python | ||
// Python version, as in the *pythons* variable above. | ||
// - environment_type | ||
// Environment type, as above. | ||
// - sys_platform | ||
// Platform, as in sys.platform. Possible values for the common | ||
// cases: 'linux2', 'win32', 'cygwin', 'darwin'. | ||
// | ||
// "exclude": [ | ||
// {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows | ||
// {"environment_type": "conda", "six": null}, // don't run without six on conda | ||
// ], | ||
// | ||
// "include": [ | ||
// // additional env for python2.7 | ||
// {"python": "2.7", "numpy": "1.8"}, | ||
// // additional env if run on windows+conda | ||
// {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""}, | ||
// ], | ||
|
||
// The directory (relative to the current directory) that benchmarks are | ||
// stored in. If not provided, defaults to "benchmarks" | ||
// "benchmark_dir": "benchmarks", | ||
|
||
// The directory (relative to the current directory) to cache the Python | ||
// environments in. If not provided, defaults to "env" | ||
"env_dir": ".asv/env", | ||
|
||
// The directory (relative to the current directory) that raw benchmark | ||
// results are stored in. If not provided, defaults to "results". | ||
"results_dir": ".asv/results", | ||
|
||
// The directory (relative to the current directory) that the html tree | ||
// should be written to. If not provided, defaults to "html". | ||
"html_dir": ".asv/html" | ||
|
||
// The number of characters to retain in the commit hashes. | ||
// "hash_length": 8, | ||
|
||
// `asv` will cache wheels of the recent builds in each | ||
// environment, making them faster to install next time. This is | ||
// number of builds to keep, per environment. | ||
// "wheel_cache_size": 0 | ||
|
||
// The commits after which the regression search in `asv publish` | ||
// should start looking for regressions. Dictionary whose keys are | ||
// regexps matching to benchmark names, and values corresponding to | ||
// the commit (exclusive) after which to start looking for | ||
// regressions. The default is to start from the first commit | ||
// with results. If the commit is `null`, regression detection is | ||
// skipped for the matching benchmark. | ||
// | ||
// "regressions_first_commits": { | ||
// "some_benchmark": "352cdf", // Consider regressions only after this commit | ||
// "another_benchmark": null, // Skip regression detection altogether | ||
// } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import os | ||
|
||
from contextlib import contextmanager | ||
from collections import OrderedDict | ||
|
||
import numpy as np | ||
|
||
from gala import imio, features, agglo, classify | ||
from asv.extern.asizeof import asizeof | ||
|
||
|
||
rundir = os.path.dirname(__file__) | ||
## dd: the data directory | ||
dd = os.path.abspath(os.path.join(rundir, '../tests/example-data')) | ||
|
||
|
||
from time import process_time | ||
|
||
|
||
@contextmanager | ||
def timer(): | ||
time = [] | ||
t0 = process_time() | ||
yield time | ||
t1 = process_time() | ||
time.append(t1 - t0) | ||
|
||
|
||
em = features.default.paper_em() | ||
|
||
|
||
def trdata(): | ||
wstr = imio.read_h5_stack(os.path.join(dd, 'train-ws.lzf.h5')) | ||
prtr = imio.read_h5_stack(os.path.join(dd, 'train-p1.lzf.h5')) | ||
gttr = imio.read_h5_stack(os.path.join(dd, 'train-gt.lzf.h5')) | ||
return wstr, prtr, gttr | ||
|
||
|
||
def tsdata(): | ||
wsts = imio.read_h5_stack(os.path.join(dd, 'test-ws.lzf.h5')) | ||
prts = imio.read_h5_stack(os.path.join(dd, 'test-p1.lzf.h5')) | ||
gtts = imio.read_h5_stack(os.path.join(dd, 'test-gt.lzf.h5')) | ||
return wsts, prts, gtts | ||
|
||
|
||
def trgraph(): | ||
ws, pr, ts = trdata() | ||
g = agglo.Rag(ws, pr) | ||
return g | ||
|
||
|
||
def tsgraph(): | ||
ws, pr, ts = tsdata() | ||
g = agglo.Rag(ws, pr, feature_manager=em) | ||
return g | ||
|
||
|
||
def trexamples(): | ||
gt = imio.read_h5_stack(os.path.join(dd, 'train-gt.lzf.h5')) | ||
g = trgraph() | ||
(X, y, w, e), _ = g.learn_agglomerate(gt, em, min_num_epochs=5) | ||
y = y[:, 0] | ||
return X, y | ||
|
||
|
||
def classifier(): | ||
X, y = trexamples() | ||
rf = classify.DefaultRandomForest() | ||
rf.fit(X, y) | ||
return rf | ||
|
||
|
||
def policy(): | ||
rf = classify.DefaultRandomForest() | ||
cl = agglo.classifier_probability(em, rf) | ||
return cl | ||
|
||
|
||
def tsgraph_queue(): | ||
g = tsgraph() | ||
cl = policy() | ||
g.merge_priority_function = cl | ||
g.rebuild_merge_queue() | ||
return g | ||
|
||
def bench_suite(): | ||
times = OrderedDict() | ||
memory = OrderedDict() | ||
wstr, prtr, gttr = trdata() | ||
with timer() as t_build_rag: | ||
g = agglo.Rag(wstr, prtr) | ||
times['build RAG'] = t_build_rag[0] | ||
memory['base RAG'] = asizeof(g) | ||
with timer() as t_features: | ||
g.set_feature_manager(em) | ||
times['build feature caches'] = t_features[0] | ||
memory['feature caches'] = asizeof(g) - memory['base RAG'] | ||
with timer() as t_flat: | ||
_ignore = g.learn_flat(gttr, em) | ||
times['learn flat'] = t_flat[0] | ||
with timer() as t_gala: | ||
(X, y, w, e), allepochs = g.learn_agglomerate(gttr, em, | ||
min_num_epochs=5) | ||
y = y[:, 0] # ignore rand-sign and vi-sign schemes | ||
memory['training data'] = asizeof((X, y, w, e)) | ||
times['learn agglo'] = t_gala[0] | ||
with timer() as t_train_classifier: | ||
cl = classify.DefaultRandomForest() | ||
cl.fit(X, y) | ||
times['classifier training'] = t_train_classifier[0] | ||
memory['classifier training'] = asizeof(cl) | ||
policy = agglo.classifier_probability(em, cl) | ||
wsts, prts, gtts = tsdata() | ||
gtest = agglo.Rag(wsts, prts, merge_priority_function=policy, | ||
feature_manager=em) | ||
with timer() as t_segment: | ||
gtest.agglomerate(np.inf) | ||
times['segment test volume'] = t_segment[0] | ||
memory['segment test volume'] = asizeof(gtest) | ||
return times, memory | ||
|
||
|
||
def print_bench_results(times=None, memory=None): | ||
if times is not None: | ||
print('Timing results:') | ||
for key in times: | ||
print('--- ', key, times[key]) | ||
if memory is not None: | ||
print('Memory results:') | ||
for key in memory: | ||
print('--- ', key, '%.3f MB' % (memory[key] / 1e6)) | ||
|
||
|
||
if __name__ == '__main__': | ||
times, memory = bench_suite() | ||
print_bench_results(times, memory) |
Oops, something went wrong.