Skip to content

Commit

Permalink
Merge pull request #1038 from rcurtin/num_classes
Browse files Browse the repository at this point in the history
Refactor classifiers to take numClasses parameters
  • Loading branch information
rcurtin committed Jul 25, 2017
2 parents cea8fe2 + e717023 commit dc70207
Show file tree
Hide file tree
Showing 16 changed files with 227 additions and 127 deletions.
6 changes: 4 additions & 2 deletions src/mlpack/methods/adaboost/adaboost.hpp
Expand Up @@ -95,6 +95,7 @@ class AdaBoost
*/
AdaBoost(const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const WeakLearnerType& other,
const size_t iterations = 100,
const double tolerance = 1e-6);
Expand All @@ -114,7 +115,7 @@ class AdaBoost
double& Tolerance() { return tolerance; }

//! Get the number of classes this model is trained on.
size_t Classes() const { return classes; }
size_t NumClasses() const { return numClasses; }

//! Get the number of weak learners in the model.
size_t WeakLearners() const { return alpha.size(); }
Expand Down Expand Up @@ -142,6 +143,7 @@ class AdaBoost
*/
void Train(const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const WeakLearnerType& learner,
const size_t iterations = 100,
const double tolerance = 1e-6);
Expand All @@ -163,7 +165,7 @@ class AdaBoost

private:
//! The number of classes in the model.
size_t classes;
size_t numClasses;
// The tolerance for change in rt and when to stop.
double tolerance;

Expand Down
21 changes: 11 additions & 10 deletions src/mlpack/methods/adaboost/adaboost_impl.hpp
Expand Up @@ -44,11 +44,12 @@ template<typename WeakLearnerType, typename MatType>
AdaBoost<WeakLearnerType, MatType>::AdaBoost(
const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const WeakLearnerType& other,
const size_t iterations,
const double tol)
{
Train(data, labels, other, iterations, tol);
Train(data, labels, numClasses, other, iterations, tol);
}

// Empty constructor.
Expand All @@ -64,6 +65,7 @@ template<typename WeakLearnerType, typename MatType>
void AdaBoost<WeakLearnerType, MatType>::Train(
const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const WeakLearnerType& other,
const size_t iterations,
const double tolerance)
Expand All @@ -72,9 +74,8 @@ void AdaBoost<WeakLearnerType, MatType>::Train(
wl.clear();
alpha.clear();

// Count the number of classes.
classes = (arma::max(labels) - arma::min(labels)) + 1;
this->tolerance = tolerance;
this->numClasses = numClasses;

// crt is the cumulative rt value for terminating the optimization when rt is
// changing by less than the tolerance.
Expand All @@ -89,11 +90,12 @@ void AdaBoost<WeakLearnerType, MatType>::Train(
MatType tempData(data);

// This matrix is a helper matrix used to calculate the final hypothesis.
arma::mat sumFinalH = arma::zeros<arma::mat>(classes, predictedLabels.n_cols);
arma::mat sumFinalH = arma::zeros<arma::mat>(numClasses,
predictedLabels.n_cols);

// Load the initial weights into a 2-D matrix.
const double initWeight = 1.0 / double(data.n_cols * classes);
arma::mat D(classes, data.n_cols);
const double initWeight = 1.0 / double(data.n_cols * numClasses);
arma::mat D(numClasses, data.n_cols);
D.fill(initWeight);

// Weights are stored in this row vector.
Expand All @@ -117,7 +119,7 @@ void AdaBoost<WeakLearnerType, MatType>::Train(
weights = arma::sum(D);

// Use the existing weak learner to train a new one with new weights.
WeakLearnerType w(other, tempData, labels, weights);
WeakLearnerType w(other, tempData, labels, numClasses, weights);
w.Classify(tempData, predictedLabels);

// Now from predictedLabels, build ht, the weak hypothesis
Expand Down Expand Up @@ -165,7 +167,6 @@ void AdaBoost<WeakLearnerType, MatType>::Train(
D(k, j) /= expo;
zt += D(k, j); // * exp(-1 * alphat * yt(j,k) * ht(j,k));


// Add to the final hypothesis matrix.
// sumFinalH(k, j) += (alphat * ht(k, j));
if (k == labels(j))
Expand Down Expand Up @@ -208,7 +209,7 @@ void AdaBoost<WeakLearnerType, MatType>::Classify(
arma::Row<size_t>& predictedLabels)
{
arma::Row<size_t> tempPredictedLabels(test.n_cols);
arma::mat cMatrix(classes, test.n_cols);
arma::mat cMatrix(numClasses, test.n_cols);

cMatrix.zeros();
predictedLabels.set_size(test.n_cols);
Expand Down Expand Up @@ -240,7 +241,7 @@ template<typename Archive>
void AdaBoost<WeakLearnerType, MatType>::Serialize(Archive& ar,
const unsigned int /* version */)
{
ar & data::CreateNVP(classes, "classes");
ar & data::CreateNVP(numClasses, "classes");
ar & data::CreateNVP(tolerance, "tolerance");
ar & data::CreateNVP(ztProduct, "ztProduct");
ar & data::CreateNVP(alpha, "alpha");
Expand Down
5 changes: 4 additions & 1 deletion src/mlpack/methods/adaboost/adaboost_main.cpp
Expand Up @@ -202,8 +202,11 @@ int main(int argc, char *argv[])
else if (weakLearner == "perceptron")
m.WeakLearnerType() = AdaBoostModel::WeakLearnerTypes::PERCEPTRON;

const size_t numClasses = m.Mappings().n_elem;
Log::Info << numClasses << " classes in dataset." << endl;

Timer::Start("adaboost_training");
m.Train(trainingData, labels, iterations, tolerance);
m.Train(trainingData, labels, numClasses, iterations, tolerance);
Timer::Stop("adaboost_training");
}
else
Expand Down
7 changes: 4 additions & 3 deletions src/mlpack/methods/adaboost/adaboost_model.cpp
Expand Up @@ -92,6 +92,7 @@ AdaBoostModel::~AdaBoostModel()
//! Train the model.
void AdaBoostModel::Train(const mat& data,
const Row<size_t>& labels,
const size_t numClasses,
const size_t iterations,
const double tolerance)
{
Expand All @@ -101,13 +102,13 @@ void AdaBoostModel::Train(const mat& data,
delete dsBoost;

DecisionStump<> ds(data, labels, max(labels) + 1);
dsBoost = new AdaBoost<DecisionStump<>>(data, labels, ds, iterations,
tolerance);
dsBoost = new AdaBoost<DecisionStump<>>(data, labels, numClasses, ds,
iterations, tolerance);
}
else if (weakLearnerType == WeakLearnerTypes::PERCEPTRON)
{
Perceptron<> p(data, labels, max(labels) + 1);
pBoost = new AdaBoost<Perceptron<>>(data, labels, p, iterations,
pBoost = new AdaBoost<Perceptron<>>(data, labels, numClasses, p, iterations,
tolerance);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/mlpack/methods/adaboost/adaboost_model.hpp
Expand Up @@ -77,6 +77,7 @@ class AdaBoostModel
//! Train the model.
void Train(const arma::mat& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const size_t iterations,
const double tolerance);

Expand Down
15 changes: 8 additions & 7 deletions src/mlpack/methods/decision_stump/decision_stump.hpp
Expand Up @@ -40,12 +40,12 @@ class DecisionStump
*
* @param data Input, training data.
* @param labels Labels of training data.
* @param classes Number of distinct classes in labels.
* @param numClasses Number of distinct classes in labels.
* @param bucketSize Minimum size of bucket when splitting.
*/
DecisionStump(const MatType& data,
const arma::Row<size_t>& labels,
const size_t classes,
const size_t numClasses,
const size_t bucketSize = 10);

/**
Expand All @@ -62,6 +62,7 @@ class DecisionStump
DecisionStump(const DecisionStump<>& other,
const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const arma::rowvec& weights);

/**
Expand All @@ -78,12 +79,12 @@ class DecisionStump
*
* @param data Dataset to train on.
* @param labels Labels for each point in the dataset.
* @param classes Number of classes in the dataset.
* @param numClasses Number of classes in the dataset.
* @param bucketSize Minimum size of bucket when splitting.
*/
void Train(const MatType& data,
const arma::Row<size_t>& labels,
const size_t classes,
const size_t numClasses,
const size_t bucketSize);

/**
Expand All @@ -94,13 +95,13 @@ class DecisionStump
* @param data Dataset to train on.
* @param labels Labels for each point in the dataset.
* @param weights Weights for each point in the dataset.
* @param classes Number of classes in the dataset.
* @param numClasses Number of classes in the dataset.
* @param bucketSize Minimum size of bucket when splitting.
*/
void Train(const MatType& data,
const arma::Row<size_t>& labels,
const arma::rowvec& weights,
const size_t classes,
const size_t numClasses,
const size_t bucketSize);

/**
Expand Down Expand Up @@ -134,7 +135,7 @@ class DecisionStump

private:
//! The number of classes (we must store this for boosting).
size_t classes;
size_t numClasses;
//! The minimum number of points in a bucket.
size_t bucketSize;

Expand Down
28 changes: 14 additions & 14 deletions src/mlpack/methods/decision_stump/decision_stump_impl.hpp
Expand Up @@ -9,7 +9,6 @@
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_IMPL_HPP
#define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_IMPL_HPP

Expand All @@ -24,15 +23,15 @@ namespace decision_stump {
*
* @param data Input, training data.
* @param labels Labels of data.
* @param classes Number of distinct classes in labels.
* @param numClasses Number of distinct classes in labels.
* @param bucketSize Minimum size of bucket when splitting.
*/
template<typename MatType>
DecisionStump<MatType>::DecisionStump(const MatType& data,
const arma::Row<size_t>& labels,
const size_t classes,
const size_t numClasses,
const size_t bucketSize) :
classes(classes),
numClasses(numClasses),
bucketSize(bucketSize)
{
arma::rowvec weights;
Expand All @@ -44,7 +43,7 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
*/
template<typename MatType>
DecisionStump<MatType>::DecisionStump() :
classes(1),
numClasses(1),
bucketSize(0),
splitDimension(0),
split(1),
Expand All @@ -60,10 +59,10 @@ DecisionStump<MatType>::DecisionStump() :
template<typename MatType>
void DecisionStump<MatType>::Train(const MatType& data,
const arma::Row<size_t>& labels,
const size_t classes,
const size_t numClasses,
const size_t bucketSize)
{
this->classes = classes;
this->numClasses = numClasses;
this->bucketSize = bucketSize;

// Pass to unweighted training function.
Expand All @@ -80,10 +79,10 @@ template<typename MatType>
void DecisionStump<MatType>::Train(const MatType& data,
const arma::Row<size_t>& labels,
const arma::rowvec& weights,
const size_t classes,
const size_t numClasses,
const size_t bucketSize)
{
this->classes = classes;
this->numClasses = numClasses;
this->bucketSize = bucketSize;

// Pass to weighted training function.
Expand Down Expand Up @@ -186,8 +185,9 @@ template<typename MatType>
DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const arma::rowvec& weights) :
classes(other.classes),
numClasses(numClasses),
bucketSize(other.bucketSize)
{
Train<true>(data, labels, weights);
Expand All @@ -205,7 +205,7 @@ void DecisionStump<MatType>::Serialize(Archive& ar,

// This is straightforward; just serialize all of the members of the class.
// None need special handling.
ar & CreateNVP(classes, "classes");
ar & CreateNVP(numClasses, "classes");
ar & CreateNVP(bucketSize, "bucketSize");
ar & CreateNVP(splitDimension, "splitDimension");
ar & CreateNVP(split, "split");
Expand Down Expand Up @@ -469,7 +469,7 @@ double DecisionStump<MatType>::CalculateEntropy(
double entropy = 0.0;
size_t j;

arma::rowvec numElem(classes);
arma::rowvec numElem(numClasses);
numElem.fill(0);

// Variable to accumulate the weight in this subview_row.
Expand All @@ -484,7 +484,7 @@ double DecisionStump<MatType>::CalculateEntropy(
accWeight += weights(j);
}

for (j = 0; j < classes; j++)
for (j = 0; j < numClasses; j++)
{
const double p1 = ((double) numElem(j) / accWeight);

Expand All @@ -499,7 +499,7 @@ double DecisionStump<MatType>::CalculateEntropy(
for (j = 0; j < labels.n_elem; j++)
numElem(labels(j))++;

for (j = 0; j < classes; j++)
for (j = 0; j < numClasses; j++)
{
const double p1 = ((double) numElem(j) / labels.n_elem);

Expand Down

0 comments on commit dc70207

Please sign in to comment.