Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic bindings step 3: model parameters #825

Merged
merged 46 commits into from
Jan 13, 2017
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
37e3585
Remove unneeded file.
rcurtin Nov 18, 2016
365309b
Improve output and printing by giving each parameter its own function…
rcurtin Nov 18, 2016
eee67ab
Add parameter type for serializable models.
rcurtin Nov 18, 2016
8b405f9
Add first tests for serializable models.
rcurtin Nov 18, 2016
1a6e5f2
Adapt a first program to load and save models.
rcurtin Nov 18, 2016
29073ae
Fix minor bugs in implementation.
rcurtin Nov 23, 2016
39f9af6
Explicitly destroy the CLI object so that things are saved.
rcurtin Nov 23, 2016
3033940
Remove created file.
rcurtin Nov 23, 2016
7c2be9f
Add an actual test for serializing models.
rcurtin Nov 23, 2016
58dfb70
Update documentation for CLI.
rcurtin Nov 23, 2016
8d491ca
Fix approx_kfn program types, transition to model parameters.
rcurtin Nov 23, 2016
024eded
Refactor to clean up cli.hpp significantly.
rcurtin Nov 23, 2016
b10ef6b
Add functions for getting default type and string type.
rcurtin Nov 23, 2016
8d7dbf4
Refactor AdaBoost to use model parameters.
rcurtin Nov 23, 2016
a52e9c6
Avoid copying objects unnecessarily.
rcurtin Nov 23, 2016
6b07e26
Add copy constructor and move constructor.
rcurtin Nov 24, 2016
cc08c1a
Add copy and move constructors.
rcurtin Nov 24, 2016
d8db6d8
Transition some programs to use model parameters.
rcurtin Nov 24, 2016
6154c09
Refactor GMM programs; INPUT_MODEL_IN_REQ() does not exist yet.
rcurtin Nov 27, 2016
30e22cb
Add required model parameters and tests.
rcurtin Nov 27, 2016
3bac181
Fix model parameter names, and call Destroy() on exit.
rcurtin Nov 27, 2016
2967d23
Incremental checkin to change computers.
rcurtin Nov 28, 2016
c5dfbac
Don't forget newlines at the end of output.
rcurtin Nov 28, 2016
926bad9
Finish transition for HMM programs.
rcurtin Nov 28, 2016
9dfe2c3
Remove no-longer-needed bits.
rcurtin Nov 29, 2016
8c9d49f
Add copy/move constructor/assignment operators for LSHSearch.
rcurtin Nov 29, 2016
3d2a624
Convert several more methods to use MODEL parameters.
rcurtin Nov 29, 2016
7c9cb2c
Fix slightly off documentation.
rcurtin Nov 30, 2016
3292197
Convert to model parameters for the last of the easy programs.
rcurtin Nov 30, 2016
07d78d7
Um there's a tornado warning so I better commit this stuff, push it, …
rcurtin Nov 30, 2016
27b01fd
Add copy/move constructor and assignment operators to NeighborSearch,…
rcurtin Dec 1, 2016
cbb4e14
Fix bugs in move and copy constructors and operators.
rcurtin Dec 1, 2016
764f2a9
Use MODEL types in KFN program.
rcurtin Dec 1, 2016
9e6f915
Refactor mlpack_range_search to use MODEL parameters.
rcurtin Dec 1, 2016
4fda9b2
Adapt mlpack_krann to use model parameters.
rcurtin Dec 2, 2016
fd9de0b
Allow a parameter that can load a matrix with a DatasetInfo.
rcurtin Dec 3, 2016
7db083d
Add a test for loading DatasetInfo/mats.
rcurtin Dec 5, 2016
d73ef9f
Fix very important misspelling!
rcurtin Dec 5, 2016
f2af43c
Add a model for easy serialization of HoeffdingTree.
rcurtin Dec 5, 2016
d9b7898
Add a way to access parameters before they are loaded, and add tests.
rcurtin Dec 6, 2016
10bc7cb
Use different parameter type for test file.
rcurtin Dec 6, 2016
ff7e3b8
Add tests for HoeffdingTreeModel.
rcurtin Dec 7, 2016
2baa9f1
Fix ARFF reading bug: make sure we trim before mapping!
rcurtin Dec 7, 2016
1ab8980
Merge different overloads since they do the same thing.
rcurtin Jan 13, 2017
8afd83c
Simplify overloads, and also mark models as loaded.
rcurtin Jan 13, 2017
1061b0b
Remove unused code.
rcurtin Jan 13, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
243 changes: 243 additions & 0 deletions src/mlpack/tests/hoeffding_tree_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
#include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp>

#include <boost/test/unit_test.hpp>
#include "test_tools.hpp"
Expand Down Expand Up @@ -1145,4 +1146,246 @@ BOOST_AUTO_TEST_CASE(MultipleSerializationTest)
}
}

// Test the Hoeffding tree model.
BOOST_AUTO_TEST_CASE(HoeffdingTreeModelTest)
{
// Generate data.
arma::mat dataset(4, 3000);
arma::Row<size_t> labels(3000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
info.MapString("0", 3);
for (size_t i = 0; i < 3000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
dataset(1, i) = mlpack::math::Random();
dataset(2, i) = mlpack::math::Random();
dataset(3, i) = 0.0;
labels[i] = 0;

dataset(0, i + 1) = mlpack::math::Random();
dataset(1, i + 1) = mlpack::math::Random() - 1.0;
dataset(2, i + 1) = mlpack::math::Random() + 0.5;
dataset(3, i + 1) = 0.0;
labels[i + 1] = 2;

dataset(0, i + 2) = mlpack::math::Random();
dataset(1, i + 2) = mlpack::math::Random() + 1.0;
dataset(2, i + 2) = mlpack::math::Random() + 0.8;
dataset(3, i + 2) = 0.0;
labels[i + 2] = 1;
}

// Train a model on a simple dataset, for all four types of models, and make
// sure we get reasonable results.
for (size_t i = 0; i < 4; ++i)
{
HoeffdingTreeModel m;
switch (i)
{
case 0:
m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
break;

case 1:
m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
break;

case 2:
m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
break;

case 3:
m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
break;
}

// We'll take 5 passes over the data.
m.BuildModel(dataset, info, labels, 3, false, 0.99, 1000, 100, 100, 4, 100);
for (size_t j = 0; j < 4; ++j)
m.Train(dataset, labels, false);

// Now make sure the performance is reasonable.
arma::Row<size_t> predictions, predictions2;
arma::rowvec probabilities;
m.Classify(dataset, predictions);
m.Classify(dataset, predictions2, probabilities);

size_t correct = 0;
for (size_t i = 0; i < 3000; ++i)
{
// Check consistency of predictions.
BOOST_REQUIRE_EQUAL(predictions[i], predictions2[i]);

if (labels[i] == predictions[i])
++correct;
}

// Require at least 95% accuracy.
BOOST_REQUIRE_GT(correct, 2850);
}
}

// Test the Hoeffding tree model in batch mode.
BOOST_AUTO_TEST_CASE(HoeffdingTreeModelBatchTest)
{
// Generate data.
arma::mat dataset(4, 3000);
arma::Row<size_t> labels(3000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
info.MapString("0", 3);
for (size_t i = 0; i < 3000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
dataset(1, i) = mlpack::math::Random();
dataset(2, i) = mlpack::math::Random();
dataset(3, i) = 0.0;
labels[i] = 0;

dataset(0, i + 1) = mlpack::math::Random();
dataset(1, i + 1) = mlpack::math::Random() - 1.0;
dataset(2, i + 1) = mlpack::math::Random() + 0.5;
dataset(3, i + 1) = 0.0;
labels[i + 1] = 2;

dataset(0, i + 2) = mlpack::math::Random();
dataset(1, i + 2) = mlpack::math::Random() + 1.0;
dataset(2, i + 2) = mlpack::math::Random() + 0.8;
dataset(3, i + 2) = 0.0;
labels[i + 2] = 1;
}

// Train a model on a simple dataset, for all four types of models, and make
// sure we get reasonable results.
for (size_t i = 0; i < 4; ++i)
{
HoeffdingTreeModel m;
switch (i)
{
case 0:
m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
break;

case 1:
m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
break;

case 2:
m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
break;

case 3:
m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
break;
}

// Train in batch.
m.BuildModel(dataset, info, labels, 3, true, 0.99, 1000, 100, 100, 4, 100);

// Now make sure the performance is reasonable.
arma::Row<size_t> predictions, predictions2;
arma::rowvec probabilities;
m.Classify(dataset, predictions);
m.Classify(dataset, predictions2, probabilities);

size_t correct = 0;
for (size_t i = 0; i < 3000; ++i)
{
// Check consistency of predictions.
BOOST_REQUIRE_EQUAL(predictions[i], predictions2[i]);

if (labels[i] == predictions[i])
++correct;
}

// Require at least 95% accuracy.
BOOST_REQUIRE_GT(correct, 2850);
}
}

BOOST_AUTO_TEST_CASE(HoeffdingTreeModelSerializationTest)
{
// Generate data.
arma::mat dataset(4, 3000);
arma::Row<size_t> labels(3000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
info.MapString("0", 3);
for (size_t i = 0; i < 3000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
dataset(1, i) = mlpack::math::Random();
dataset(2, i) = mlpack::math::Random();
dataset(3, i) = 0.0;
labels[i] = 0;

dataset(0, i + 1) = mlpack::math::Random();
dataset(1, i + 1) = mlpack::math::Random() - 1.0;
dataset(2, i + 1) = mlpack::math::Random() + 0.5;
dataset(3, i + 1) = 0.0;
labels[i + 1] = 2;

dataset(0, i + 2) = mlpack::math::Random();
dataset(1, i + 2) = mlpack::math::Random() + 1.0;
dataset(2, i + 2) = mlpack::math::Random() + 0.8;
dataset(3, i + 2) = 0.0;
labels[i + 2] = 1;
}

// Train a model on a simple dataset, for all four types of models, and make
// sure we get reasonable results.
for (size_t i = 0; i < 4; ++i)
{
HoeffdingTreeModel m, xmlM, textM, binaryM;
switch (i)
{
case 0:
m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
break;

case 1:
m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
break;

case 2:
m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
break;

case 3:
m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
break;
}

// Train in batch.
m.BuildModel(dataset, info, labels, 3, true, 0.99, 1000, 100, 100, 4, 100);
// False training of XML model.
xmlM.BuildModel(dataset, info, labels, 3, false, 0.5, 100, 100, 100, 2,
100);

// Now make sure the performance is reasonable.
arma::Row<size_t> predictions, predictionsXml, predictionsText,
predictionsBinary;
arma::rowvec probabilities, probabilitiesXml, probabilitiesText,
probabilitiesBinary;

SerializeObjectAll(m, xmlM, textM, binaryM);

// Get predictions for all.
m.Classify(dataset, predictions, probabilities);
xmlM.Classify(dataset, predictionsXml, probabilitiesXml);
textM.Classify(dataset, predictionsText, probabilitiesText);
binaryM.Classify(dataset, predictionsBinary, probabilitiesBinary);

for (size_t i = 0; i < 3000; ++i)
{
// Check consistency of predictions and probabilities.
BOOST_REQUIRE_EQUAL(predictions[i], predictionsXml[i]);
BOOST_REQUIRE_EQUAL(predictions[i], predictionsText[i]);
BOOST_REQUIRE_EQUAL(predictions[i], predictionsBinary[i]);

BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesXml[i], 1e-5);
BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesText[i], 1e-5);
BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesBinary[i], 1e-5);
}
}
}

BOOST_AUTO_TEST_SUITE_END();