Skip to content

Commit

Permalink
Pull request "Add staged_predict functions to C API" by @Mb-NextTime
Browse files Browse the repository at this point in the history
…from #2584

Pull Request resolved: 2584
  • Loading branch information
Mb-NextTime authored and andrey-khropov committed Feb 8, 2024
1 parent c7cd9df commit 0adc70a
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 12 deletions.
118 changes: 106 additions & 12 deletions catboost/libs/model_interface/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,16 @@ CATBOOST_API bool SetPredictionTypeString(ModelCalcerHandle* modelHandle, const
return true;
}

CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
CATBOOST_API bool CalcModelPredictionFlatStaged(ModelCalcerHandle* modelHandle, size_t docCount, size_t treeStart, size_t treeEnd, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
try {
if (docCount == 1) {
FULL_MODEL_PTR(modelHandle)->CalcFlatSingle(TConstArrayRef<float>(*floatFeatures, floatFeaturesSize), TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->CalcFlatSingle(TConstArrayRef<float>(*floatFeatures, floatFeaturesSize), treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} else {
TVector<TConstArrayRef<float>> featuresVec(docCount);
for (size_t i = 0; i < docCount; ++i) {
featuresVec[i] = TConstArrayRef<float>(floatFeatures[i], floatFeaturesSize);
}
FULL_MODEL_PTR(modelHandle)->CalcFlat(featuresVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->CalcFlat(featuresVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
}
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
Expand All @@ -384,23 +384,37 @@ CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t
return true;
}

CATBOOST_API bool CalcModelPredictionFlatTransposed(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
return CalcModelPredictionFlatStaged(modelHandle, docCount, 0, GetTreeCount(modelHandle), floatFeatures, floatFeaturesSize, result, resultSize);
}

CATBOOST_API bool CalcModelPredictionFlatTransposedStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
double* result, size_t resultSize) {
try {
TVector<TConstArrayRef<float>> featuresVec(floatFeaturesSize);
for (size_t i = 0; i < floatFeaturesSize; ++i) {
featuresVec[i] = TConstArrayRef<float>(floatFeatures[i], docCount);
}
FULL_MODEL_PTR(modelHandle)->CalcFlatTransposed(featuresVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->CalcFlatTransposed(featuresVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPrediction(
CATBOOST_API bool CalcModelPredictionFlatTransposed(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) {
return CalcModelPredictionFlatTransposedStaged(modelHandle, docCount, 0, GetTreeCount(modelHandle), floatFeatures, floatFeaturesSize, result, resultSize);
}

CATBOOST_API bool CalcModelPredictionStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
Expand All @@ -415,17 +429,34 @@ CATBOOST_API bool CalcModelPrediction(
catFeaturesVec[i][catFeatureIdx] = catFeatures[i][catFeatureIdx];
}
}
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPredictionText(
CATBOOST_API bool CalcModelPrediction(
ModelCalcerHandle* modelHandle,
size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionStaged(
modelHandle,
docCount,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionTextStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
Expand All @@ -445,20 +476,46 @@ CATBOOST_API bool CalcModelPredictionText(
textFeaturesVec[i][textFeatureIdx] = textFeatures[i][textFeatureIdx];
}
}
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, textFeaturesVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->Calc(
floatFeaturesVec,
catFeaturesVec,
textFeaturesVec,
treeStart,
treeEnd,
TArrayRef<double>(result, resultSize)
);
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
CATBOOST_API bool CalcModelPredictionText(
ModelCalcerHandle* modelHandle,
size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionTextStaged(
modelHandle,
docCount,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
textFeatures, textFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionTextAndEmbeddingsStaged(
ModelCalcerHandle* modelHandle,
size_t docCount,
size_t treeStart, size_t treeEnd,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
const float*** embeddingFeatures, size_t* embeddingDimensions, size_t embeddingFeaturesSize,
double* result, size_t resultSize) {
try {
Expand Down Expand Up @@ -490,6 +547,8 @@ CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
catFeaturesVec,
textFeaturesVec,
embeddingFeaturesVec,
treeStart,
treeEnd,
TArrayRef<double>(result, resultSize)
);
} catch (...) {
Expand All @@ -499,8 +558,29 @@ CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
return true;
}

CATBOOST_API bool CalcModelPredictionSingle(
CATBOOST_API bool CalcModelPredictionTextAndEmbeddings(
ModelCalcerHandle* modelHandle,
size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const char*** catFeatures, size_t catFeaturesSize,
const char*** textFeatures, size_t textFeaturesSize,
const float*** embeddingFeatures, size_t* embeddingDimensions, size_t embeddingFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionTextAndEmbeddingsStaged(
modelHandle,
docCount,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
textFeatures, textFeaturesSize,
embeddingFeatures, embeddingDimensions, embeddingFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionSingleStaged(
ModelCalcerHandle* modelHandle,
size_t treeStart, size_t treeEnd,
const float* floatFeatures, size_t floatFeaturesSize,
const char** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
Expand All @@ -513,14 +593,28 @@ CATBOOST_API bool CalcModelPredictionSingle(
for (size_t catFeatureIdx = 0; catFeatureIdx < catFeaturesSize; ++catFeatureIdx) {
catFeaturesVec[0][catFeatureIdx] = catFeatures[catFeatureIdx];
}
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef<double>(result, resultSize));
FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, treeStart, treeEnd, TArrayRef<double>(result, resultSize));
} catch (...) {
Singleton<TErrorMessageHolder>()->Message = CurrentExceptionMessage();
return false;
}
return true;
}

CATBOOST_API bool CalcModelPredictionSingle(
ModelCalcerHandle* modelHandle,
const float* floatFeatures, size_t floatFeaturesSize,
const char** catFeatures, size_t catFeaturesSize,
double* result, size_t resultSize) {
return CalcModelPredictionSingleStaged(
modelHandle,
0, GetTreeCount(modelHandle),
floatFeatures, floatFeaturesSize,
catFeatures, catFeaturesSize,
result, resultSize
);
}

CATBOOST_API bool CalcModelPredictionWithHashedCatFeatures(ModelCalcerHandle* modelHandle, size_t docCount,
const float** floatFeatures, size_t floatFeaturesSize,
const int** catFeatures, size_t catFeaturesSize,
Expand Down

0 comments on commit 0adc70a

Please sign in to comment.