Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treelite JSON export: C++, R, Python implementation #144

Merged
merged 26 commits into from
Aug 17, 2023
Merged

Conversation

Ilia-Shutov
Copy link
Collaborator

@Ilia-Shutov Ilia-Shutov commented Aug 10, 2023

exportTreeliteJson(C++), export_treelite_json(R) methods are introduced

New R package dependancy required for R tests running: rjson

linear forests are not supported - throwing an exception in that case.

Tested manually by importing and comparing prediction results using Python Treelite library (slight precision differences present)

R manual test script used:

library(Rforestry)
library(jsonlite)

wd <- "~/Projects/Rforestry.data/"

set.seed(292313)
test_idx <- sample(nrow(iris), 11)
x_train <- iris[-test_idx, -1]
write(toJSON(x_train), file.path(wd, "R_x_train.json"))

y_train <- iris[-test_idx, 1]
x_test <- iris[test_idx, -1]
write(toJSON(x_test), file.path(wd, "R_x_test.json"))

rf <- forestry(x = x_train, y = y_train, ntree = 1, maxDepth = 1, seed = 2)
#rf <- forestry(x = x_train, y = y_train, nthread = 2)

x_test_predict = predict(rf, x_test)
print(x_test_predict)
write(toJSON(x_test_predict), file.path(wd, "R_x_test_predict.json"))

write(toJSON(rf@categoricalFeatureMapping), file.path(wd, "R_categoricalFeatureMapping.json"))

export_json_ret <- export_json(rf)
writeLines(export_json_ret, file.path(wd, "export_json.json"))

Python manual test script used:

import json
import treelite
import numpy as np

with open('../Rforestry.data/export_json.json', 'r', encoding='utf-8') as f:
    json_str = f.read()
model = treelite.Model.import_from_json(json_str)

with open('../Rforestry.data/R_categoricalFeatureMapping.json', 'r', encoding='utf-8') as f:
    r_categorical_feature_mapping = json.load(f)

col_mapping = {col_info['categoricalFeatureCol']: dict(zip(col_info['uniqueFeatureValues'], col_info['numericFeatureValues'])) for col_info in r_categorical_feature_mapping}

with open('../Rforestry.data/R_x_test.json', 'r', encoding='utf-8') as f:
    r_x_test_json = json.load(f)

with open('../Rforestry.data/R_x_test_predict.json', 'r', encoding='utf-8') as f:
    r_x_test_predict_json = json.load(f)

# r_x_test = list(zip(*list(r_x_test_json.values())))
r_x_test = [list(x if x != 'NA' else None for x in column_values) for column_values in zip(*list(r_x_test_json.values()))]

for row in r_x_test:
    for column_index_r, mapping in col_mapping.items():
        column_index = column_index_r - 1
        if row[column_index] is None:
            continue
        row[column_index] = mapping[row[column_index]]

X = np.array(r_x_test)
Y_expected = np.array(r_x_test_predict_json)

result = treelite.gtil.predict(model, data=X)
print(result)
results_compare = np.vstack((Y_expected, result))

Copy link
Collaborator

@edwardwliu edwardwliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good progress so far, left some comments.

.github/workflows/R.yaml Outdated Show resolved Hide resolved
R/R/forestry.R Outdated Show resolved Hide resolved
src/include/license.txt Outdated Show resolved Hide resolved
src/utils.cpp Outdated Show resolved Hide resolved
R/tests/testthat/test-export_treelite_json.R Outdated Show resolved Hide resolved
src/utils.h Outdated Show resolved Hide resolved
Python package name updated Rforestry -> random_forestry
pybind11 code simplified: forest* everywhere, removed excess functions
Using py::return_value_policy::reference to prevent Python GC from deleting the trained tree object
@Ilia-Shutov Ilia-Shutov changed the title Treelite JSON export: C++, R implementation Treelite JSON export: C++, R, Python implementation Aug 11, 2023
@Ilia-Shutov
Copy link
Collaborator Author

Pushed Python part, no new tests were added, checked manually only.

Haven't changed code much, because as I understand there is pending PR from @petrovicboban related to Python part soon.

Few comments about current pybind11 C++, what can be simplified:

  1. I'm sure that extern "C" { is not needed - as I understand, we're not building a library that should link to pybind11 build with a different compiler later, they're built together, even C++11 features could be used
  2. I have an assumption why there are void* instead of forestry* - by default pybind11 tries to guess if the memory visible to Python (returned by a function) should be managed with Python garbage collector - and in case it is void*, it gives up and leave the memory management to C++, not deleting it.
    This is fixed by py::return_value_policy::reference, checked that without it, GC seems to delete the memory and tests crash immediately 100% time.

@petrovicboban
Copy link
Contributor

petrovicboban commented Aug 11, 2023

@Ilia-Shutov I'm very novice in C++, what you can see there is just my best effort. If there is anything that you think you can improve, feel free to submit PR, and if tests are still succeeding, you should go for it. Btw, we don't have unit tests in C++, that's something which should be done eventually.

It seems also that you have domain knowledge, so feel free to explore sklearn compat branch, hopefully you can improve it. I'm devops guy only, working on CI/CD and general python, I don't have any DS domain skill and knowledge.

@edwardwliu
Copy link
Collaborator

Nice, if possible, let's add a basic Python test since the core function behavior should not change. Can you create a separate PR for the extern "C" change? Generally, if the pipeline tests complete, we should be good to merge in the clean-up.

@Ilia-Shutov
Copy link
Collaborator Author

Ilia-Shutov commented Aug 16, 2023

@edwardwliu Please don't merge yet, I've found a missing void* -> forestry* place which should be fixed or it will fail during serialisation/deserialisation, will push a commit soon

UPD: fixed. Serialization/Deserialization test case is a part of the next PR related to sklearn

@edwardwliu
Copy link
Collaborator

@Ilia-Shutov saw the new commit, is this ready to squash and merge?

@edwardwliu edwardwliu merged commit cc8f041 into master Aug 17, 2023
13 checks passed
@Ilia-Shutov Ilia-Shutov deleted the is/json_export branch August 22, 2023 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants