Skip to content

Commit

Permalink
Update T5 tasks to seqio tasks
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 371391065
  • Loading branch information
gauravmishra authored and Copybara-Service committed Apr 30, 2021
1 parent 23d59be commit 1ee9880
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 74 deletions.
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.

"""Transformer Modifications Mixtures."""
import t5

import seqio
from t5.data.glue_utils import get_super_glue_weight_mapping

from transformer_modifications.transformer_modifications import tasks # pylint: disable=unused-import

MixtureRegistry = t5.data.MixtureRegistry
TaskRegistry = t5.data.TaskRegistry
MixtureRegistry = seqio.MixtureRegistry
TaskRegistry = seqio.TaskRegistry

_super_glue_tasks_envocab = {}
for task, value in get_super_glue_weight_mapping().items():
Expand Down
181 changes: 112 additions & 69 deletions transformer_modifications/transformer_modifications/tasks.py
Expand Up @@ -17,7 +17,7 @@

import functools

import t5.data
import seqio
from t5.data import postprocessors as t5_postprocessors
from t5.data import preprocessors as t5_preprocessors

Expand All @@ -29,45 +29,58 @@
import tensorflow_datasets as tfds
import t5_closed_book_qa.t5_cbqa.preprocessors as t5_cbqa_preprocessors

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask
TaskRegistry = seqio.TaskRegistry

EN_VOCAB_SPM_PATH = "gs://t5-data/vocabs/cc_en.32000/sentencepiece.model"
WMT14_CUSTOM_SPM_PATH = "gs://t5-data/vocabs/wmt_ende.37000/spm.model"


WMT14_VOCAB_EXTRA_100 = t5.data.SentencePieceVocabulary(
WMT14_VOCAB_EXTRA_100 = seqio.SentencePieceVocabulary(
WMT14_CUSTOM_SPM_PATH, extra_ids=100)
EN_VOCAB_EXTRA_100 = t5.data.SentencePieceVocabulary(
EN_VOCAB_EXTRA_100 = seqio.SentencePieceVocabulary(
EN_VOCAB_SPM_PATH, extra_ids=100)

EN_VOCAB_OUTPUT_FEATURES = {
"inputs": seqio.Feature(vocabulary=EN_VOCAB_EXTRA_100, add_eos=True),
"targets": seqio.Feature(vocabulary=EN_VOCAB_EXTRA_100, add_eos=True)
}

#================================ English only vocab ===========================
for version in ("2.2.0", "2.3.0", "2.3.1"):
TaskRegistry.add(
"c4_v{}_unsupervised_en32k".format(version.replace(".", "")),
TfdsTask,
tfds_name="c4/en:{}".format(version),
text_preprocessor=functools.partial(
t5_preprocessors.rekey, key_map={
"inputs": None,
"targets": "text"
}),
token_preprocessor=t5_preprocessors.unsupervised,
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100),
source=seqio.TfdsDataSource(tfds_name="c4/en:{}".format(version)),
preprocessors=[
functools.partial(
t5_preprocessors.rekey,
key_map={
"inputs": None,
"targets": "text"
}),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
t5_preprocessors.unsupervised,
seqio.preprocessors.append_eos_after_trim,
],
output_features=EN_VOCAB_OUTPUT_FEATURES,
metric_fns=[])


#================================ XSUM =========================================
TaskRegistry.add(
"xsum_v110",
TfdsTask,
tfds_name="xsum:1.1.0",
text_preprocessor=functools.partial(
t5_preprocessors.summarize,
article_key="document",
summary_key="summary"),
source=seqio.TfdsDataSource(tfds_name="xsum:1.1.0"),
preprocessors=[
functools.partial(
t5_preprocessors.summarize,
article_key="document",
summary_key="summary"),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
metric_fns=[t5_metrics.rouge],
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100),
output_features=EN_VOCAB_OUTPUT_FEATURES,
)

#============================ SuperGLUE English Vocab===========================
Expand All @@ -88,95 +101,125 @@
get_glue_text_preprocessor(b)
]
else:
text_preprocessor = get_glue_text_preprocessor(b)
text_preprocessor = [get_glue_text_preprocessor(b)]
TaskRegistry.add(
"super_glue_%s_v102_envocab" % b.name,
TfdsTask,
tfds_name="super_glue/%s:1.0.2" % b.name,
text_preprocessor=text_preprocessor,
source=seqio.TfdsDataSource(
tfds_name="super_glue/%s:1.0.2" % b.name,
splits=["test"] if b.name in ["axb", "axg"] else None),
preprocessors=text_preprocessor + [
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
metric_fns=get_super_glue_metric(b.name),
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100),
postprocess_fn=get_glue_postprocess_fn(b),
splits=["test"] if b.name in ["axb", "axg"] else None)
output_features=EN_VOCAB_OUTPUT_FEATURES,
postprocess_fn=get_glue_postprocess_fn(b))

# ======================== Definite Pronoun Resolution =========================
TaskRegistry.add(
"dpr_v001_simple_envocab",
TfdsTask,
tfds_name="definite_pronoun_resolution:1.1.0",
text_preprocessor=t5_preprocessors.definite_pronoun_resolution_simple,
source=seqio.TfdsDataSource(tfds_name="definite_pronoun_resolution:1.1.0"),
preprocessors=[
t5_preprocessors.definite_pronoun_resolution_simple,
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
metric_fns=[t5_metrics.accuracy],
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100))
output_features=EN_VOCAB_OUTPUT_FEATURES)

# =================================== WSC ======================================
TaskRegistry.add(
"super_glue_wsc_v102_simple_train_envocab",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
t5_preprocessors.wsc_simple, correct_referent_only=True),
source=seqio.TfdsDataSource(
tfds_name="super_glue/wsc.fixed:1.0.2", splits=["train"]),
preprocessors=[
functools.partial(
t5_preprocessors.wsc_simple, correct_referent_only=True),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
metric_fns=[],
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100),
splits=["train"])
output_features=EN_VOCAB_OUTPUT_FEATURES)
TaskRegistry.add(
"super_glue_wsc_v102_simple_eval_envocab",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
t5_preprocessors.wsc_simple, correct_referent_only=False),
source=seqio.TfdsDataSource(
tfds_name="super_glue/wsc.fixed:1.0.2", splits=["validation", "test"]),
preprocessors=[
functools.partial(
t5_preprocessors.wsc_simple, correct_referent_only=False),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
postprocess_fn=t5_postprocessors.wsc_simple,
metric_fns=[t5_metrics.accuracy],
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100),
splits=["validation", "test"])
output_features=EN_VOCAB_OUTPUT_FEATURES)

# ============================ Web Questions ===================================

# This task uses 10% of the train split for validation.
TaskRegistry.add(
"web_questions_open_envocab",
TfdsTask,
tfds_name="web_questions:1.0.0",
splits={
"train": "train[:3417]", # ~90%, matches numbers used by ORQA
"validation": "train[3417:]", # ~10%, matches numbers used by ORQA
"test": "test"
},
text_preprocessor=[t5_cbqa_preprocessors.web_questions_open],
source=seqio.TfdsDataSource(
tfds_name="web_questions:1.0.0",
splits={
"train": "train[:3417]", # ~90%, matches numbers used by ORQA
"validation": "train[3417:]", # ~10%, matches numbers used by ORQA
"test": "test"
}),
preprocessors=[
t5_cbqa_preprocessors.web_questions_open,
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
postprocess_fn=t5_postprocessors.qa,
metric_fns=[t5_metrics.squad],
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100))
output_features=EN_VOCAB_OUTPUT_FEATURES)

# This tasks trains on the full train split.
TaskRegistry.add(
"web_questions_open_test_envocab",
TfdsTask,
tfds_name="web_questions:1.0.0",
splits={
"train": "train",
"validation": "test",
},
text_preprocessor=[t5_cbqa_preprocessors.web_questions_open],
source=seqio.TfdsDataSource(
tfds_name="web_questions:1.0.0",
splits={
"train": "train",
"validation": "test",
}),
preprocessors=[
t5_cbqa_preprocessors.web_questions_open,
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
postprocess_fn=t5_postprocessors.qa,
metric_fns=[t5_metrics.squad],
output_features=t5.data.Feature(vocabulary=EN_VOCAB_EXTRA_100))
output_features=EN_VOCAB_OUTPUT_FEATURES)

# WMT14 en-de t2t task with the custom vocabulary for the original Transformer
# paper relication experiments.
# This is internal because the vocabulary only resides in CNS for now.
b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"]

wmt14_t2t_output_features = {
"inputs": t5.data.Feature(vocabulary=WMT14_VOCAB_EXTRA_100, add_eos=True),
"targets": t5.data.Feature(vocabulary=WMT14_VOCAB_EXTRA_100, add_eos=True)
"inputs": seqio.Feature(vocabulary=WMT14_VOCAB_EXTRA_100, add_eos=True),
"targets": seqio.Feature(vocabulary=WMT14_VOCAB_EXTRA_100, add_eos=True)
}
TaskRegistry.add(
"wmt_t2t_ende_v003_vocab_37000",
TfdsTask,
tfds_name="wmt_t2t_translate/de-en:1.0.0",
text_preprocessor=functools.partial(
t5_preprocessors.translate,
source_language=b.language_pair[1],
target_language=b.language_pair[0],
source=seqio.TfdsDataSource(tfds_name="wmt_t2t_translate/de-en:1.0.0"),
preprocessors=[
functools.partial(
t5_preprocessors.translate,
source_language=b.language_pair[1],
target_language=b.language_pair[0],
),
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.append_eos_after_trim,
],
metric_fns=[t5_metrics.bleu],
output_features=wmt14_t2t_output_features)
Expand Up @@ -19,16 +19,16 @@
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import t5

import seqio
import tensorflow.compat.v1 as tf

from transformer_modifications.transformer_modifications import tasks # pylint: disable=unused-import

tf.disable_v2_behavior()
tf.enable_eager_execution()

TaskRegistry = t5.data.TaskRegistry
TaskRegistry = seqio.TaskRegistry
_SEQUENCE_LENGTH = {"inputs": 512, "targets": 256}
_TASKS = [
"c4_v220_unsupervised_en32k",
Expand Down

0 comments on commit 1ee9880

Please sign in to comment.