Skip to content

Commit

Permalink
Python: RandomForest conforms to sklearn Estimator interface. Fix #98
Browse files Browse the repository at this point in the history
  • Loading branch information
Ilia-Shutov committed Aug 22, 2023
1 parent c798f62 commit 368fe2e
Show file tree
Hide file tree
Showing 33 changed files with 1,159 additions and 945 deletions.
4 changes: 2 additions & 2 deletions Python/docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ Retrieve the Tree Structure

This is an example of how to retrieve the underlying tree structure in the forest. To do that,
we need to use the :meth:`translate_tree() <forestry.RandomForest.translate_tree>` function,
which fills the :ref:`saved_forest <translate-label>` attribute for the corresponding tree.
which fills the :ref:`saved_forest_ <translate-label>` attribute for the corresponding tree.

.. code-block:: Python
Expand All @@ -315,5 +315,5 @@ which fills the :ref:`saved_forest <translate-label>` attribute for the correspo
# Translate the first tree in the forest
fr.translate_tree(0)
print(fr.saved_forest[0])
print(fr.saved_forest_[0])
104 changes: 59 additions & 45 deletions Python/extension/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ extern "C" {
std::vector<double>& coefs,
bool hier_shrinkage,
double lambda_shrinkage
) {
) {
DataFrame* dta_frame = reinterpret_cast<DataFrame *>(dataframe_pt);

forest->setDataframe(dta_frame);
Expand Down Expand Up @@ -498,7 +498,7 @@ extern "C" {
}


forestry* reconstructree(
forestry* reconstructForest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -532,6 +532,8 @@ extern "C" {
int* na_left_count,
int* na_right_count,
int* na_default_directions,
int* average_counts,
int* split_counts,
size_t* split_idx,
size_t* average_idx,
double* predict_weights,
Expand Down Expand Up @@ -584,10 +586,10 @@ extern "C" {
std::unique_ptr< std::vector< std::vector<int> > > var_ids(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<int> > > average_counts(
std::unique_ptr< std::vector< std::vector<int> > > averageCountsAll(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<int> > > split_counts(
std::unique_ptr< std::vector< std::vector<int> > > splitCountsAll(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<double> > > split_vals(
Expand Down Expand Up @@ -620,8 +622,8 @@ extern "C" {

// Reserve space for each of the vectors equal to ntree
var_ids->reserve(ntree);
average_counts->reserve(ntree);
split_counts->reserve(ntree);
averageCountsAll->reserve(ntree);
splitCountsAll->reserve(ntree);
split_vals->reserve(ntree);
averagingSampleIndex->reserve(ntree);
splittingSampleIndex->reserve(ntree);
Expand All @@ -632,65 +634,78 @@ extern "C" {
treeSeeds->reserve(ntree);
predictWeights->reserve(ntree);

// Now actually populate the vectors
// Deserialization of the data serialized in forestry.py
size_t ind = 0, ind_s = 0, ind_a = 0;
for(size_t i = 0; i < ntree; i++){
// Should be num total nodes
std::vector<int> cur_var_ids((tree_counts[4*i]), 0);
std::vector<int> cur_average_counts((tree_counts[4*i]), 0);
std::vector<int> cur_split_counts((tree_counts[4*i]), 0);
std::vector<double> cur_split_vals(tree_counts[4*i], 0);
std::vector<int> curNaLeftCounts(tree_counts[4*i], 0);
std::vector<int> curNaRightCounts(tree_counts[4*i], 0);
std::vector<int> curNaDefaultDirections(tree_counts[4*i], 0);
std::vector<size_t> curSplittingSampleIndex(tree_counts[4*i+1], 0);
std::vector<size_t> curAveragingSampleIndex(tree_counts[4*i+2], 0);
std::vector<double> cur_predict_weights(tree_counts[4*i], 0);

for(size_t j = 0; j < tree_counts[4*i]; j++){
cur_split_vals.at(j) = thresholds[ind];
curNaLeftCounts.at(j) = na_left_count[ind];
curNaRightCounts.at(j) = na_right_count[ind];
curNaDefaultDirections.at(j) = na_default_directions[ind];
cur_predict_weights.at(j) = predict_weights[ind];
cur_var_ids.at(j) = features[ind];
cur_average_counts.at(j) = features[ind];
cur_split_counts.at(j) = features[ind];

for (size_t i = 0; i < ntree; i++) {
// Py: tree_counts[3 * i] = num_nodes
const size_t numNodes = tree_counts[3 * i];

std::vector<double> curThresholds(numNodes);
std::vector<int> curNaLeftCounts(numNodes);
std::vector<int> curNaRightCounts(numNodes);
std::vector<int> curNaDefaultDirections(numNodes);
std::vector<int> curAverageCounts(numNodes);
std::vector<int> curSplitCounts(numNodes);
std::vector<double> cur_predict_weights(numNodes);
std::vector<int> cur_var_ids(numNodes);

for (size_t j = 0; j < numNodes; j++){
// for split node
curThresholds[j] = thresholds[ind];
curNaLeftCounts[j] = na_left_count[ind];
curNaRightCounts[j] = na_right_count[ind];
curNaDefaultDirections[j] = na_default_directions[ind];

// for leaf node
curAverageCounts[j] = average_counts[ind];
curSplitCounts[j] = split_counts[ind];

// if < 0: flag that this is a leaf node
// if >= 0: split node, feature == var_id - 1
cur_var_ids[j] = features[ind];

// leaf node weight or split node threshold
cur_predict_weights[j] = predict_weights[ind];
ind++;
}

for(size_t j = 0; j < tree_counts[4*i+1]; j++){
curSplittingSampleIndex.at(j) = split_idx[ind_s];

// Py: tree_counts[3 * i + 1] = num_split_idx
const size_t numSplitIdx = tree_counts[3 * i + 1];
std::vector<size_t> curSplittingSampleIndex(numSplitIdx);
for (size_t j = 0; j < numSplitIdx; j++) {
curSplittingSampleIndex[j] = split_idx[ind_s];
ind_s++;
}

for(size_t j = 0; j < tree_counts[4*i+2]; j++){
curAveragingSampleIndex.at(j) = average_idx[ind_a];

// Py: tree_counts[3 * i + 2] = num_av_idx
const size_t numAvIdx = tree_counts[3 * i + 2];
std::vector<size_t> curAveragingSampleIndex(numAvIdx);
for (size_t j = 0; j < numAvIdx; j++) {
curAveragingSampleIndex[j] = average_idx[ind_a];
ind_a++;
}


treeSeeds->push_back(tree_seeds[i]);
var_ids->push_back(cur_var_ids);
average_counts->push_back(cur_average_counts);
split_counts->push_back(cur_split_counts);
split_vals->push_back(cur_split_vals);
averageCountsAll->push_back(curAverageCounts);
splitCountsAll->push_back(curSplitCounts);
split_vals->push_back(curThresholds);
naLeftCounts->push_back(curNaLeftCounts);
naRightCounts->push_back(curNaRightCounts);
naDefaultDirections->push_back(curNaDefaultDirections);
splittingSampleIndex->push_back(curSplittingSampleIndex);
averagingSampleIndex->push_back(curAveragingSampleIndex);
excludedSampleIndex->push_back(std::vector<size_t>());
predictWeights->push_back(cur_predict_weights);
treeSeeds->push_back(tree_seeds[i]);
}

// call reconstructTrees
forest->reconstructTrees(
categoricalFeatureCols_copy,
treeSeeds,
var_ids,
average_counts,
split_counts,
averageCountsAll,
splitCountsAll,
split_vals,
naLeftCounts,
naRightCounts,
Expand All @@ -700,9 +715,8 @@ extern "C" {
excludedSampleIndex,
predictWeights
);

return forest;

}

size_t get_node_count(forestry* forest, int tree_idx) {
Expand Down
4 changes: 3 additions & 1 deletion Python/extension/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ extern "C" {
size_t numColumns,
unsigned int seed
);
forestry* reconstructree(
forestry* reconstructForest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -92,6 +92,8 @@ extern "C" {
int* na_left_count,
int* na_right_count,
int* na_default_directions,
int* average_counts,
int* split_counts,
size_t* split_idx,
size_t* average_idx,
double* predict_weights,
Expand Down
14 changes: 8 additions & 6 deletions Python/extension/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void *get_data_wrapper(
);
}

forestry *reconstructree_wrapper(
forestry *reconstructForestWrapper(
void *data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -120,12 +120,14 @@ forestry *reconstructree_wrapper(
py::array_t<int> na_left_count,
py::array_t<int> na_right_count,
py::array_t<int> na_default_directions,
py::array_t<int> average_counts,
py::array_t<int> split_counts,
py::array_t<size_t> split_idx,
py::array_t<size_t> average_idx,
py::array_t<double> predict_weights,
py::array_t<unsigned int> tree_seeds
) {
return reconstructree(data_ptr,
return reconstructForest(data_ptr,
ntree,
replace,
sampSize,
Expand Down Expand Up @@ -158,6 +160,8 @@ forestry *reconstructree_wrapper(
static_cast<int *>(na_left_count.request().ptr),
static_cast<int *>(na_right_count.request().ptr),
static_cast<int *>(na_default_directions.request().ptr),
static_cast<int *>(average_counts.request().ptr),
static_cast<int *>(split_counts.request().ptr),
static_cast<size_t *>(split_idx.request().ptr),
static_cast<size_t *>(average_idx.request().ptr),
static_cast<double *>(predict_weights.request().ptr),
Expand Down Expand Up @@ -361,10 +365,8 @@ PYBIND11_MODULE(extension, m)
Some other explanation about the getTreeLeafNodeCount function.
)pbdoc");
m.def("reconstruct_tree", &reconstructree_wrapper, py::return_value_policy::reference, R"pbdoc(
Some help text here
Some other explanation about the reconstructree function.
m.def("reconstruct_forest", &reconstructForestWrapper, py::return_value_policy::reference, R"pbdoc(
Create C++ forestry object
)pbdoc");
m.def("delete_forestry", &delete_forestry, R"pbdoc(
Some help text here
Expand Down
2 changes: 1 addition & 1 deletion Python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"numpy >= 1.23.5, < 2",
"pandas >= 1.4, < 2",
"statsmodels >= 0.13.5, < 1",
"pydantic >= 1.10.6, < 2",
"deepdiff >= 6.3.0, < 7",
"scikit-learn == 1.2.2",

# Conditional dependencies
Expand Down
Loading

0 comments on commit 368fe2e

Please sign in to comment.