Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Density Estimation Tree made sparse-enabled #802

Merged
merged 28 commits into from Nov 1, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
84990cb
- Mac gitignore changes.
thejonan Oct 6, 2016
dc2a8e8
Merge remote-tracking branch 'upstream/master'
thejonan Oct 6, 2016
5b13881
- DTree class templated.
thejonan Oct 13, 2016
8b4c907
- First successfull builtd.
thejonan Oct 13, 2016
0b58fd9
- DET changes propagated to tests.
thejonan Oct 14, 2016
47bf92f
- DET templating ready and tests passing.
thejonan Oct 14, 2016
e4a9be0
- More sparse-matrix migration steps.
thejonan Oct 17, 2016
3547fff
- First successfull SpMat build.
thejonan Oct 17, 2016
11c4b0a
- Sparse DTree finally working.
thejonan Oct 17, 2016
4083784
- All DET sparsification works.
thejonan Oct 18, 2016
44fd0c0
- OMP test line removed.
thejonan Oct 18, 2016
525c617
Merge remote-tracking branch 'upstream/master'
thejonan Oct 18, 2016
aa2ad99
- DTree class templated.
thejonan Oct 13, 2016
37a9e50
- First successfull builtd.
thejonan Oct 13, 2016
f1d4467
- DET changes propagated to tests.
thejonan Oct 14, 2016
45ff5ba
- DET templating ready and tests passing.
thejonan Oct 14, 2016
9bee58e
- More sparse-matrix migration steps.
thejonan Oct 17, 2016
6b3ae6c
- First successfull SpMat build.
thejonan Oct 17, 2016
dc8d9b1
- Sparse DTree finally working.
thejonan Oct 17, 2016
1fee500
- All DET sparsification works.
thejonan Oct 18, 2016
0d0e387
- OMP test line removed.
thejonan Oct 18, 2016
98cadd8
Merge remote-tracking branch 'origin/feature/det_sparse' into feature…
thejonan Oct 18, 2016
24f6d8c
- Fixed openmp declarations.
thejonan Oct 18, 2016
4514e4c
- Sparse matrix speed-ups.
thejonan Oct 19, 2016
8d26bb6
- Fix of sparse iteration
thejonan Oct 19, 2016
50fa931
- Rowback to faster sparse iteration.
thejonan Oct 19, 2016
a60ae8a
- Fixes, based on PR's comments.
thejonan Oct 19, 2016
10812ca
- Template fixes for ExtractSplits.
thejonan Oct 20, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mlpack/methods/det/CMakeLists.txt
Expand Up @@ -5,6 +5,10 @@ set(SOURCES
# the DET class
dtree.hpp
dtree_impl.hpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we include dt_utils.hpp and dt_utils_impl.hpp too? I overlooked this in my review :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've decided to remove them, because I'm not sure they should actually be part of the library itself. Aren't they mostly related to the command-line tool?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they actually can be used for the library---they contain some cross-validation code and the Trainer() method which is used for producing a cross-validated tree. To be honest I would prefer if those were refactored into the class constructor itself, since they are not very "mlpack-ish", but either way they are something that a user may want to use so we should leave them in.


# Utility files
dt_utils.hpp
dt_utils_impl.hpp
)

# add directory name to sources
Expand Down
18 changes: 13 additions & 5 deletions src/mlpack/methods/det/dt_utils_impl.hpp
Expand Up @@ -158,7 +158,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
// Some sanity checks. It seems that on some datasets, the error does not
// increase as the tree is pruned but instead stays the same---hence the
// "<=" in the final assert.
Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtree.SubtreeLeaves() == 1));
Log::Assert((alpha < std::numeric_limits<double>::max())
|| (dtree.SubtreeLeaves() == 1));
Log::Assert(alpha > oldAlpha);
Log::Assert(dtree.SubtreeLeavesLogNegError() <= treeSeq.second);
}
Expand Down Expand Up @@ -191,7 +192,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
{
// Break up data into train and test sets.
const size_t start = fold * testSize;
const size_t end = std::min((size_t) (fold + 1) * testSize, (size_t) cvData.n_cols);
const size_t end = std::min((size_t) (fold + 1)
* testSize, (size_t) cvData.n_cols);

MatType test = cvData.cols(start, end - 1);
MatType train(cvData.n_rows, cvData.n_cols - test.n_cols);
Expand Down Expand Up @@ -242,7 +244,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
cvRegularizationConstants[i] += 2.0 * cvVal / (double) cvData.n_cols;

// Determine the new alpha value and prune accordingly.
double cvOldAlpha = 0.5 * (prunedSequence[i + 1].first + prunedSequence[i + 2].first);
double cvOldAlpha = 0.5 * (prunedSequence[i + 1].first
+ prunedSequence[i + 2].first);
cvDTree.PruneAndUpdate(cvOldAlpha, train.n_cols, useVolumeReg);
}

Expand All @@ -255,7 +258,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
}

if (prunedSequence.size() > 2)
cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal / (double) cvData.n_cols;
cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal
/ (double) cvData.n_cols;

#pragma omp critical (DTreeCVUpdate)
regularizationConstants += cvRegularizationConstants;
Expand Down Expand Up @@ -293,7 +297,11 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,

// Grow the tree.
oldAlpha = -DBL_MAX;
alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize, minLeafSize);
alpha = dtreeOpt->Grow(newDataset,
oldFromNew,
useVolumeReg,
maxLeafSize,
minLeafSize);

// Prune with optimal alpha.
while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
Expand Down
62 changes: 38 additions & 24 deletions src/mlpack/methods/det/dtree_impl.hpp
Expand Up @@ -23,7 +23,8 @@ namespace details
* in a vector, that can easily be iterated afterwards.
*/
template <typename MatType>
void ExtractSplits(std::vector<std::pair<typename MatType::elem_type, size_t>>& splitVec,
void ExtractSplits(std::vector<
std::pair<typename MatType::elem_type, size_t>>& splitVec,
const MatType& data,
size_t dim,
size_t start,
Expand Down Expand Up @@ -90,7 +91,8 @@ namespace details
lastVal = ElemType(0);
}

if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)// the normal case
// the normal case
if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)
{
// This makes sense for real continuous data. This kinda corrupts the
// data and estimation if the data is ordinal.
Expand Down Expand Up @@ -278,8 +280,6 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
#endif
{
// Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
// think of how to do that...
const ElemType min = minVals[dim];
const ElemType max = maxVals[dim];

Expand Down Expand Up @@ -329,8 +329,10 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
// and because the volume is only dependent on the dimension we are
// splitting, we can assume V_l is just the range of the left and V_r is
// just the range of the right.
double negLeftError = std::pow(position + 1, 2.0) / (split - min);
double negRightError = std::pow(points - position - 1, 2.0) / (max - split);
double negLeftError = std::pow(position + 1, 2.0)
/ (split - min);
double negRightError = std::pow(points - position - 1, 2.0)
/ (max - split);

// If this is better, take it.
if ((negLeftError + negRightError) >= minDimError)
Expand All @@ -344,21 +346,23 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
}
}

double actualMinDimError = std::log(minDimError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
double actualMinDimError = std::log(minDimError)
- 2 * std::log((double) data.n_cols)
- volumeWithoutDim;

#pragma omp critical (DTreeFindUpdate)
if ((actualMinDimError > minError) && dimSplitFound)
{
{
// Calculate actual error (in logspace) by adding terms back to our
// estimate.
minError = actualMinDimError;
splitDim = dim;
splitValue = dimSplitValue;
leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
splitFound = true;
}
// Calculate actual error (in logspace) by adding terms back to our
// estimate.
minError = actualMinDimError;
splitDim = dim;
splitValue = dimSplitValue;
leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
- volumeWithoutDim;
rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
- volumeWithoutDim;
splitFound = true;
} // end if better split found in this dimension.
}

Expand Down Expand Up @@ -451,8 +455,10 @@ double DTree<MatType, TagType>::Grow(MatType& data,
left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);

leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
minLeafSize);
rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
minLeafSize);

// Store values of R(T~) and |T~|.
subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
Expand Down Expand Up @@ -517,12 +523,15 @@ double DTree<MatType, TagType>::Grow(MatType& data,

if (right->SubtreeLeaves() > 1)
{
const double exponent = 2 * std::log((double) data.n_cols) + logVolume + right->AlphaUpper();
const double exponent = 2 * std::log((double) data.n_cols)
+ logVolume
+ right->AlphaUpper();

tmpAlphaSum += std::exp(exponent);
}

alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols) - logVolume;
alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
- logVolume;

double gT;
if (useVolReg)
Expand Down Expand Up @@ -689,7 +698,9 @@ double DTree<MatType, TagType>::ComputeValue(const VecType& query) const
else
{
// Return either of the two children - left or right, depending on the splitValue
return (query[splitDim] <= splitValue) ? left->ComputeValue(query) : right->ComputeValue(query);
return (query[splitDim] <= splitValue) ?
left->ComputeValue(query) :
right->ComputeValue(query);
}

return 0.0;
Expand Down Expand Up @@ -725,12 +736,15 @@ TagType DTree<MatType, TagType>::FindBucket(const VecType& query) const
else
{
// Return the tag from either of the two children - left or right.
return (query[splitDim] <= splitValue) ? left->FindBucket(query) : right->FindBucket(query);
return (query[splitDim] <= splitValue) ?
left->FindBucket(query) :
right->FindBucket(query);
}
}

template <typename MatType, typename TagType>
void DTree<MatType, TagType>::ComputeVariableImportance(arma::vec& importances) const
void
DTree<MatType, TagType>::ComputeVariableImportance(arma::vec& importances) const
{
// Clear and set to right size.
importances.zeros(maxVals.n_elem);
Expand Down