From d0c0034cfbdfab1b4312deaec0ab3c7324ce0ff3 Mon Sep 17 00:00:00 2001 From: Dennis Collaris Date: Sun, 5 Nov 2023 11:54:16 +0100 Subject: [PATCH] feat: support non-seekable file-like objects --- sklearn_pmml_model/auto_detect/base.py | 3 +++ tests/auto_detect/test_auto_detect.py | 37 +++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sklearn_pmml_model/auto_detect/base.py b/sklearn_pmml_model/auto_detect/base.py index 2ab3ca8..e48bbda 100644 --- a/sklearn_pmml_model/auto_detect/base.py +++ b/sklearn_pmml_model/auto_detect/base.py @@ -23,6 +23,9 @@ def auto_detect_estimator(pmml, **kwargs): Filename or file object containing PMML data. """ + if isinstance(pmml, io.IOBase) and not pmml.seekable(): + pmml = io.StringIO(pmml.read()) + base = PMMLBaseEstimator(pmml=pmml) target_field_name = base.target_field.attrib['name'] target_field_type = base.field_mapping[target_field_name][1] diff --git a/tests/auto_detect/test_auto_detect.py b/tests/auto_detect/test_auto_detect.py index da5d2d5..8f572db 100644 --- a/tests/auto_detect/test_auto_detect.py +++ b/tests/auto_detect/test_auto_detect.py @@ -1,5 +1,5 @@ from unittest import TestCase -from io import StringIO +from io import StringIO, UnsupportedOperation import sklearn_pmml_model from sklearn_pmml_model.auto_detect import auto_detect_estimator from sklearn_pmml_model.tree import PMMLTreeClassifier, PMMLTreeRegressor @@ -151,6 +151,41 @@ def test_auto_detect_file_object_regressor(self): assert isinstance(clf, PMMLLinearRegression) + def test_auto_detect_non_seekable_file_object(self): + class NoSeekStringIO(StringIO): + def seekable(self) -> bool: + return False + + def seek(self, __cookie: int, __whence: int = ...) -> int: + raise UnsupportedOperation('seek') + + string = NoSeekStringIO(""" + + + + + + + + + + + + + + + + + + + + + + """) + clf = auto_detect_estimator(string) + + assert isinstance(clf, PMMLLinearRegression) + def test_auto_detect_tree_classifier(self): pmml = path.join(BASE_DIR, '../models/tree-iris.pmml') assert isinstance(auto_detect_estimator(pmml=pmml), PMMLTreeClassifier)