Skip to content

Commit

Permalink
feat: support non-seekable file-like objects
Browse files Browse the repository at this point in the history
  • Loading branch information
iamDecode committed Nov 5, 2023
1 parent 55de16a commit d0c0034
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
3 changes: 3 additions & 0 deletions sklearn_pmml_model/auto_detect/base.py
Expand Up @@ -23,6 +23,9 @@ def auto_detect_estimator(pmml, **kwargs):
Filename or file object containing PMML data.
"""
if isinstance(pmml, io.IOBase) and not pmml.seekable():
pmml = io.StringIO(pmml.read())

base = PMMLBaseEstimator(pmml=pmml)
target_field_name = base.target_field.attrib['name']
target_field_type = base.field_mapping[target_field_name][1]
Expand Down
37 changes: 36 additions & 1 deletion tests/auto_detect/test_auto_detect.py
@@ -1,5 +1,5 @@
from unittest import TestCase
from io import StringIO
from io import StringIO, UnsupportedOperation
import sklearn_pmml_model
from sklearn_pmml_model.auto_detect import auto_detect_estimator
from sklearn_pmml_model.tree import PMMLTreeClassifier, PMMLTreeRegressor
Expand Down Expand Up @@ -151,6 +151,41 @@ def test_auto_detect_file_object_regressor(self):

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_non_seekable_file_object(self):
class NoSeekStringIO(StringIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

string = NoSeekStringIO("""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="continuous" dataType="float"/>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="continuous" dataType="float" feature="predictedValue"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(string)

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_tree_classifier(self):
pmml = path.join(BASE_DIR, '../models/tree-iris.pmml')
assert isinstance(auto_detect_estimator(pmml=pmml), PMMLTreeClassifier)
Expand Down

0 comments on commit d0c0034

Please sign in to comment.