Skip to content

Commit

Permalink
feat: add support for file objects containing bytes (like BytesIO)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamDecode committed Nov 22, 2023
1 parent a4bed07 commit 237c1f2
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 5 deletions.
20 changes: 17 additions & 3 deletions sklearn_pmml_model/auto_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def auto_detect_estimator(pmml, **kwargs):
"""
if isinstance(pmml, io.IOBase) and not pmml.seekable():
pmml = io.StringIO(pmml.read())
content = pmml.read()
if isinstance(content, bytes):
pmml = io.BytesIO(content)
if isinstance(content, str):
pmml = io.StringIO(content)

base = PMMLBaseEstimator(pmml=pmml)
target_field_name = base.target_field.attrib['name']
Expand All @@ -49,6 +53,8 @@ def auto_detect_classifier(pmml, **kwargs):

def parse(file: Iterator, seek=False):
for line in file:
if isinstance(line, bytes):
line = line.decode('utf8')
if '<Segmentation' in line:
clfs = [x for x in (detect_classifier(line) for line in file) if x is not None]

Expand Down Expand Up @@ -93,6 +99,8 @@ def auto_detect_regressor(pmml, **kwargs):

def parse(file: Iterator, seek=False):
for line in file:
if isinstance(line, bytes):
line = line.decode('utf8')
if '<Segmentation' in line:
regs = [x for x in (detect_regressor(line) for line in file) if x is not None]

Expand Down Expand Up @@ -130,13 +138,16 @@ def detect_classifier(line):
Parameters
----------
line : str
line : str, bytes
Line of a PMML file as a string.
pmml : str, object
Filename or file object containing PMML data.
"""
if isinstance(line, bytes):
line = line.decode('utf8')

if '<TreeModel' in line:
return PMMLTreeClassifier

Expand Down Expand Up @@ -167,13 +178,16 @@ def detect_regressor(line):
Parameters
----------
line : str
line : str, bytes
Line of a PMML file as a string.
pmml : str, object
Filename or file object containing PMML data.
"""
if isinstance(line, bytes):
line = line.decode('utf8')

if '<TreeModel' in line:
return PMMLTreeRegressor

Expand Down
201 changes: 199 additions & 2 deletions tests/auto_detect/test_auto_detect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
from io import StringIO, UnsupportedOperation
from io import open, StringIO, BytesIO, UnsupportedOperation
import sklearn_pmml_model
from sklearn_pmml_model.auto_detect import auto_detect_estimator
from sklearn_pmml_model.tree import PMMLTreeClassifier, PMMLTreeRegressor
Expand Down Expand Up @@ -151,7 +151,47 @@ def test_auto_detect_file_object_regressor(self):

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_non_seekable_file_object(self):
def test_auto_detect_non_seekable_file_object_classifier(self):
class NoSeekStringIO(StringIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

string = NoSeekStringIO("""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(string)

assert isinstance(clf, PMMLLogisticRegression)


def test_auto_detect_non_seekable_file_object_regressor(self):
class NoSeekStringIO(StringIO):
def seekable(self) -> bool:
return False
Expand Down Expand Up @@ -186,6 +226,163 @@ def seek(self, __cookie: int, __whence: int = ...) -> int:

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_bytes_file_object_classifier(self):
bytes = BytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLogisticRegression)

def test_auto_detect_non_seekable_bytes_file_object_classifier(self):
class NoSeekBytesIO(BytesIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

bytes = NoSeekBytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLogisticRegression)

def test_auto_detect_bytes_file_object_regressor(self):
bytes = BytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="continuous" dataType="float"/>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="continuous" dataType="float" feature="predictedValue"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_non_seekable_bytes_file_object_regressor(self):
class NoSeekBytesIO(BytesIO):
def seekable(self) -> bool:
return False

def seek(self, __cookie: int, __whence: int = ...) -> int:
raise UnsupportedOperation('seek')

bytes = NoSeekBytesIO(b"""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="continuous" dataType="float"/>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="continuous" dataType="float" feature="predictedValue"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
""")
clf = auto_detect_estimator(bytes)

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_bytes_random_forest_classifier(self):
pmml = path.join(BASE_DIR, '../models/rf-cat-pima.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLForestClassifier)

def test_auto_detect_bytes_random_forest_regressor(self):
pmml = path.join(BASE_DIR, '../models/rf-cat-pima-regression.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLForestRegressor)
def test_auto_detect_bytes_gradient_boosting_classifier(self):
pmml = path.join(BASE_DIR, '../models/gb-xgboost-iris.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLGradientBoostingClassifier)

def test_auto_detect_bytes_gradient_boosting_regressor(self):
pmml = path.join(BASE_DIR, '../models/gb-gbm-cat-pima-regression.pmml')
with open(pmml) as file:
content = str.encode(file.read())
assert isinstance(auto_detect_estimator(pmml=BytesIO(content)), PMMLGradientBoostingRegressor)

def test_auto_detect_tree_classifier(self):
pmml = path.join(BASE_DIR, '../models/tree-iris.pmml')
assert isinstance(auto_detect_estimator(pmml=pmml), PMMLTreeClassifier)
Expand Down

0 comments on commit 237c1f2

Please sign in to comment.