Skip to content

Commit

Permalink
added src and process_data folders containing python scripts for repr…
Browse files Browse the repository at this point in the history
…oducing experiments in the Mondrian Forests paper
  • Loading branch information
balajiln committed Jul 14, 2014
1 parent f2b68d8 commit 846a17f
Show file tree
Hide file tree
Showing 12 changed files with 2,101 additions and 3 deletions.
25 changes: 25 additions & 0 deletions COPYING
@@ -0,0 +1,25 @@
-------------------------------------------------------------------------------
The standard MIT License for code in this archive written by Balaji Lakshminarayanan
http://www.opensource.org/licenses/mit-license.php
-------------------------------------------------------------------------------
Copyright (c) 2014 Balaji Lakshminarayanan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
-------------------------------------------------------------------------------

73 changes: 70 additions & 3 deletions README.md
@@ -1,4 +1,71 @@
mondrianforest
==============
This folder contains the scripts used in the following paper:

Code for "Mondrian Forests: Efficient Online Random Forests"
**Mondrian Forests: Efficient Online Random Forests**

Balaji Lakshminarayanan, Daniel M. Roy, Yee Whye Teh

[http://arxiv.org/abs/1406.2673](http://arxiv.org/abs/1406.2673)

Please cite the above paper if you use this code.



I ran my experiments using Enthought python (which includes all the necessary python packages).
If you are running a different version of python, you will need the following python packages
(and possibly other packages) to run the scripts:

* numpy
* scipy
* matplotlib (for plotting Mondrian partitions)
* pydot and graphviz (for printing Mondrian trees)
* sklearn (for reading libsvm format files)


The datasets are not included here; you need to download them from the UCI repository. You can run
experiments using toy data though. Run **commands.sh** in **process_data** folder for automatically
downloading and processing the datasets. I have tested these scripts only on Ubuntu, but it should be straightforward to process datasets in other platforms.

If you have any questions/comments/suggestions, please contact me at
[balaji@gatsby.ucl.ac.uk](mailto:balaji@gatsby.ucl.ac.uk).

Code released under MIT license (see COPYING for more info).

Copyright © 2014 Balaji Lakshminarayanan

----------------------------------------------------------------------------

**List of scripts in the src folder**:

- mondrianforest.py
- mondrianforest_utils.py
- utils.py

Help on usage can be obtained by typing the following commands on the terminal:

./mondrianforest.py -h

**Example usage**:

./mondrianforest.py --dataset toy-mf --n_mondrians 100 --budget -1 --normalize_features 1

**Examples that draw the Mondrian partition and Mondrian tree**:

./mondrianforest.py --draw_mondrian 1 --save 1 --n_mondrians 10 --dataset toy-mf --store_every 1 --n_mini 6 --tag demo
./mondrianforest.py --draw_mondrian 1 --save 1 --n_mondrians 1 --dataset toy-mf --store_every 1 --n_mini 6 --tag demo

**Example on a real-world dataset**:

*assuming you have successfully run commands.sh in process_data folder*

./mondrianforest.py --dataset satimage --n_mondrians 100 --budget -1 --normalize_features 1 --save 1 --data_path ../process_data/ --n_minibatches 10 --store_every 1

----------------------------------------------------------------------------

I generated commands for parameter sweeps using 'build_cmds' script by Jan Gasthaus, available publicly at [https://github.com/jgasthaus/Gitsby/tree/master/pbs/python](https://github.com/jgasthaus/Gitsby/tree/master/pbs/python).

Some examples of parameter sweeps are:

./build_cmds ./mondrianforest.py "--op_dir={results}" "--init_id=1:1:6" "--dataset={letter,satimage,usps,dna,dna-61-120}" "--n_mondrians={100}" "--save={1}" "--discount_factor={10.0}" "--budget={-1}" "--n_minibatches={100}" "--bagging={0}" "--store_every={1}" "--normalize_features={1}" "--data_path={../process_data/}" >> run

Note that the results (predictions, accuracy, log predictive probability on training/test data, runtimes) are stored in the pickle files.
You need to write additional scripts to aggregate the results from these pickle files and generate the plots.
21 changes: 21 additions & 0 deletions process_data/commands.sh
@@ -0,0 +1,21 @@
#!/usr/bin/env bash
# script to download datasets for pre-process them

# download datasets
for dir in 'usps' 'dna' 'satimage' 'letter'; do
echo ${dir}
cd ${dir}
./commands.sh
cd ..
done

# process libsvm datasets
./process_libsvm_datasets.py usps usps usps.t
./process_libsvm_datasets.py dna dna.scale.tr dna.scale.t # ignoring validation files
./process_libsvm_datasets.py satimage satimage.scale.tr satimage.scale.t # ignoring validation files
# ./process_libsvm_datasets.py letter letter.scale.tr letter.scale.t # ignoring validation files
./process_libsvm_datasets.py letter letter.scale.tr letter.scale.t letter.scale.val # adding validation to training (to make comparisons with OnlineRF by Saffari et al.)
cd dna-61-120
./process_dna-61-120.py
cd ..
./convert_pickle_2_onlinerf_format.py dna-61-120
31 changes: 31 additions & 0 deletions process_data/convert_pickle_2_onlinerf_format.py
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# input dataset is pickle file (to retain same train/test split)
# example: ./convert_pickle_2_onlinerf_format.py dna-61-120
#
# Output is slightly different from LIBSVM format
# - need to add header of #Samples #Features #Classes #FeatureMinIndex to the files
# - class indices need to start from 0

import sys
import os
import cPickle as pickle
from itertools import izip

name = sys.argv[1]

data = pickle.load(open(name + '/' + name + '.p', 'rb'))

d_n_dim = {'magic04': 10, 'pendigits': 16, 'dna-61-120': 60}
d_n_class = {'magic04': 2, 'pendigits': 10, 'dna-61-120': 3}

feat_id_start = 1

def print_file(x, y, name, op_name):
op = open(name + '/' + op_name, 'w')
print>>op, '%s %s %s %s' % (len(y), d_n_dim[name], d_n_class[name], feat_id_start)
for x_, y_ in izip(x, y):
s = ' '.join(['%d:%f' % (i+1, x__) for i, x__ in enumerate(x_)])
print>>op, '%s %s' % (y_, s)

print_file(data['x_train'], data['y_train'], name, name + '.orf.train')
print_file(data['x_test'], data['y_test'], name, name + '.orf.test')
5 changes: 5 additions & 0 deletions process_data/dna/commands.sh
@@ -0,0 +1,5 @@
#!/usr/bin/env bash

wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/dna.scale.tr
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/dna.scale.t
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/dna.scale.val
5 changes: 5 additions & 0 deletions process_data/letter/commands.sh
@@ -0,0 +1,5 @@
#!/usr/bin/env bash

wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.t
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.val
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.tr
54 changes: 54 additions & 0 deletions process_data/process_libsvm_datasets.py
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# usage: ./process_libsvm_datasets.py name filename_train filename_test filename_val
# name is the folder name as well as the pickled filename
# filename_train and filename_test are assumed to exist in the folder name
# filename_val is optional (validation dataset will be added to training data if supplied)

import sys
import numpy as np
import cPickle as pickle
from sklearn.datasets import load_svmlight_file

def get_x_y(filename, name):
(x, y) = load_svmlight_file(filename)
x = x.toarray()
if name == 'pendigits-svm':
x = x[:, 1:]
y = y.astype('int')
if name != 'pendigits-svm':
y -= 1
return (x, y)

name, filename_train, filename_test = sys.argv[1:4]

data = {}

x, y = get_x_y(name + '/' + filename_train, name)
data['x_train'] = x
data['y_train'] = y
if len(sys.argv) > 4:
filename_val = sys.argv[4]
x, y = get_x_y(name + '/' + filename_val, name)
data['x_train'] = np.vstack((data['x_train'], x))
data['y_train'] = np.append(data['y_train'], y)
data['n_train'] = data['x_train'].shape[0]
assert len(data['y_train']) == data['n_train']

x, y = get_x_y(name + '/' + filename_test, name)
data['x_test'] = x
data['n_test'] = x.shape[0]
data['y_test'] = y

data['n_dim'] = x.shape[1]
data['n_class'] = len(np.unique(y))
try:
assert data['n_class'] == max(np.unique(y)) + 1
except AssertionError:
print 'np.unique(y) = %s' % np.unique(y)
raise AssertionError
data['is_sparse'] = False

print 'name = %10s, n_dim = %5d, n_class = %5d, n_train = %5d, n_test = %5d' \
% (name, data['n_dim'], data['n_class'], data['n_train'], data['n_test'])

pickle.dump(data, open(name + '/' + name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL)
5 changes: 5 additions & 0 deletions process_data/satimage/commands.sh
@@ -0,0 +1,5 @@
#!/usr/bin/env bash

wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.t
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.tr
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.val
6 changes: 6 additions & 0 deletions process_data/usps/commands.sh
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2
wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2
bzip2 -d usps.bz2
bzip2 -d usps.t.bz2

0 comments on commit 846a17f

Please sign in to comment.