forked from mothur/mothur
-
Notifications
You must be signed in to change notification settings - Fork 1
/
decisiontree.hpp
executable file
·77 lines (58 loc) · 2.7 KB
/
decisiontree.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
//
// decisiontree.hpp
// rrf-fs-prototype
//
// Created by Abu Zaher Faridee on 5/28/12.
// Copyright (c) 2012 Schloss Lab. All rights reserved.
//
#ifndef RF_DECISIONTREE_HPP
#define RF_DECISIONTREE_HPP
#include "macros.h"
#include "rftreenode.hpp"
#include "abstractdecisiontree.hpp"
/***********************************************************************/
struct VariableRankDescendingSorter {
bool operator() (const pair<int, int>& firstPair, const pair<int, int>& secondPair){
return firstPair.second > secondPair.second;
}
};
struct VariableRankDescendingSorterDouble {
bool operator() (const pair<int, double>& firstPair, const pair<int, double>& secondPair){
return firstPair.second > secondPair.second;
}
};
/***********************************************************************/
class DecisionTree: public AbstractDecisionTree{
friend class RandomForest;
public:
DecisionTree(vector< vector<int> >& baseDataSet,
vector<int> globalDiscardedFeatureIndices,
OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
string treeSplitCriterion,
float featureStandardDeviationThreshold);
virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); }
int calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate);
int evaluateSample(vector<int> testSample);
int calcTreeErrorRate(int& numCorrect, double& treeErrorRate);
void randomlyShuffleAttribute(const vector< vector<int> >& samples,
const int featureIndex,
const int prevFeatureIndex,
vector< vector<int> >& shuffledSample);
void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); }
int purgeTreeNodesDataRecursively(RFTreeNode* treeNode);
void pruneTree(double pruneAggressiveness);
void pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness);
void updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample);
void updateOutputClassOfNode(RFTreeNode* treeNode);
private:
void buildDecisionTree();
int splitRecursively(RFTreeNode* rootNode);
int findAndUpdateBestFeatureToSplitOn(RFTreeNode* node);
vector<int> selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices);
int printTree(RFTreeNode* treeNode, string caption);
void deleteTreeNodesRecursively(RFTreeNode* treeNode);
vector<int> variableImportanceList;
map<int, int> outOfBagEstimates;
float featureStandardDeviationThreshold;
};
#endif