Skip to content

Commit

Permalink
Adding test and documentation fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffin143 committed May 11, 2019
1 parent 88af420 commit 06efecb
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 4 deletions.
2 changes: 1 addition & 1 deletion HISTORY.md
@@ -1,7 +1,7 @@
### mlpack 3.1.1
###### ????-??-??
* `output` option changed to `predictions` for adaboost and perceptron
binding.Old options are now deprecated and will be preserved until mlpack
binding. Old options are now deprecated and will be preserved until mlpack
4.0.0 (#1882).

* Concatenated ReLU layer (#1843).
Expand Down
2 changes: 1 addition & 1 deletion src/mlpack/methods/adaboost/adaboost_main.cpp
Expand Up @@ -101,7 +101,7 @@ PROGRAM_INFO("AdaBoost",
PRINT_DATASET("predictions") + " with the following command: "
"\n\n" +
PRINT_CALL("adaboost", "input_model", "model", "test", "test_data",
"output", "predictions"),
"predictions", "predictions"),
// See also...
SEE_ALSO("AdaBoost on Wikipedia", "https://en.wikipedia.org/wiki/AdaBoost"),
SEE_ALSO("Improved boosting algorithms using confidence-rated predictions "
Expand Down
4 changes: 2 additions & 2 deletions src/mlpack/methods/perceptron/perceptron_main.cpp
Expand Up @@ -78,7 +78,7 @@ PROGRAM_INFO("Perceptron",
"saving the predicted classes to " + PRINT_DATASET("predictions") + "."
"\n\n" +
PRINT_CALL("perceptron", "input_model", "perceptron_model", "test",
"test_data", "output", "predictions") +
"test_data", "predictions", "predictions") +
"\n\n"
"Note that all of the options may be specified at once: predictions may be "
"calculated right after training a model, and model training can occur even"
Expand Down Expand Up @@ -154,7 +154,7 @@ static void mlpackMain()
// should issue a warning.
RequireAtLeastOnePassed({ "output_model", "output", "predictions" }, false,
"no output will be saved");
// "output" can be removed in mlpack 4.
// "output" can be removed in mlpack 4.0.0.
ReportIgnoredParam({{ "test", false }}, "predictions");

// Check parameter validity.
Expand Down
22 changes: 22 additions & 0 deletions src/mlpack/tests/main_tests/adaboost_test.cpp
Expand Up @@ -205,6 +205,28 @@ BOOST_AUTO_TEST_CASE(AdaBoostTrainingDataOrModelTest)
BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error);
Log::Fatal.ignoreInput = false;
}
/**
* This test can be removed in mlpack 4.0.0. Testing
* the output and predictions outputs are the same.
*/
BOOST_AUTO_TEST_CASE(AdaBoostOutputPredictionsTest)
{
arma::mat trainData;
if (!data::Load("vc2.csv", trainData))
BOOST_FAIL("Unable to load train dataset vc2.csv!");

arma::Row<size_t> labels;
if (!data::Load("vc2_labels.txt", labels))
BOOST_FAIL("Unable to load label dataset vc2_labels.txt!");

SetInputParam("training", std::move(trainData));
SetInputParam("labels", std::move(labels));

mlpackMain();

CheckMatrices(CLI::GetParam<arma::Row<size_t>>("output"),
CLI::GetParam<arma::Row<size_t>>("predictions"));
}

/**
* Weak learner should be either Decision Stump or Perceptron.
Expand Down
33 changes: 33 additions & 0 deletions src/mlpack/tests/main_tests/perceptron_test.cpp
Expand Up @@ -162,6 +162,39 @@ BOOST_AUTO_TEST_CASE(PerceptronLabelsLessDimensionTest)
CheckMatrices(output, CLI::GetParam<arma::Row<size_t>>("output"));
}

/**
* This test can be removed in mlpack 4.0.0. Testing
* the output and predictions outputs are the same.
*/
BOOST_AUTO_TEST_CASE(PerceptronOutputPredictionsCheck)
{
arma::mat trainX1;
arma::Row<size_t> labelsX1;

// Loading a train data set with 3 classes.
if (!data::Load("vc2.csv", trainX1))
{
BOOST_FAIL("Could not load the train data (vc2.csv)");
}

// Loading the corresponding labels to the dataset.
if (!data::Load("vc2_labels.txt", labelsX1))
{
BOOST_FAIL("Could not load the train data (vc2_labels.csv)");
}

SetInputParam("training", std::move(trainX1)); // Training data.
// Labels for the training data.
SetInputParam("labels", std::move(labelsX1));

// Training model using first training dataset.
mlpackMain();

// Check that the outputs are the same.
CheckMatrices(CLI::GetParam<arma::Row<size_t>>("output"),
CLI::GetParam<arma::Row<size_t>>("predictions"));
}

/**
* Ensure that saved model can be used again.
*/
Expand Down

0 comments on commit 06efecb

Please sign in to comment.