-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #682 from mv1388/hf_evaluate_metrics_support
HF evaluate metrics support
- Loading branch information
Showing
3 changed files
with
106 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
aitoolbox/experiment/result_package/hf_evaluate_packages.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from aitoolbox.experiment.result_package.abstract_result_packages import AbstractResultPackage | ||
|
||
|
||
class HFEvaluateResultPackage(AbstractResultPackage): | ||
def __init__(self, hf_evaluate_metric, use_models_additional_results=True, **kwargs): | ||
"""HuggingFace Evaluate Metrics Result Package | ||
Result package wrapping around the evaluation metrics provided in the HuggingFace Evaluate package. | ||
All the metric result names will have the '_HFEvaluate' appended at the end to help distinguish them. | ||
Github: https://github.com/huggingface/evaluate | ||
More info on how to use the metrics: https://huggingface.co/docs/evaluate/index | ||
Args: | ||
hf_evaluate_metric (evaluate.EvaluationModule): HF Evaluate metric to be used by the result package | ||
use_models_additional_results (bool): Should the additional results from the model | ||
(in addition to predictions and references) normally returned from the get_predictions() function be | ||
added as the additional input to the HF Evaluate metric to perform the evaluation calculation. | ||
**kwargs: additional parameters or inputs to the HF Evaluate metric being calculated. These can be generally | ||
inputs available already at the start before making model predictions and thus don't need to be gathered | ||
from the train/prediction loop. | ||
""" | ||
AbstractResultPackage.__init__(self, pkg_name='HuggingFace Evaluate metrics', **kwargs) | ||
|
||
self.metric = hf_evaluate_metric | ||
self.use_models_additional_results = use_models_additional_results | ||
|
||
def prepare_results_dict(self): | ||
additional_metric_inputs = self.package_metadata | ||
|
||
if self.use_models_additional_results: | ||
model_additional_results = self.additional_results['additional_results'] | ||
additional_metric_inputs = {**additional_metric_inputs, **model_additional_results} | ||
|
||
metric_result = self.metric.compute( | ||
references=self.y_true, predictions=self.y_predicted, | ||
**additional_metric_inputs | ||
) | ||
|
||
if isinstance(metric_result, dict): | ||
metric_result = {f'{k}_HFEvaluate': v for k, v in metric_result.items()} | ||
|
||
return metric_result |
59 changes: 59 additions & 0 deletions
59
tests/test_experiment/test_result_package/test_hf_evaluate_packages.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import unittest | ||
|
||
from aitoolbox.experiment.result_package.hf_evaluate_packages import HFEvaluateResultPackage | ||
|
||
|
||
class DummyHFEvaluateMetric: | ||
def compute(self, **kwargs): | ||
return kwargs | ||
|
||
|
||
class TestHFEvaluateResultPackage(unittest.TestCase): | ||
def test_models_additional_inputs_parsing(self): | ||
metric = DummyHFEvaluateMetric() | ||
result_package = HFEvaluateResultPackage(hf_evaluate_metric=metric, use_models_additional_results=True) | ||
result_package.y_true = 1 | ||
result_package.y_predicted = 2 | ||
result_package.additional_results = {'additional_results': {'aaa': 3}} | ||
result_dict = result_package.prepare_results_dict() | ||
self.assertEqual(result_dict, {'references_HFEvaluate': 1, 'predictions_HFEvaluate': 2, 'aaa_HFEvaluate': 3}) | ||
|
||
metric = DummyHFEvaluateMetric() | ||
result_package = HFEvaluateResultPackage(hf_evaluate_metric=metric, use_models_additional_results=False) | ||
result_package.y_true = 1 | ||
result_package.y_predicted = 2 | ||
result_package.additional_results = {'additional_results': {'aaa': 3}} | ||
result_dict = result_package.prepare_results_dict() | ||
self.assertEqual(result_dict, {'references_HFEvaluate': 1, 'predictions_HFEvaluate': 2}) | ||
|
||
def test_additional_inputs_combination(self): | ||
metric = DummyHFEvaluateMetric() | ||
result_package = HFEvaluateResultPackage( | ||
hf_evaluate_metric=metric, use_models_additional_results=True, | ||
my_additional_input_1=123, my_additional_input_2='ABCD' | ||
) | ||
result_package.y_true = 1 | ||
result_package.y_predicted = 2 | ||
result_package.additional_results = {'additional_results': {'aaa': 3}} | ||
result_dict = result_package.prepare_results_dict() | ||
self.assertEqual( | ||
result_dict, | ||
{'references_HFEvaluate': 1, 'predictions_HFEvaluate': 2, | ||
'my_additional_input_1_HFEvaluate': 123, 'my_additional_input_2_HFEvaluate': 'ABCD', | ||
'aaa_HFEvaluate': 3} | ||
) | ||
|
||
metric = DummyHFEvaluateMetric() | ||
result_package = HFEvaluateResultPackage( | ||
hf_evaluate_metric=metric, use_models_additional_results=False, | ||
my_additional_input_1=123, my_additional_input_2='ABCD' | ||
) | ||
result_package.y_true = 1 | ||
result_package.y_predicted = 2 | ||
result_package.additional_results = {'additional_results': {'aaa': 3}} | ||
result_dict = result_package.prepare_results_dict() | ||
self.assertEqual( | ||
result_dict, | ||
{'references_HFEvaluate': 1, 'predictions_HFEvaluate': 2, | ||
'my_additional_input_1_HFEvaluate': 123, 'my_additional_input_2_HFEvaluate': 'ABCD'} | ||
) |