Augment and Reduce
Code for Augment & Reduce, a scalable stochastic algorithm for large categorical distributions.
This code replicates the experiments in the paper
- Francisco J. R. Ruiz, Michalis K. Titsias, Adji B. Dieng, and David M. Blei. Augment and Reduce: Stochastic Inference for Large Categorical Distributions. International Conference on Machine Learning. 2018.
This code trains a linear multiclass classifier on a dataset with a large number of classes.
Please cite this paper if you use this software.
The code is written in Matlab, combined with C++ functions.
flag_mexFile controls whether you wish to use the C++ code. It is strongly recommended to leave the flag active to speed up the code. This is the default setting. For that, you first need to compile the mex files using the two steps below (they should work under Mac and Unix).
- First, make sure you have the GSL library installed. If so, open a terminal and run
gsl-config --cflags --libs
Copy the output on the clipboard; you will need it for the second step.
- Second, open Matlab,
cdto the repo path, and run the commands below, replacing
<TERMINAL_OUTPUT>with the output from Step 1.
mex CFLAGS="\$CFLAGS" -largeArrayDims src/infer/compute_psi.cpp -outdir src/infer mex CFLAGS="\$CFLAGS" -largeArrayDims src/infer/increase_follow_gradients.cpp -outdir src/infer mex CFLAGS="\$CFLAGS" -largeArrayDims src/aux/keep_first_label_c.cpp -outdir src/aux mex CFLAGS="\$CFLAGS" <TERMINAL_OUTPUT> -largeArrayDims src/infer/compute_predictions_c.cpp -outdir src/infer mex CFLAGS="\$CFLAGS" <TERMINAL_OUTPUT> src/aux/multirandperm.cpp -outdir src/aux
The data should be contained in a Matlab struct object and it must contain the following fields:
data a struct containing the data data.X the training data (instances x dimensions). It MUST be in sparse matrix format (use the command sparse). data.Y the training labels (instances x 1). Each element indicates the class (from 1, ..., K). data.test a struct containing the test data data.test.X the test data (test_instances x dimensions). It MUST be in sparse matrix format. data.test.Y the test labels (test_instances x 1). Each element indicates the class (from 1, ..., K).
Please refer to the main files in
src/ for additional information.
You can also obtain the datasets used in the paper.