Skip to content

Commit

Permalink
seq2seq trainer test refactor. (#66)
Browse files Browse the repository at this point in the history
* Refactor on test_seq2seq_trainer.py (input handler refactored).

* Code formatting.

* Docstring refactored, comment added.
  • Loading branch information
devrimcavusoglu committed Oct 13, 2022
1 parent 1e0e8f1 commit 15b1f5e
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tests/training/test_seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __call__(self, eval_pred: EvalPrediction) -> EvalPrediction:
predictions=eval_pred.predictions[0],
label_ids=eval_pred.label_ids
)

return super().__call__(eval_pred)


Expand All @@ -124,7 +125,7 @@ def trainer_params(temp_output_dir, temp_result_dir,
"rouge"
]
},
"metric_input_handler": {"type": "pass_through"},
"metric_input_handler": {"type": "language-generation"},
"args": {
"type": "seq2seq",
"output_dir": temp_output_dir + "/checkpoints",
Expand Down
3 changes: 3 additions & 0 deletions trapper/metrics/input_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from trapper.metrics.input_handlers.input_handler import MetricInputHandler
from trapper.metrics.input_handlers.language_generation_input_handler import (
MetricInputHandlerForLanguageGeneration,
)
from trapper.metrics.input_handlers.question_answering_input_handler import (
MetricInputHandlerForQuestionAnswering,
)
Expand Down
14 changes: 8 additions & 6 deletions trapper/metrics/input_handlers/input_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def _extract_metadata(self, instance: IndexedInstance) -> None:
"""
return None

def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction:
processsed_predictions = eval_pred.predictions.argmax(-1)
processed_label_ids = eval_pred.label_ids
return EvalPrediction(
predictions=processsed_predictions, label_ids=processed_label_ids
)

def __call__(
self,
eval_pred: EvalPrediction,
Expand All @@ -72,12 +79,7 @@ def __call__(
Returns: Processed EvalPrediction.
"""
processsed_predictions = eval_pred.predictions.argmax(-1)
processed_label_ids = eval_pred.label_ids
processed_eval_pred = EvalPrediction(
predictions=processsed_predictions, label_ids=processed_label_ids
)
return processed_eval_pred
return self.preprocess(eval_pred)


MetricInputHandler.register("default")(MetricInputHandler)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
from transformers import EvalPrediction

from trapper.data.tokenizers import TokenizerWrapper
from trapper.metrics.input_handlers import MetricInputHandler


@MetricInputHandler.register("language-generation")
class MetricInputHandlerForLanguageGeneration(MetricInputHandler):
"""
`MetricInputHandlerForLanguageGeneration` provides the conversion from token ids
to decoded strings for predictions and labels and prepares them for the metric
computation.
Args:
tokenizer_wrapper (): Required to convert token ids to strings.
"""

_contexts = list()

def __init__(
self,
tokenizer_wrapper: TokenizerWrapper,
):
super(MetricInputHandlerForLanguageGeneration, self).__init__()
self._tokenizer_wrapper = tokenizer_wrapper

@property
def tokenizer(self):
return self._tokenizer_wrapper.tokenizer

def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction:
if isinstance(eval_pred.predictions, tuple):
eval_pred = EvalPrediction(
# Models like T5 returns a tuple of
# (logits, encoder_last_hidden_state) instead of only the logits
predictions=eval_pred.predictions[0],
label_ids=eval_pred.label_ids,
)
eval_pred = super(MetricInputHandlerForLanguageGeneration, self).preprocess(
eval_pred
)

# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/examples/pytorch/translation/run_translation.py#L540
references = np.where(
eval_pred.label_ids != -100,
eval_pred.label_ids,
self.tokenizer.pad_token_id,
)

# Batch decode is intentionally avoided as jury metrics expect
# list of list of string for language-generation metrics.
predictions = np.array(
[
[self.tokenizer.decode(pred, skip_special_tokens=True)]
for pred in eval_pred.predictions
]
)
references = np.array(
[
[self.tokenizer.decode(ref, skip_special_tokens=True)]
for ref in references
]
)

return EvalPrediction(predictions=predictions, label_ids=references)
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _decode_answer(self, context: List[int], start, end) -> str:
answer = context[start - 1 : end - 1]
return self.tokenizer.decode(answer).lstrip()

def __call__(self, eval_pred: EvalPrediction) -> EvalPrediction:
def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction:
predictions, references = eval_pred.predictions, eval_pred.label_ids
predicted_starts, predicted_ends = predictions[0].argmax(-1), predictions[
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def label_mapper(self):
def _id_to_label(self, id_: int) -> str:
return self.label_mapper.get_label(id_)

def __call__(self, eval_pred: EvalPrediction) -> EvalPrediction:
def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction:
predictions, references = eval_pred.predictions, eval_pred.label_ids
all_predicted_ids = np.argmax(predictions, axis=2)
all_label_ids = references
Expand Down
2 changes: 1 addition & 1 deletion trapper/metrics/jury.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Union

import jury
from allennlp.common import Params
from transformers import EvalPrediction

from trapper.common import Params
from trapper.metrics.metric import Metric, MetricParam


Expand Down

0 comments on commit 15b1f5e

Please sign in to comment.