Skip to content

Commit

Permalink
fix: correctly retrieve votes distribution for float target trees
Browse files Browse the repository at this point in the history
Closes #44
  • Loading branch information
iamDecode committed Aug 15, 2022
1 parent 672ac95 commit 40df678
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sklearn_pmml_model/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,13 @@ def construct_tree(node, classes, field_mapping, i=0, rescale_factor=1):
i += 1

def votes_for(field):
distribution = node.find(f"ScoreDistribution[@value='{field}']")

# Deal with case where target field is a double, but ScoreDistribution value is an integer.
if isinstance(field, float) and field.is_integer():
return node.find(f"ScoreDistribution[@value='{field}']") or node.find(f"ScoreDistribution[@value='{int(field)}']")
if distribution is None and isinstance(field, float) and field.is_integer():
return node.find(f"ScoreDistribution[@value='{int(field)}']")

return node.find(f"ScoreDistribution[@value='{field}']")
return distribution

if not child_nodes:
record_count = node.get('recordCount')
Expand Down

0 comments on commit 40df678

Please sign in to comment.