Skip to content

Commit

Permalink
Treelite JSON export: C++, R, Python implementation (#144)
Browse files Browse the repository at this point in the history
* Fix #114 - added export to Treelite JSON in Python and R packages
  • Loading branch information
Ilia-Shutov committed Aug 17, 2023
1 parent 4a6080d commit cc8f041
Show file tree
Hide file tree
Showing 27 changed files with 426 additions and 72 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/R.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: recursive

- name: Setup R
uses: r-lib/actions/setup-r@v2
Expand Down Expand Up @@ -74,6 +76,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: recursive

- name: Build package
run: R CMD build R/
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/cpp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: recursive

- name: Install dependencies
run: |
Expand All @@ -25,7 +27,7 @@ jobs:
- name: Lint
working-directory: src/
run: |
clang-tidy *.cpp -header-filter="^.*/src/.*" --use-color -- -I/usr/lib/R/site-library/RcppArmadillo/include -I/usr/lib/R/site-library/Rcpp/include -I/usr/share/R/include -I/RcppThread-2.1.2/inst/include
clang-tidy *.cpp -header-filter="^.*/src/.*" --use-color -- -Irapidjson/include -I/usr/lib/R/site-library/RcppArmadillo/include -I/usr/lib/R/site-library/Rcpp/include -I/usr/share/R/include -I/RcppThread-2.1.2/inst/include
testing:
runs-on: ubuntu-22.04
steps:
Expand Down
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@
path = tests/cpp/include/Catch2
url = https://github.com/catchorg/Catch2.git
branch = devel

[submodule "rapidjson"]
path = src/rapidjson
url = https://github.com/Tencent/rapidjson.git
branch = master
15 changes: 13 additions & 2 deletions Python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,24 @@ For full documentation, see the documentation site (https://random-forestry.read
git clone --recursive https://github.com/forestry-labs/Rforestry.git
cd Rforestry/Python

conda create -n rforestry python pandas build pytest pytest-xdist pytest-sugar pytest-cov
conda create -n rforestry python pandas build pytest pytest-xdist pytest-sugar pytest-cov mypy
conda activate rforestry

python -m build --sdist
pip install dist/random-forestry-*.tar.gz
pytest tests/
```
4. To be able to run and debug via IDE without full package installation, run this script from the same `Rforestry/Python` folder, it will generate binary of the C++ extension and Python stubs for it:
```bash
mkdir build

pushd build
cmake -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../random_forestry ../extension
cmake --build .
popd

PYTHONPATH=random_forestry stubgen -m extension -o .
```


### Python Package Usage
Expand All @@ -45,7 +56,7 @@ Then the python code can be called:
import numpy as np
import pandas as pd
from random import randrange
from Rforestry import RandomForest
from random_forestry import RandomForest
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

Expand Down
1 change: 1 addition & 0 deletions Python/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ endif()
set_target_properties(extension PROPERTIES CXX_VISIBILITY_PRESET "hidden" CUDA_VISIBILITY_PRESET "hidden")

target_include_directories(extension PRIVATE src)
target_include_directories(extension PRIVATE src/rapidjson/include)

# VERSION_INFO is defined by setup.py and passed into the C++ code as a
# define (VERSION_INFO) here.
Expand Down
37 changes: 16 additions & 21 deletions Python/extension/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ extern "C" {
}


void* train_forest(
forestry* train_forest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -247,7 +247,7 @@ extern "C" {
}

void predict_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
unsigned int seed,
Expand All @@ -264,9 +264,6 @@ extern "C" {
bool hier_shrinkage,
double lambda_shrinkage
) {


forestry* forest = reinterpret_cast<forestry *>(forest_pt);
DataFrame* dta_frame = reinterpret_cast<DataFrame *>(dataframe_pt);

forest->setDataframe(dta_frame);
Expand Down Expand Up @@ -376,7 +373,7 @@ extern "C" {


void predictOOB_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
bool doubleOOB,
Expand All @@ -390,9 +387,8 @@ extern "C" {
double lambda_shrinkage
) {
if (verbose)
std::cout << forest_pt << std::endl;
std::cout << forest << std::endl;

forestry* forest = reinterpret_cast<forestry *>(forest_pt);
DataFrame* dta_frame = reinterpret_cast<DataFrame *>(dataframe_pt);
forest->setDataframe(dta_frame);

Expand Down Expand Up @@ -456,15 +452,13 @@ extern "C" {
}

void fill_tree_info(
void* forest_ptr,
forestry* forest,
int tree_idx,
std::vector<double>& treeInfo,
std::vector<int>& split_info,
std::vector<int>& av_info
) {

forestry* forest = reinterpret_cast<forestry *>(forest_ptr);

std::unique_ptr<tree_info> info_holder;

info_holder = forest->getForest()->at(tree_idx)->getTreeInfo(forest->getTrainingData());
Expand Down Expand Up @@ -504,7 +498,7 @@ extern "C" {
}


void* reconstructree(
forestry* reconstructree(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -710,24 +704,25 @@ extern "C" {
return forest;

}

size_t get_node_count(void* forest_pt, int tree_idx) {
forestry* forest = reinterpret_cast<forestry *>(forest_pt);

size_t get_node_count(forestry* forest, int tree_idx) {
return(forest->getForest()->at(tree_idx)->getNodeCount());
}

size_t get_split_node_count(void* forest_pt, int tree_idx) {
forestry* forest = reinterpret_cast<forestry *>(forest_pt);
size_t get_split_node_count(forestry* forest, int tree_idx) {
return(forest->getForest()->at(tree_idx)->getSplitNodeCount());
}

size_t get_leaf_node_count(void* forest_pt, int tree_idx) {
forestry* forest = reinterpret_cast<forestry *>(forest_pt);
size_t get_leaf_node_count(forestry* forest, int tree_idx) {
return(forest->getForest()->at(tree_idx)->getLeafNodeCount());
}

void delete_forestry(void* forest_pt, void* dataframe_pt) {
void delete_forestry(forestry* forest, void* dataframe_pt) {
delete(reinterpret_cast<DataFrame* >(dataframe_pt));
delete(reinterpret_cast<forestry* >(forest_pt));
delete(forest);
}
}

std::string export_json(forestry* forest, const std::vector<double>& colSds, const std::vector<double>& colMeans) {
return exportJson(*forest, colSds, colMeans);
}
36 changes: 14 additions & 22 deletions Python/extension/api.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#pragma once

#include <vector>
#include <string>
#include <iostream>
#include <random>
#include "forestry.h"
#include "utils.h"

#ifndef FORESTRYCPP_API_H
#define FORESTRYCPP_API_H

#endif //FORESTRYCPP_API_H


extern "C" {
void* train_forest(
forestry* train_forest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -61,7 +58,7 @@ extern "C" {
size_t numColumns,
unsigned int seed
);
void* reconstructree(
forestry* reconstructree(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -101,7 +98,7 @@ extern "C" {
unsigned int* tree_seeds
);
void predictOOB_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
bool doubleOOB,
Expand All @@ -115,7 +112,7 @@ extern "C" {
double lambda_shrinkage
);
void predict_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
unsigned int seed,
Expand All @@ -133,21 +130,16 @@ extern "C" {
double lambda_shrinkage = 0
);
void fill_tree_info(
void* forest_ptr,
int tree_idx,
std::vector<double>& treeInfo,
std::vector<int>& split_info,
std::vector<int>& av_info
);
void fill_tree_info(
void* forest_ptr,
forestry* forest,
int tree_idx,
std::vector<double>& treeInfo,
std::vector<int>& split_info,
std::vector<int>& av_info
);
size_t get_node_count(void* forest_pt, int tree_idx);
size_t get_split_node_count(void* forest_pt, int tree_idx);
size_t get_leaf_node_count(void* forest_pt, int tree_idx);
void delete_forestry(void* forest_pt, void* dataframe_pt);
}
size_t get_node_count(forestry* forest, int tree_idx);
size_t get_split_node_count(forestry* forest, int tree_idx);
size_t get_leaf_node_count(forestry* forest, int tree_idx);
void delete_forestry(forestry* forest, void* dataframe_pt);
}

std::string export_json(forestry* forest, const std::vector<double>& colSds, const std::vector<double>& colMeans);
Loading

0 comments on commit cc8f041

Please sign in to comment.