From a4bed072a426ec913d8e17777afcd0a4f12b555d Mon Sep 17 00:00:00 2001 From: Dennis Collaris Date: Sun, 5 Nov 2023 13:24:16 +0100 Subject: [PATCH] fix: ensure to reseek before loading segmentation models if needed --- sklearn_pmml_model/auto_detect/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn_pmml_model/auto_detect/base.py b/sklearn_pmml_model/auto_detect/base.py index e48bbda..65fdd8b 100644 --- a/sklearn_pmml_model/auto_detect/base.py +++ b/sklearn_pmml_model/auto_detect/base.py @@ -54,8 +54,12 @@ def parse(file: Iterator, seek=False): if all(clf is PMMLTreeClassifier or clf is PMMLLogisticRegression for clf in clfs): if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line: + if seek: + pmml.seek(0) return PMMLForestClassifier(pmml=pmml, **kwargs) if 'multipleModelMethod="modelChain"' in line: + if seek: + pmml.seek(0) return PMMLGradientBoostingClassifier(pmml=pmml, **kwargs) raise Exception('Unsupported PMML classifier: invalid segmentation.') @@ -94,8 +98,12 @@ def parse(file: Iterator, seek=False): if all(reg is PMMLTreeRegressor or reg is PMMLLinearRegression for reg in regs): if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line: + if seek: + pmml.seek(0) return PMMLForestRegressor(pmml=pmml, **kwargs) if 'multipleModelMethod="sum"' in line: + if seek: + pmml.seek(0) return PMMLGradientBoostingRegressor(pmml=pmml, **kwargs) raise Exception('Unsupported PMML regressor: invalid segmentation.')