diff --git a/sklearn_pmml_model/auto_detect/base.py b/sklearn_pmml_model/auto_detect/base.py index e48bbda..65fdd8b 100644 --- a/sklearn_pmml_model/auto_detect/base.py +++ b/sklearn_pmml_model/auto_detect/base.py @@ -54,8 +54,12 @@ def parse(file: Iterator, seek=False): if all(clf is PMMLTreeClassifier or clf is PMMLLogisticRegression for clf in clfs): if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line: + if seek: + pmml.seek(0) return PMMLForestClassifier(pmml=pmml, **kwargs) if 'multipleModelMethod="modelChain"' in line: + if seek: + pmml.seek(0) return PMMLGradientBoostingClassifier(pmml=pmml, **kwargs) raise Exception('Unsupported PMML classifier: invalid segmentation.') @@ -94,8 +98,12 @@ def parse(file: Iterator, seek=False): if all(reg is PMMLTreeRegressor or reg is PMMLLinearRegression for reg in regs): if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line: + if seek: + pmml.seek(0) return PMMLForestRegressor(pmml=pmml, **kwargs) if 'multipleModelMethod="sum"' in line: + if seek: + pmml.seek(0) return PMMLGradientBoostingRegressor(pmml=pmml, **kwargs) raise Exception('Unsupported PMML regressor: invalid segmentation.')