Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added src and process_data folders containing python scripts for repr…
…oducing experiments in the Mondrian Forests paper
- Loading branch information
Showing
12 changed files
with
2,101 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
------------------------------------------------------------------------------- | ||
The standard MIT License for code in this archive written by Balaji Lakshminarayanan | ||
http://www.opensource.org/licenses/mit-license.php | ||
------------------------------------------------------------------------------- | ||
Copyright (c) 2014 Balaji Lakshminarayanan | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to | ||
deal in the Software without restriction, including without limitation the | ||
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||
IN THE SOFTWARE. | ||
------------------------------------------------------------------------------- | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,71 @@ | ||
mondrianforest | ||
============== | ||
This folder contains the scripts used in the following paper: | ||
|
||
Code for "Mondrian Forests: Efficient Online Random Forests" | ||
**Mondrian Forests: Efficient Online Random Forests** | ||
|
||
Balaji Lakshminarayanan, Daniel M. Roy, Yee Whye Teh | ||
|
||
[http://arxiv.org/abs/1406.2673](http://arxiv.org/abs/1406.2673) | ||
|
||
Please cite the above paper if you use this code. | ||
|
||
|
||
|
||
I ran my experiments using Enthought python (which includes all the necessary python packages). | ||
If you are running a different version of python, you will need the following python packages | ||
(and possibly other packages) to run the scripts: | ||
|
||
* numpy | ||
* scipy | ||
* matplotlib (for plotting Mondrian partitions) | ||
* pydot and graphviz (for printing Mondrian trees) | ||
* sklearn (for reading libsvm format files) | ||
|
||
|
||
The datasets are not included here; you need to download them from the UCI repository. You can run | ||
experiments using toy data though. Run **commands.sh** in **process_data** folder for automatically | ||
downloading and processing the datasets. I have tested these scripts only on Ubuntu, but it should be straightforward to process datasets in other platforms. | ||
|
||
If you have any questions/comments/suggestions, please contact me at | ||
[balaji@gatsby.ucl.ac.uk](mailto:balaji@gatsby.ucl.ac.uk). | ||
|
||
Code released under MIT license (see COPYING for more info). | ||
|
||
Copyright © 2014 Balaji Lakshminarayanan | ||
|
||
---------------------------------------------------------------------------- | ||
|
||
**List of scripts in the src folder**: | ||
|
||
- mondrianforest.py | ||
- mondrianforest_utils.py | ||
- utils.py | ||
|
||
Help on usage can be obtained by typing the following commands on the terminal: | ||
|
||
./mondrianforest.py -h | ||
|
||
**Example usage**: | ||
|
||
./mondrianforest.py --dataset toy-mf --n_mondrians 100 --budget -1 --normalize_features 1 | ||
|
||
**Examples that draw the Mondrian partition and Mondrian tree**: | ||
|
||
./mondrianforest.py --draw_mondrian 1 --save 1 --n_mondrians 10 --dataset toy-mf --store_every 1 --n_mini 6 --tag demo | ||
./mondrianforest.py --draw_mondrian 1 --save 1 --n_mondrians 1 --dataset toy-mf --store_every 1 --n_mini 6 --tag demo | ||
|
||
**Example on a real-world dataset**: | ||
|
||
*assuming you have successfully run commands.sh in process_data folder* | ||
|
||
./mondrianforest.py --dataset satimage --n_mondrians 100 --budget -1 --normalize_features 1 --save 1 --data_path ../process_data/ --n_minibatches 10 --store_every 1 | ||
|
||
---------------------------------------------------------------------------- | ||
|
||
I generated commands for parameter sweeps using 'build_cmds' script by Jan Gasthaus, available publicly at [https://github.com/jgasthaus/Gitsby/tree/master/pbs/python](https://github.com/jgasthaus/Gitsby/tree/master/pbs/python). | ||
|
||
Some examples of parameter sweeps are: | ||
|
||
./build_cmds ./mondrianforest.py "--op_dir={results}" "--init_id=1:1:6" "--dataset={letter,satimage,usps,dna,dna-61-120}" "--n_mondrians={100}" "--save={1}" "--discount_factor={10.0}" "--budget={-1}" "--n_minibatches={100}" "--bagging={0}" "--store_every={1}" "--normalize_features={1}" "--data_path={../process_data/}" >> run | ||
|
||
Note that the results (predictions, accuracy, log predictive probability on training/test data, runtimes) are stored in the pickle files. | ||
You need to write additional scripts to aggregate the results from these pickle files and generate the plots. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/usr/bin/env bash | ||
# script to download datasets for pre-process them | ||
|
||
# download datasets | ||
for dir in 'usps' 'dna' 'satimage' 'letter'; do | ||
echo ${dir} | ||
cd ${dir} | ||
./commands.sh | ||
cd .. | ||
done | ||
|
||
# process libsvm datasets | ||
./process_libsvm_datasets.py usps usps usps.t | ||
./process_libsvm_datasets.py dna dna.scale.tr dna.scale.t # ignoring validation files | ||
./process_libsvm_datasets.py satimage satimage.scale.tr satimage.scale.t # ignoring validation files | ||
# ./process_libsvm_datasets.py letter letter.scale.tr letter.scale.t # ignoring validation files | ||
./process_libsvm_datasets.py letter letter.scale.tr letter.scale.t letter.scale.val # adding validation to training (to make comparisons with OnlineRF by Saffari et al.) | ||
cd dna-61-120 | ||
./process_dna-61-120.py | ||
cd .. | ||
./convert_pickle_2_onlinerf_format.py dna-61-120 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/usr/bin/env python | ||
# input dataset is pickle file (to retain same train/test split) | ||
# example: ./convert_pickle_2_onlinerf_format.py dna-61-120 | ||
# | ||
# Output is slightly different from LIBSVM format | ||
# - need to add header of #Samples #Features #Classes #FeatureMinIndex to the files | ||
# - class indices need to start from 0 | ||
|
||
import sys | ||
import os | ||
import cPickle as pickle | ||
from itertools import izip | ||
|
||
name = sys.argv[1] | ||
|
||
data = pickle.load(open(name + '/' + name + '.p', 'rb')) | ||
|
||
d_n_dim = {'magic04': 10, 'pendigits': 16, 'dna-61-120': 60} | ||
d_n_class = {'magic04': 2, 'pendigits': 10, 'dna-61-120': 3} | ||
|
||
feat_id_start = 1 | ||
|
||
def print_file(x, y, name, op_name): | ||
op = open(name + '/' + op_name, 'w') | ||
print>>op, '%s %s %s %s' % (len(y), d_n_dim[name], d_n_class[name], feat_id_start) | ||
for x_, y_ in izip(x, y): | ||
s = ' '.join(['%d:%f' % (i+1, x__) for i, x__ in enumerate(x_)]) | ||
print>>op, '%s %s' % (y_, s) | ||
|
||
print_file(data['x_train'], data['y_train'], name, name + '.orf.train') | ||
print_file(data['x_test'], data['y_test'], name, name + '.orf.test') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env bash | ||
|
||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/dna.scale.tr | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/dna.scale.t | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/dna.scale.val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env bash | ||
|
||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.t | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.val | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.tr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/usr/bin/env python | ||
# usage: ./process_libsvm_datasets.py name filename_train filename_test filename_val | ||
# name is the folder name as well as the pickled filename | ||
# filename_train and filename_test are assumed to exist in the folder name | ||
# filename_val is optional (validation dataset will be added to training data if supplied) | ||
|
||
import sys | ||
import numpy as np | ||
import cPickle as pickle | ||
from sklearn.datasets import load_svmlight_file | ||
|
||
def get_x_y(filename, name): | ||
(x, y) = load_svmlight_file(filename) | ||
x = x.toarray() | ||
if name == 'pendigits-svm': | ||
x = x[:, 1:] | ||
y = y.astype('int') | ||
if name != 'pendigits-svm': | ||
y -= 1 | ||
return (x, y) | ||
|
||
name, filename_train, filename_test = sys.argv[1:4] | ||
|
||
data = {} | ||
|
||
x, y = get_x_y(name + '/' + filename_train, name) | ||
data['x_train'] = x | ||
data['y_train'] = y | ||
if len(sys.argv) > 4: | ||
filename_val = sys.argv[4] | ||
x, y = get_x_y(name + '/' + filename_val, name) | ||
data['x_train'] = np.vstack((data['x_train'], x)) | ||
data['y_train'] = np.append(data['y_train'], y) | ||
data['n_train'] = data['x_train'].shape[0] | ||
assert len(data['y_train']) == data['n_train'] | ||
|
||
x, y = get_x_y(name + '/' + filename_test, name) | ||
data['x_test'] = x | ||
data['n_test'] = x.shape[0] | ||
data['y_test'] = y | ||
|
||
data['n_dim'] = x.shape[1] | ||
data['n_class'] = len(np.unique(y)) | ||
try: | ||
assert data['n_class'] == max(np.unique(y)) + 1 | ||
except AssertionError: | ||
print 'np.unique(y) = %s' % np.unique(y) | ||
raise AssertionError | ||
data['is_sparse'] = False | ||
|
||
print 'name = %10s, n_dim = %5d, n_class = %5d, n_train = %5d, n_test = %5d' \ | ||
% (name, data['n_dim'], data['n_class'], data['n_train'], data['n_test']) | ||
|
||
pickle.dump(data, open(name + '/' + name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env bash | ||
|
||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.t | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.tr | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
|
||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2 | ||
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2 | ||
bzip2 -d usps.bz2 | ||
bzip2 -d usps.t.bz2 |
Oops, something went wrong.