Skip to content

Commit

Permalink
fix: prevent closed file errors when loading from file-like objects
Browse files Browse the repository at this point in the history
  • Loading branch information
iamDecode committed Nov 5, 2023
1 parent e187c8a commit 55de16a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 42 deletions.
91 changes: 49 additions & 42 deletions sklearn_pmml_model/auto_detect/base.py
@@ -1,3 +1,6 @@
import io
from collections.abc import Iterator

from sklearn_pmml_model.base import PMMLBaseEstimator
from sklearn_pmml_model.datatypes import Category
from sklearn_pmml_model.tree import PMMLTreeClassifier, PMMLTreeRegressor
Expand Down Expand Up @@ -40,32 +43,34 @@ def auto_detect_classifier(pmml, **kwargs):
Filename or file object containing PMML data.
"""
if isinstance(pmml, str):
file = open(pmml, 'r')
else:
pmml.seek(0)
file = pmml

for line in file:
if '<Segmentation' in line:
clfs = [x for x in (detect_classifier(line) for line in file) if x is not None]
file.close()
def parse(file: Iterator, seek=False):
for line in file:
if '<Segmentation' in line:
clfs = [x for x in (detect_classifier(line) for line in file) if x is not None]

if all(clf is PMMLTreeClassifier or clf is PMMLLogisticRegression for clf in clfs):
if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line:
return PMMLForestClassifier(pmml=pmml, **kwargs)
if 'multipleModelMethod="modelChain"' in line:
return PMMLGradientBoostingClassifier(pmml=pmml, **kwargs)
if all(clf is PMMLTreeClassifier or clf is PMMLLogisticRegression for clf in clfs):
if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line:
return PMMLForestClassifier(pmml=pmml, **kwargs)
if 'multipleModelMethod="modelChain"' in line:
return PMMLGradientBoostingClassifier(pmml=pmml, **kwargs)

raise Exception('Unsupported PMML classifier: invalid segmentation.')
raise Exception('Unsupported PMML classifier: invalid segmentation.')

clf = detect_classifier(line)
if clf:
file.close()
return clf(pmml, **kwargs)
clf = detect_classifier(line)
if clf:
if seek:
pmml.seek(0)
return clf(pmml, **kwargs)

file.close()
raise Exception('Unsupported PMML classifier.')
raise Exception('Unsupported PMML classifier.')

if isinstance(pmml, str):
with io.open(pmml, 'r') as f:
return parse(f)
else:
pmml.seek(0)
return parse(pmml, seek=True)


def auto_detect_regressor(pmml, **kwargs):
Expand All @@ -78,32 +83,34 @@ def auto_detect_regressor(pmml, **kwargs):
Filename or file object containing PMML data.
"""
if isinstance(pmml, str):
file = open(pmml, 'r')
else:
pmml.seek(0)
file = pmml

for line in file:
if '<Segmentation' in line:
regs = [x for x in (detect_regressor(line) for line in file) if x is not None]
file.close()
def parse(file: Iterator, seek=False):
for line in file:
if '<Segmentation' in line:
regs = [x for x in (detect_regressor(line) for line in file) if x is not None]

if all(reg is PMMLTreeRegressor or reg is PMMLLinearRegression for reg in regs):
if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line:
return PMMLForestRegressor(pmml=pmml, **kwargs)
if 'multipleModelMethod="sum"' in line:
return PMMLGradientBoostingRegressor(pmml=pmml, **kwargs)

if all(reg is PMMLTreeRegressor or reg is PMMLLinearRegression for reg in regs):
if 'multipleModelMethod="majorityVote"' in line or 'multipleModelMethod="average"' in line:
return PMMLForestRegressor(pmml=pmml, **kwargs)
if 'multipleModelMethod="sum"' in line:
return PMMLGradientBoostingRegressor(pmml=pmml, **kwargs)
raise Exception('Unsupported PMML regressor: invalid segmentation.')

raise Exception('Unsupported PMML regressor: invalid segmentation.')
reg = detect_regressor(line)
if reg:
if seek:
pmml.seek(0)
return reg(pmml, **kwargs)

reg = detect_regressor(line)
if reg:
file.close()
return reg(pmml, **kwargs)
raise Exception('Unsupported PMML regressor.')

file.close()
raise Exception('Unsupported PMML regressor.')
if isinstance(pmml, str):
with io.open(pmml, 'r') as f:
return parse(f)
else:
pmml.seek(0)
return parse(pmml, seek=True)


def detect_classifier(line):
Expand Down
58 changes: 58 additions & 0 deletions tests/auto_detect/test_auto_detect.py
Expand Up @@ -93,6 +93,64 @@ def test_auto_detect_invalid_regressor_segmentation(self):

assert str(cm.exception) == 'Unsupported PMML regressor: invalid segmentation.'

def test_auto_detect_file_object_classifier(self):
clf = auto_detect_estimator(StringIO("""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="categorical" dataType="string">
<Value value="A"/>
<Value value="B"/>
</DataField>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" />
<MiningField name="Class" usageType="target" />
</MiningSchema>
<Output>
<OutputField name="probability(A)" optype="continuous" dataType="double" feature="probability" value="A"/>
<OutputField name="probability(B)" optype="continuous" dataType="double" feature="probability" value="B"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
"""))

assert isinstance(clf, PMMLLogisticRegression)

def test_auto_detect_file_object_regressor(self):
clf = auto_detect_estimator(StringIO("""
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<DataDictionary>
<DataField name="feature" optype="continuous" dataType="float"/>
<DataField name="Class" optype="continuous" dataType="float"/>
</DataDictionary>
<MiningSchema>
<MiningField name="Class" usageType="target"/>
</MiningSchema>
<RegressionModel>
<MiningSchema>
<MiningField name="feature" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="continuous" dataType="float" feature="predictedValue"/>
</Output>
<RegressionTable intercept="-1">
<NumericPredictor name="feature" exponent="1" coefficient="0.1"/>
</RegressionTable>
</RegressionModel>
</PMML>
"""))

assert isinstance(clf, PMMLLinearRegression)

def test_auto_detect_tree_classifier(self):
pmml = path.join(BASE_DIR, '../models/tree-iris.pmml')
assert isinstance(auto_detect_estimator(pmml=pmml), PMMLTreeClassifier)
Expand Down

0 comments on commit 55de16a

Please sign in to comment.