Skip to content

Commit

Permalink
Added 'Evaluator.suppressResultFields()' method
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 2, 2023
1 parent 0d6f035 commit 1074ef4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
16 changes: 15 additions & 1 deletion jpmml_evaluator/__init__.py
Expand Up @@ -120,6 +120,9 @@ def evaluate(self, arguments, nan_as_missing = True):
except Exception as e:
raise self.backend.toJavaError(e)
results = self.backend.loads(results)
if hasattr(self, "dropColumns"):
for dropColumn in self.dropColumns:
del results[dropColumn]
return results

def evaluateAll(self, arguments_df, nan_as_missing = True):
Expand All @@ -131,7 +134,18 @@ def evaluateAll(self, arguments_df, nan_as_missing = True):
except Exception as e:
raise self.backend.toJavaError(e)
result_records = self.backend.loads(result_records)
return DataFrame.from_records(result_records)
results_df = DataFrame.from_records(result_records)
if hasattr(self, "dropColumns"):
for dropColumn in self.dropColumns:
results_df.drop(str(dropColumn), axis = 1, inplace = True)
return results_df

def suppressResultFields(self, resultFields):
if resultFields:
self.dropColumns = [resultField.getName() for resultField in resultFields]
else:
if hasattr(self, "dropColumns"):
del self.dropColumns

class BaseModelEvaluatorBuilder(JavaObject):

Expand Down
18 changes: 18 additions & 0 deletions jpmml_evaluator/tests/__init__.py
Expand Up @@ -97,10 +97,28 @@ def workflow(self, backend, lax):
self.assertEqual(0.0, results["probability(virginica)"])
self.assertTrue(results["report(probability(versicolor))"].startswith("<math "))

evaluator.suppressResultFields([targetField])
self.assertTrue(hasattr(evaluator, "dropColumns"))

results = evaluator.evaluate(arguments)

self.assertEqual(4, len(results))

evaluator.suppressResultFields([])
self.assertFalse(hasattr(evaluator, "dropColumns"))

arguments_df = pandas.read_csv(_resource("Iris.csv"), sep = ",")
print(arguments_df.head(5))

results_df = evaluator.evaluateAll(arguments_df)
print(results_df.head(5))

self.assertEqual((150, 5), results_df.shape)

evaluator.suppressResultFields([targetField])

results_df = evaluator.evaluateAll(arguments_df)

self.assertEqual((150, 4), results_df.shape)

evaluator.suppressResultFields(None)

0 comments on commit 1074ef4

Please sign in to comment.