Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: RandomForest conforms to sklearn Estimator interface #146

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Python/docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ Retrieve the Tree Structure

This is an example of how to retrieve the underlying tree structure in the forest. To do that,
we need to use the :meth:`translate_tree() <forestry.RandomForest.translate_tree>` function,
which fills the :ref:`saved_forest <translate-label>` attribute for the corresponding tree.
which fills the :ref:`saved_forest_ <translate-label>` attribute for the corresponding tree.

.. code-block:: Python

Expand All @@ -315,5 +315,5 @@ which fills the :ref:`saved_forest <translate-label>` attribute for the correspo

# Translate the first tree in the forest
fr.translate_tree(0)
print(fr.saved_forest[0])
print(fr.saved_forest_[0])

104 changes: 59 additions & 45 deletions Python/extension/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ extern "C" {
std::vector<double>& coefs,
bool hier_shrinkage,
double lambda_shrinkage
) {
) {
DataFrame* dta_frame = reinterpret_cast<DataFrame *>(dataframe_pt);

forest->setDataframe(dta_frame);
Expand Down Expand Up @@ -498,7 +498,7 @@ extern "C" {
}


forestry* reconstructree(
forestry* reconstructForest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -532,6 +532,8 @@ extern "C" {
int* na_left_count,
int* na_right_count,
int* na_default_directions,
int* average_counts,
int* split_counts,
size_t* split_idx,
size_t* average_idx,
double* predict_weights,
Expand Down Expand Up @@ -584,10 +586,10 @@ extern "C" {
std::unique_ptr< std::vector< std::vector<int> > > var_ids(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<int> > > average_counts(
std::unique_ptr< std::vector< std::vector<int> > > averageCountsAll(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<int> > > split_counts(
std::unique_ptr< std::vector< std::vector<int> > > splitCountsAll(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<double> > > split_vals(
Expand Down Expand Up @@ -620,8 +622,8 @@ extern "C" {

// Reserve space for each of the vectors equal to ntree
var_ids->reserve(ntree);
average_counts->reserve(ntree);
split_counts->reserve(ntree);
averageCountsAll->reserve(ntree);
splitCountsAll->reserve(ntree);
split_vals->reserve(ntree);
averagingSampleIndex->reserve(ntree);
splittingSampleIndex->reserve(ntree);
Expand All @@ -632,65 +634,78 @@ extern "C" {
treeSeeds->reserve(ntree);
predictWeights->reserve(ntree);

// Now actually populate the vectors
// Deserialization of the data serialized in forestry.py
size_t ind = 0, ind_s = 0, ind_a = 0;
for(size_t i = 0; i < ntree; i++){
// Should be num total nodes
std::vector<int> cur_var_ids((tree_counts[4*i]), 0);
std::vector<int> cur_average_counts((tree_counts[4*i]), 0);
std::vector<int> cur_split_counts((tree_counts[4*i]), 0);
std::vector<double> cur_split_vals(tree_counts[4*i], 0);
std::vector<int> curNaLeftCounts(tree_counts[4*i], 0);
std::vector<int> curNaRightCounts(tree_counts[4*i], 0);
std::vector<int> curNaDefaultDirections(tree_counts[4*i], 0);
std::vector<size_t> curSplittingSampleIndex(tree_counts[4*i+1], 0);
std::vector<size_t> curAveragingSampleIndex(tree_counts[4*i+2], 0);
std::vector<double> cur_predict_weights(tree_counts[4*i], 0);

for(size_t j = 0; j < tree_counts[4*i]; j++){
cur_split_vals.at(j) = thresholds[ind];
curNaLeftCounts.at(j) = na_left_count[ind];
curNaRightCounts.at(j) = na_right_count[ind];
curNaDefaultDirections.at(j) = na_default_directions[ind];
cur_predict_weights.at(j) = predict_weights[ind];
cur_var_ids.at(j) = features[ind];
cur_average_counts.at(j) = features[ind];
cur_split_counts.at(j) = features[ind];

for (size_t i = 0; i < ntree; i++) {
// Py: tree_counts[3 * i] = num_nodes
const size_t numNodes = tree_counts[3 * i];

std::vector<double> curThresholds(numNodes);
std::vector<int> curNaLeftCounts(numNodes);
std::vector<int> curNaRightCounts(numNodes);
std::vector<int> curNaDefaultDirections(numNodes);
std::vector<int> curAverageCounts(numNodes);
std::vector<int> curSplitCounts(numNodes);
std::vector<double> cur_predict_weights(numNodes);
std::vector<int> cur_var_ids(numNodes);

for (size_t j = 0; j < numNodes; j++){
// for split node
curThresholds[j] = thresholds[ind];
curNaLeftCounts[j] = na_left_count[ind];
curNaRightCounts[j] = na_right_count[ind];
curNaDefaultDirections[j] = na_default_directions[ind];

// for leaf node
curAverageCounts[j] = average_counts[ind];
curSplitCounts[j] = split_counts[ind];

// if < 0: flag that this is a leaf node
// if >= 0: split node, feature == var_id - 1
cur_var_ids[j] = features[ind];

// leaf node weight or split node threshold
cur_predict_weights[j] = predict_weights[ind];
ind++;
}

for(size_t j = 0; j < tree_counts[4*i+1]; j++){
curSplittingSampleIndex.at(j) = split_idx[ind_s];

// Py: tree_counts[3 * i + 1] = num_split_idx
const size_t numSplitIdx = tree_counts[3 * i + 1];
std::vector<size_t> curSplittingSampleIndex(numSplitIdx);
for (size_t j = 0; j < numSplitIdx; j++) {
curSplittingSampleIndex[j] = split_idx[ind_s];
ind_s++;
}

for(size_t j = 0; j < tree_counts[4*i+2]; j++){
curAveragingSampleIndex.at(j) = average_idx[ind_a];

// Py: tree_counts[3 * i + 2] = num_av_idx
const size_t numAvIdx = tree_counts[3 * i + 2];
std::vector<size_t> curAveragingSampleIndex(numAvIdx);
for (size_t j = 0; j < numAvIdx; j++) {
curAveragingSampleIndex[j] = average_idx[ind_a];
ind_a++;
}


treeSeeds->push_back(tree_seeds[i]);
var_ids->push_back(cur_var_ids);
average_counts->push_back(cur_average_counts);
split_counts->push_back(cur_split_counts);
split_vals->push_back(cur_split_vals);
averageCountsAll->push_back(curAverageCounts);
splitCountsAll->push_back(curSplitCounts);
split_vals->push_back(curThresholds);
naLeftCounts->push_back(curNaLeftCounts);
naRightCounts->push_back(curNaRightCounts);
naDefaultDirections->push_back(curNaDefaultDirections);
splittingSampleIndex->push_back(curSplittingSampleIndex);
averagingSampleIndex->push_back(curAveragingSampleIndex);
excludedSampleIndex->push_back(std::vector<size_t>());
predictWeights->push_back(cur_predict_weights);
treeSeeds->push_back(tree_seeds[i]);
}

// call reconstructTrees
forest->reconstructTrees(
categoricalFeatureCols_copy,
treeSeeds,
var_ids,
average_counts,
split_counts,
averageCountsAll,
splitCountsAll,
split_vals,
naLeftCounts,
naRightCounts,
Expand All @@ -700,9 +715,8 @@ extern "C" {
excludedSampleIndex,
predictWeights
);

return forest;

}

size_t get_node_count(forestry* forest, int tree_idx) {
Expand Down
4 changes: 3 additions & 1 deletion Python/extension/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ extern "C" {
size_t numColumns,
unsigned int seed
);
forestry* reconstructree(
forestry* reconstructForest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -92,6 +92,8 @@ extern "C" {
int* na_left_count,
int* na_right_count,
int* na_default_directions,
int* average_counts,
int* split_counts,
size_t* split_idx,
size_t* average_idx,
double* predict_weights,
Expand Down
14 changes: 8 additions & 6 deletions Python/extension/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void *get_data_wrapper(
);
}

forestry *reconstructree_wrapper(
forestry *reconstructForestWrapper(
void *data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -120,12 +120,14 @@ forestry *reconstructree_wrapper(
py::array_t<int> na_left_count,
py::array_t<int> na_right_count,
py::array_t<int> na_default_directions,
py::array_t<int> average_counts,
py::array_t<int> split_counts,
py::array_t<size_t> split_idx,
py::array_t<size_t> average_idx,
py::array_t<double> predict_weights,
py::array_t<unsigned int> tree_seeds
) {
return reconstructree(data_ptr,
return reconstructForest(data_ptr,
ntree,
replace,
sampSize,
Expand Down Expand Up @@ -158,6 +160,8 @@ forestry *reconstructree_wrapper(
static_cast<int *>(na_left_count.request().ptr),
static_cast<int *>(na_right_count.request().ptr),
static_cast<int *>(na_default_directions.request().ptr),
static_cast<int *>(average_counts.request().ptr),
static_cast<int *>(split_counts.request().ptr),
static_cast<size_t *>(split_idx.request().ptr),
static_cast<size_t *>(average_idx.request().ptr),
static_cast<double *>(predict_weights.request().ptr),
Expand Down Expand Up @@ -361,10 +365,8 @@ PYBIND11_MODULE(extension, m)

Some other explanation about the getTreeLeafNodeCount function.
)pbdoc");
m.def("reconstruct_tree", &reconstructree_wrapper, py::return_value_policy::reference, R"pbdoc(
Some help text here

Some other explanation about the reconstructree function.
m.def("reconstruct_forest", &reconstructForestWrapper, py::return_value_policy::reference, R"pbdoc(
Create C++ forestry object
)pbdoc");
m.def("delete_forestry", &delete_forestry, R"pbdoc(
Some help text here
Expand Down
2 changes: 1 addition & 1 deletion Python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"numpy >= 1.23.5, < 2",
"pandas >= 1.4, < 2",
"statsmodels >= 0.13.5, < 1",
"pydantic >= 1.10.6, < 2",
"deepdiff >= 6.3.0, < 7",
"scikit-learn == 1.2.2",

# Conditional dependencies
Expand Down
Loading
Loading