Skip to content

Commit

Permalink
[R] fix EncodeChar
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Nov 27, 2017
1 parent 8a5ec36 commit f42e6c3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 35 deletions.
15 changes: 5 additions & 10 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,23 +273,21 @@ class Booster {
return ret;
}

#pragma warning(disable : 4996)
int GetEvalNames(char** out_strs) const {
int idx = 0;
for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) {
std::strcpy(out_strs[idx], name.c_str());
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx;
}
}
return idx;
}

#pragma warning(disable : 4996)
int GetFeatureNames(char** out_strs) const {
int idx = 0;
for (const auto& name : boosting_->FeatureNames()) {
std::strcpy(out_strs[idx], name.c_str());
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx;
}
return idx;
Expand Down Expand Up @@ -719,7 +717,6 @@ int LGBM_DatasetSetFeatureNames(
API_END();
}

#pragma warning(disable : 4996)
int LGBM_DatasetGetFeatureNames(
DatasetHandle handle,
char** feature_names,
Expand All @@ -729,7 +726,7 @@ int LGBM_DatasetGetFeatureNames(
auto inside_feature_name = dataset->feature_names();
*num_feature_names = static_cast<int>(inside_feature_name.size());
for (int i = 0; i < *num_feature_names; ++i) {
std::strcpy(feature_names[i], inside_feature_name[i].c_str());
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
}
API_END();
}
Expand Down Expand Up @@ -1138,7 +1135,6 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
API_END();
}

#pragma warning(disable : 4996)
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int64_t buffer_len,
Expand All @@ -1149,12 +1145,11 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
std::string model = ref_booster->SaveModelToString(num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str());
std::memcpy(out_str, model.c_str(), *out_len);
}
API_END();
}

#pragma warning(disable : 4996)
int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int64_t buffer_len,
Expand All @@ -1165,7 +1160,7 @@ int LGBM_BoosterDumpModel(BoosterHandle handle,
std::string model = ref_booster->DumpModel(num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str());
std::memcpy(out_str, model.c_str(), *out_len);
}
API_END();
}
Expand Down
33 changes: 8 additions & 25 deletions src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
using namespace LightGBM;

LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len) {
int str_len = static_cast<int>(std::strlen(src));
R_INT_PTR(actual_len)[0] = str_len;
size_t str_len = std::strlen(src);
if (str_len > INT32_MAX) {
Log::Fatal("Don't support large string in R-package.");
}
R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
if (R_AS_INT(buf_len) < str_len) { return dest; }
auto ptr = R_CHAR_PTR(dest);
int i = 0;
while (src[i] != '\0') {
ptr[i] = src[i];
++i;
}
std::memcpy(ptr, src, str_len);
return dest;
}

Expand Down Expand Up @@ -604,15 +603,7 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
if (out_len <= INT32_MAX) {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
} else {
Log::Fatal("Don't support large model in R package.");
}
}
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
R_API_END();
}

Expand All @@ -626,14 +617,6 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
if (out_len <= INT32_MAX) {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
} else {
Log::Fatal("Don't support large model in R package.");
}
}
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
R_API_END();
}

0 comments on commit f42e6c3

Please sign in to comment.