Skip to content

Commit

Permalink
Introduces seqio.CollectingMetric class
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526774402
  • Loading branch information
KEHANG authored and SeqIO committed Apr 24, 2023
1 parent a08b322 commit a2b140b
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 47 deletions.
4 changes: 2 additions & 2 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,8 @@ def metric_objs(self) -> Sequence[metrics_lib.Metric]:
to_return = list(x for x in self._metric_objs_constructor_args)
if self.metric_fns:
to_return += [
metrics_lib.LegacyMetric.empty(mf, self._postprocess_fn)
for mf in self.metric_fns
metrics_lib.PassthroughLegacyMetric.from_metric_fn(
mf, self._postprocess_fn).empty() for mf in self.metric_fns
]
return to_return

Expand Down
5 changes: 1 addition & 4 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ def score_metric_fn_2(targets, scores):
)

actual_metric_objs = list(task.metric_objs)
expected_metric_objs = input_metric_objs + [
metrics_lib.LegacyMetric.empty(score_metric_fn_2, None)
]
self.assertListEqual(actual_metric_objs, expected_metric_objs)
self.assertLen(actual_metric_objs, 2)

def test_metric_fn_signature(self):
# pylint:disable=unused-argument
Expand Down
52 changes: 39 additions & 13 deletions seqio/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def _permute(x, sorted_order):
return [x[sorted_order[i]] for i in range(len(sorted_order))]

model_fn_result = model_fn(cached_model_dataset) # pytype: disable=missing-parameter # always-use-return-annotations

if isinstance(model_fn_result, tuple):
# Some of model functions return a tuple of two outputs per example.
# e.g., ModelOutputType.PREDICTION_WITH_AUX,
Expand Down Expand Up @@ -707,25 +706,48 @@ def _compute_clu_metrics(
inferences = {}
for metric_obj in task.metric_objs:
model_output = all_output[task.name][metric_obj.model_output_type]
# When model output type is PREDICTION_WITH_AUX or
# SCORE_WITH_INTERMEDIATES, model output is a tuple of two arrays/lists.
if isinstance(model_output, tuple):
prediction_or_score, aux_value = model_output
aux_value = jax.tree_map(
np.array, aux_value,
is_leaf=lambda x: isinstance(x, list),
)
model_output = (np.array(prediction_or_score), aux_value)
else:
model_output = np.array(model_output)
metric_instance = metric_obj.from_model_output(
tfds.as_numpy(task_dataset),
model_output,
task.output_features,
self._target_field_name,
)
task_metrics.append(metric_instance.compute())
if isinstance(metric_instance, metrics_lib.CollectingMetric):
metric_value, targets_and_inferences = metric_instance.actual_compute(
tfds.as_numpy(task_dataset),
task.output_features,
self._target_field_name)
else:
metric_value = metric_instance.compute()
targets_and_inferences = None
if hasattr(metric_instance, "targets_and_inferences"):
targets_and_inferences = metric_instance.targets_and_inferences
task_metrics.append(metric_value)
# Records inferences for legacy logging compatibility.
inferences.update(
{
key: val
for key, val in metric_instance.targets_and_inferences.items()
if key != "targets"
}
)
# common ones are score, output, prediction.
if targets_and_inferences:
for key, val in targets_and_inferences.items():
if key == "targets": continue
inferences[key] = (val.tolist()
if isinstance(val, np.ndarray) else val)
# Records targets for legacy logging compatibility.
# Each metric_instance should have identical targets.
# Chooses the last metric_instance for this recording purpose.
targets = metric_instance.targets_and_inferences["targets"]
# Each targets_and_inferences should have identical targets.
# Chooses the last targets_and_inferences for this recording purpose.
if targets_and_inferences:
targets = targets_and_inferences["targets"]
else:
targets = None

all_metrics[task.name] = {}
for k, v in itertools.chain(*[m.items() for m in task_metrics]):
Expand Down Expand Up @@ -823,7 +845,8 @@ def from_model_output(self,
Tuple[np.ndarray, np.ndarray]],
features: Mapping[str, utils.Feature],
target_field_name: str = "targets",
mask: Optional[np.ndarray] = None):
mask: Optional[np.ndarray] = None,
indices_2d: Optional[np.ndarray] = None):
"""Calculates the metrics associated with the given task name and model output type.
Args:
Expand All @@ -837,6 +860,8 @@ def from_model_output(self,
mask: An array of booleans, same length as inputs. Each element indicates
whether to include the corresponding element in the inputs for metric
evaluation.
indices_2d: 2d-indices of examples in the inputs/model_output. First
dimension is shard id, the second is the example id within that shard.
Returns:
Expand All @@ -845,6 +870,7 @@ def from_model_output(self,
metric_batch = metrics_collection.single_from_model_output(
inputs=inputs,
model_output=model_output,
indices_2d=indices_2d,
features=features,
target_field_name=target_field_name,
mask=mask)
Expand Down
18 changes: 9 additions & 9 deletions seqio/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def predict_fn(
) -> evaluation.PredictFnReturnType:
del ds, model_feature_shapes
return (
[(0, [5, 6]), (1, [7]), (2, [7])]
[(0, [5, 6]), (1, [7, 5]), (2, [7, 5])]
if task.predict_metric_fns
else self.uncalled_fn
)
Expand All @@ -466,7 +466,7 @@ def predict_with_aux_fn(
del ds, model_feature_shapes

indices_and_predictions = (
[(0, [5, 6]), (1, [7]), (2, [7])]
[(0, [5, 6]), (1, [7, 5]), (2, [7, 5])]
if task.predict_with_aux_metric_fns
else self.uncalled_fn
)
Expand Down Expand Up @@ -508,7 +508,7 @@ def test_evaluate_single_task_predict(self):
)
all_metrics, _ = self._evaluate_single_task(task)
self.assertDictClose(
{"sequence_accuracy": 2.0 / 3 * 100}, all_metrics[task.name]
{"sequence_accuracy": 1.0 / 3 * 100}, all_metrics[task.name]
)

def test_evaluate_single_task_score(self):
Expand All @@ -524,7 +524,7 @@ def test_evaluate_single_task_both(self):
score_metric_fns=[_sum_scores_metric],
)
all_metrics, _ = self._evaluate_single_task(task)
expected = {"sequence_accuracy": 2.0 / 3 * 100, "total_score": 1305}
expected = {"sequence_accuracy": 1.0 / 3 * 100, "total_score": 1305}
self.assertDictClose(expected, all_metrics[task.name])

def test_evaluate_using_aux_score(self):
Expand Down Expand Up @@ -610,7 +610,7 @@ def test_evaluate_single_task_predict_target_field_name(self):
)
all_metrics, _ = self._evaluate_single_task(task, target_field_name="foo")
self.assertDictClose(
{"sequence_accuracy": 2.0 / 3 * 100}, all_metrics[task.name]
{"sequence_accuracy": 1.0 / 3 * 100}, all_metrics[task.name]
)

def test_evaluate_single_task_with_loggers(self):
Expand All @@ -630,7 +630,7 @@ def test_evaluate_single_task_with_loggers(self):

_, evaluator = self._evaluate_single_task(task, loggers=loggers)
metrics = {
"sequence_accuracy": metrics_lib.Scalar(1 / 3 * 100),
"sequence_accuracy": metrics_lib.Scalar(0.0 / 3 * 100),
"total_score": metrics_lib.Scalar(1305),
}
for logger in loggers:
Expand All @@ -641,9 +641,9 @@ def test_evaluate_single_task_with_loggers(self):
dataset=evaluator._cached_task_datasets[task.name],
targets=["e5 e6", "e6", "e7"],
inferences={
"prediction": ["e5 e7", "e7", "e7"],
"prediction": ["e5 e7", "e7 e5", "e7 e5"],
"score": [2, 1, 3],
"output": ["e5 e6", "e7", "e7"],
"output": ["e5 e6", "e7 e5", "e7 e5"],
},
)

Expand Down Expand Up @@ -831,7 +831,7 @@ def predict_fn(
) -> Optional[evaluation.PredictFnReturnType]:
del model_feature_shapes
if ds == mock_ds1:
return [(0, [5, 6]), (1, [7])]
return [(0, [5, 6]), (1, [7, 5])]
elif ds == mock_ds2:
return [(0, [5]), (1, [6]), (2, [7])]

Expand Down
Loading

0 comments on commit a2b140b

Please sign in to comment.