Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add staged_predict functions to C API #2584

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 106 additions & 12 deletions catboost/libs/model_interface/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,16 @@ CATBOOST_API bool SetPredictionTypeString(ModelCalcerHandle* modelHandle, const
return true;
}

CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
CATBOOST_API bool CalcModelPredictionFlatStaged(ModelCalcerHandle* modelHandle, size_t docCount, size_t treeStart, size_t treeEnd, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
try {
if (docCount == 1) {
FULL_MODEL_PTR(modelHandle)->CalcFlatSingle(TConstArrayRef<float>(*floatFeatures, floatFeaturesSize), TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->CalcFlatSingle(TConstArrayRef<float>(*floatFeatures, floatFeaturesSize), treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} else {
TVector<TConstArrayRef<float>> featuresVec(docCount);
for (size_t i = 0; i < docCount; ++i) {
featuresVec[i] = TConstArrayRef<float>(floatFeatures[i], floatFeaturesSize);
}
FULL_MODEL_PTR(modelHandle)->CalcFlat(featuresVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->CalcFlat(featuresVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
}
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
Expand All @@ -384,23 +384,37 @@ CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t
return true;
}

CATBOOST_API bool CalcModelPredictionFlatTransposed(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
return CalcModelPredictionFlatStaged(modelHandle, docCount, 0, GetTreeCount(modelHandle), floatFeatures, floatFeaturesSize, result, resultSize);
}

CATBOOST_API bool CalcModelPredictionFlatTransposedStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
double* result, size_t resultSize) {
try {
TVector<TConstArrayRef<float>> featuresVec(floatFeaturesSize);
for (size_t i = 0; i < floatFeaturesSize; ++i) {
featuresVec[i] = TConstArrayRef<float>(floatFeatures[i], docCount);
}
FULL_MODEL_PTR(modelHandle)->CalcFlatTransposed(featuresVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->CalcFlatTransposed(featuresVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPrediction(
CATBOOST_API bool CalcModelPredictionFlatTransposed(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
return CalcModelPredictionFlatTransposedStaged(modelHandle, docCount, 0, GetTreeCount(modelHandle), floatFeatures, floatFeaturesSize, result, resultSize);
}

CATBOOST_API bool CalcModelPredictionStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
Expand All @@ -415,19 +429,36 @@ CATBOOST_API bool CalcModelPrediction(
catFeaturesVec[i][catFeatureIdx] = catFeatures[i][catFeatureIdx];
}
}
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPredictionText(
CATBOOST_API bool CalcModelPrediction(
ModelCalcerHandle* modelHandle,
size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionStaged(
modelHandle,
docCount,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionTextStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
double* result, size_t resultSize) {
try {
Expand All @@ -445,17 +476,43 @@ CATBOOST_API bool CalcModelPredictionText(
textFeaturesVec[i][textFeatureIdx] = textFeatures[i][textFeatureIdx];
}
}
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, textFeaturesVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->Calc(
floatFeaturesVec,
catFeaturesVec,
textFeaturesVec,
treeStart,
treeEnd,
TArrayRef<double>(result, resultSize)
);
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
CATBOOST_API bool CalcModelPredictionText(
ModelCalcerHandle* modelHandle,
size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionTextStaged(
modelHandle,
docCount,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
textFeatures, textFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionTextAndEmbeddingsStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
Expand Down Expand Up @@ -490,6 +547,8 @@ CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
catFeaturesVec,
textFeaturesVec,
embeddingFeaturesVec,
treeStart,
treeEnd,
TArrayRef<double>(result, resultSize)
);
} catch (...) {
Expand All @@ -499,8 +558,29 @@ CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
return true;
}

CATBOOST_API bool CalcModelPredictionSingle(
CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
ModelCalcerHandle* modelHandle,
size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
const float*** embeddingFeatures, size_t* embeddingDimensions, size_t embeddingFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionTextAndEmbeddingsStaged(
modelHandle,
docCount,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
textFeatures, textFeaturesSize,
embeddingFeatures, embeddingDimensions, embeddingFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionSingleStaged(
ModelCalcerHandle* modelHandle,
size_t treeStart, size_t treeEnd,
const float* floatFeatures, size_t floatFeaturesSize,
const char** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
Expand All @@ -513,14 +593,28 @@ CATBOOST_API bool CalcModelPredictionSingle(
for (size_t catFeatureIdx = 0; catFeatureIdx < catFeaturesSize; ++catFeatureIdx) {
catFeaturesVec[0][catFeatureIdx] = catFeatures[catFeatureIdx];
}
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPredictionSingle(
ModelCalcerHandle* modelHandle,
const float* floatFeatures, size_t floatFeaturesSize,
const char** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionSingleStaged(
modelHandle,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionWithHashedCatFeatures(ModelCalcerHandle* modelHandle, size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const int** catFeatures, size_t catFeaturesSize,
Expand Down