Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for file objects containing bytes (like BytesIO) #53

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 17 additions & 3 deletions sklearn_pmml_model/auto_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def auto_detect_estimator(pmml, **kwargs):

"""
if isinstance(pmml, io.IOBase) and not pmml.seekable():
pmml = io.StringIO(pmml.read())
content = pmml.read()
if isinstance(content, bytes):
pmml = io.BytesIO(content)
if isinstance(content, str):
pmml = io.StringIO(content)

base = PMMLBaseEstimator(pmml=pmml)
target_field_name = base.target_field.attrib['name']
Expand All @@ -49,6 +53,8 @@ def auto_detect_classifier(pmml, **kwargs):

def parse(file: Iterator, seek=False):
for line in file:
if isinstance(line, bytes):
line = line.decode('utf8')
if '<Segmentation' in line:
clfs = [x for x in (detect_classifier(line) for line in file) if x is not None]

Expand Down Expand Up @@ -93,6 +99,8 @@ def auto_detect_regressor(pmml, **kwargs):

def parse(file: Iterator, seek=False):
for line in file:
if isinstance(line, bytes):
line = line.decode('utf8')
if '<Segmentation' in line:
regs = [x for x in (detect_regressor(line) for line in file) if x is not None]

Expand Down Expand Up @@ -130,13 +138,16 @@ def detect_classifier(line):

Parameters
----------
line : str
line : str, bytes
Line of a PMML file as a string.

pmml : str, object
Filename or file object containing PMML data.

"""
if isinstance(line, bytes):
line = line.decode('utf8')

if '<TreeModel' in line:
return PMMLTreeClassifier

Expand Down Expand Up @@ -167,13 +178,16 @@ def detect_regressor(line):

Parameters
----------
line : str
line : str, bytes
Line of a PMML file as a string.

pmml : str, object
Filename or file object containing PMML data.

"""
if isinstance(line, bytes):
line = line.decode('utf8')

if '<TreeModel' in line:
return PMMLTreeRegressor

Expand Down
201 changes: 199 additions & 2 deletions tests/auto_detect/test_auto_detect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
from io import StringIO, UnsupportedOperation
from io import open, StringIO, BytesIO, UnsupportedOperation
import sklearn_pmml_model
from sklearn_pmml_model.auto_detect import auto_detect_estimator
from sklearn_pmml_model.tree import PMMLTreeClassifier, PMMLTreeRegressor
Expand Down Expand Up @@ -151,7 +151,47 @@ def test_auto_detect_file_object_regressor(self):

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_non_seekable_file_object(self):
def test_auto_detect_non_seekable_file_object_classifier(self):
class NoSeekStringIO(StringIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

string = NoSeekStringIO("""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(string)

assert isinstance(clf, PMMLLogisticRegression)


def test_auto_detect_non_seekable_file_object_regressor(self):
class NoSeekStringIO(StringIO):
def seekable(self) -> bool:
return False
Expand Down Expand Up @@ -186,6 +226,163 @@ def seek(self, __cookie: int, __whence: int = ...) -> int:

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_bytes_file_object_classifier(self):
bytes = BytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLogisticRegression)

def test_auto_detect_non_seekable_bytes_file_object_classifier(self):
class NoSeekBytesIO(BytesIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

bytes = NoSeekBytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLogisticRegression)

def test_auto_detect_bytes_file_object_regressor(self):
bytes = BytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="continuous" dataType="float"/>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="continuous" dataType="float" feature="predictedValue"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_non_seekable_bytes_file_object_regressor(self):
class NoSeekBytesIO(BytesIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

bytes = NoSeekBytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="continuous" dataType="float"/>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="continuous" dataType="float" feature="predictedValue"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_bytes_random_forest_classifier(self):
pmml = path.join(BASE_DIR, '../models/rf-cat-pima.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLForestClassifier)

def test_auto_detect_bytes_random_forest_regressor(self):
pmml = path.join(BASE_DIR, '../models/rf-cat-pima-regression.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLForestRegressor)
def test_auto_detect_bytes_gradient_boosting_classifier(self):
pmml = path.join(BASE_DIR, '../models/gb-xgboost-iris.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLGradientBoostingClassifier)

def test_auto_detect_bytes_gradient_boosting_regressor(self):
pmml = path.join(BASE_DIR, '../models/gb-gbm-cat-pima-regression.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLGradientBoostingRegressor)

def test_auto_detect_tree_classifier(self):
pmml = path.join(BASE_DIR, '../models/tree-iris.pmml')
assert isinstance(auto_detect_estimator(pmml=pmml), PMMLTreeClassifier)
Expand Down