-
Notifications
You must be signed in to change notification settings - Fork 749
/
mesh_transformer.py
357 lines (303 loc) · 12.6 KB
/
mesh_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for providing data to Mesh TF transformer."""
import functools
from absl import logging
import gin
import mesh_tensorflow.transformer.dataset as transformer_dataset
import seqio
from t5.models import utils as model_utils
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
DEPRECATED_GIN_REFERENCES = (
"configurable_vocabulary",
"get_sentencepiece_model_path",
"maybe_print_dataset",
"num_parallel_calls",
"SentencePieceVocabulary",
"seqio.sentencepiece_vocabulary.SentencePieceVocabulary",
"t5.models.mesh_transformer.get_sentencepiece_model_path",
"train_model",
"vocabularies.Vocabulary",
"Vocabulary",
)
@gin.configurable()
def mesh_train_dataset_fn(
mixture_or_task_name,
sequence_length,
vocabulary=None,
dataset_split=tfds.Split.TRAIN,
shuffle=True,
seed=None,
use_cached=False,
pack=True):
"""Returns the tf.data.Dataset for training on a given mixture.
This uses the format required for utils.run's `train_dataset_fn` argument in
the Mesh TF transformer standalone.
Args:
mixture_or_task_name: string, an identifier for a Mixture or Task in the
appropriate registry. Must be specified via gin.
sequence_length: dict mapping feature key to the int length for that feature
the max sequence length.
vocabulary: unused argument, maintains compatibility with other dataset_fns.
dataset_split: string, which split of the dataset to load. In most cases
this should be "train".
shuffle: Whether or not to shuffle dataset.
seed: tf.int64 scalar tf.Tensor (or None). Used for both the global seed and
shuffle seed for tf.data
use_cached: bool, whether to load the cached version of this dataset.
pack: bool, whether to pack the dataset.
Returns:
A tf.data.Dataset of preprocessed, tokenized, and batched examples.
"""
del vocabulary
mixture_or_task = seqio.get_mixture_or_task(mixture_or_task_name)
ds = mixture_or_task.get_dataset(
sequence_length, split=dataset_split, use_cached=use_cached,
shuffle=shuffle, num_epochs=None, seed=seed)
# Select just the output features which are present in the dataset.
feature_keys = tuple(k for k in mixture_or_task.output_features
if k in tf.data.get_output_shapes(ds))
# Filtering feature keys is done in pack_or_pad function. However, when
# packing is turned off, input_features aren't filtered leading to training
# problems due to strings showing up in the input example. Filtering features
# ensures that we don't rely on pack_or_pad to filter features for training.
def _filter_features(ex):
return {k: ex[k] for k in feature_keys}
ds = ds.map(
_filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
eos_keys = set(
k for k, f in mixture_or_task.output_features.items() if f.add_eos)
ds = transformer_dataset.pack_or_pad(
ds, sequence_length, pack=pack,
feature_keys=feature_keys, ensure_eos=eos_keys)
return ds
@gin.configurable()
def mesh_inference_dataset_fn(
mixture_or_task_name,
sequence_length,
dataset_split,
shuffle=False,
seed=None,
vocabulary=None,
num_inference_examples=-1,
use_cached=False,
priming_sequence_length=None):
"""Returns all tf.data.Datasets for LM inference on a given mixture.
For Tasks without inputs (such as language modeling), the first
`priming_sequence_length` tokens in the target are used as the "inputs" for
inference.
Args:
mixture_or_task_name: string, an identifier for a Mixture or Task in the
appropriate registry. Must be specified via gin.
sequence_length: dict mapping feature key to the int length for that feature
the max sequence length. If set to None, packing and padding will be
disabled.
dataset_split: string, which split of the dataset to load. NOTE, this
function does NOT receive the split specified in utils.run. It needs to be
specified separately.
shuffle: Whether or not to shuffle dataset.
seed: tf.int64 scalar tf.Tensor (or None). Used as shuffle seed for tf.data.
vocabulary: unused argument, maintains compatibility with other dataaset_fns
num_inference_examples: maximum number of examples per task to do inference
on. If None or less than 0, use all examples.
use_cached: bool, whether to load the cached version of this dataset.
evals but should not be used for iterative decoding.
priming_sequence_length: If the Task only has "targets", select the first
this many tokens from each target sequence to use as "inputs". This is
useful for decoder-only language models where you would like to use a
portion of the targets as a priming sequence for generation.
Returns:
A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
"""
del vocabulary
mixture_or_task = seqio.get_mixture_or_task(mixture_or_task_name)
def _split_targets_for_primed_inference(ex):
ex["inputs"] = ex["targets"][:priming_sequence_length]
ex["targets"] = ex["targets"][priming_sequence_length:]
ex["inputs"] = tf.pad(
ex["inputs"],
[[0, priming_sequence_length - tf.shape(ex["inputs"])[0]]], "CONSTANT")
ex["inputs"] = tf.reshape(ex["inputs"], shape=(priming_sequence_length,))
return ex
def _prepare_for_unprimed_inference(ex):
ex["inputs"] = tf.constant([], dtype=tf.int64)
return ex
def _get_dataset_for_single_task(task, sequence_length):
"""Get a tensorflow.data.Dataset for the provided task."""
ds = task.get_dataset(
sequence_length, split=dataset_split, use_cached=use_cached,
shuffle=shuffle, seed=seed)
if "inputs" not in ds.element_spec:
if not priming_sequence_length or priming_sequence_length <= 0:
logging.warning("Priming sequence length not specified so priming "
"with the empty string.")
ds = ds.map(_prepare_for_unprimed_inference)
else:
logging.info("Using the first %d tokens of each target as input.",
priming_sequence_length)
ds = ds.map(_split_targets_for_primed_inference)
elif priming_sequence_length is not None:
raise ValueError(
"Setting a priming sequence length only makes sense for decoder-only "
"Tasks, which have `targets` but no `inputs`.")
eos_keys = set(
k for k, f in mixture_or_task.output_features.items() if f.add_eos)
logging.info(
"Padding '%s' with sequence lengths: %s", task.name, sequence_length)
ds = transformer_dataset.pack_or_pad(
ds,
sequence_length,
pack=False,
feature_keys=tuple(task.output_features),
ensure_eos=eos_keys)
if num_inference_examples is not None and num_inference_examples >= 0:
ds = ds.take(num_inference_examples)
return ds
outputs = []
for task in seqio.get_subtasks(mixture_or_task):
if dataset_split not in task.splits:
logging.info("Task %s has no '%s' split, skipping inference.",
task.name, dataset_split)
continue
outputs.append(
transformer_dataset.EvalDataset(
task.name,
functools.partial(
_get_dataset_for_single_task,
task=task,
sequence_length=sequence_length),
task.postprocess_fn,
task.metric_fns,
)
)
if not outputs:
logging.warning("No %s data found for %s.",
dataset_split, mixture_or_task_name)
return outputs
@gin.configurable()
def mesh_eval_dataset_fn(
mixture_or_task_name,
sequence_length,
dataset_split,
vocabulary=None,
num_eval_examples=-1,
use_cached=False,
pack=False,
shuffle_eval_examples=False,
seed=None):
"""Returns all tf.data.Datasets for evaluation on a given mixture.
This uses the format required for utils.run's `eval_dataset_fn` argument in
the Mesh TF transformer standalone.
Args:
mixture_or_task_name: string, an identifier for a Mixture or Task in the
appropriate registry. Must be specified via gin.
sequence_length: dict mapping feature key to the int length for that feature
the max sequence length. If set to None, packing and padding will be
disabled.
dataset_split: string, which split of the dataset to load.
vocabulary: unused argument, maintains compatibility with other dataaset_fns
num_eval_examples: maximum number of examples per task to use for continuous
eval. If None or less than 0, use all examples.
use_cached: bool, whether to load the cached version of this dataset.
pack: a boolean, whether to pack examples. This is useful for perplexity
evals but should not be used for iterative decoding.
shuffle_eval_examples: boolean, whether to shuffle eval examples, applied
only when num_eval_examples is not None. Intended to be able to eval on a
different eval slice at every iteration.
seed: tf.int64 scalar tf.Tensor (or None). Used for both the global seed and
shuffle seed for tf.data
Returns:
A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples.
"""
del vocabulary
mixture_or_task = seqio.get_mixture_or_task(mixture_or_task_name)
def _get_dataset_for_single_task(task, sequence_length):
"""Get a tensorflow.data.Dataset for the provided task."""
if shuffle_eval_examples and seed is None:
logging.warning(("shuffle_seed_examples is true but no seed was ",
"provided. Using a random seed."))
ds = task.get_dataset(
sequence_length, split=dataset_split,
use_cached=use_cached, shuffle=shuffle_eval_examples, seed=seed,
)
eos_keys = set(
k for k, f in mixture_or_task.output_features.items() if f.add_eos)
if sequence_length is None:
logging.info(
"Skipping packing/padding for '%s' since sequence length is None.",
task.name)
else:
logging.info(
"%sing '%s' with sequence lengths: %s",
"Pack" if pack else "Padd", task.name, sequence_length)
ds = transformer_dataset.pack_or_pad(
ds,
sequence_length,
pack=pack,
feature_keys=tuple(task.output_features),
ensure_eos=eos_keys)
if num_eval_examples is not None and num_eval_examples >= 0:
ds = ds.take(num_eval_examples)
return ds
outputs = []
for task in seqio.get_subtasks(mixture_or_task):
if dataset_split not in task.splits:
logging.info(
"Task %s has no '%s' split, skipping eval.", task.name, dataset_split
)
continue
outputs.append(
transformer_dataset.EvalDataset(
task.name,
functools.partial(
_get_dataset_for_single_task,
task=task,
sequence_length=sequence_length),
task.postprocess_fn,
task.metric_fns,
)
)
if not outputs:
logging.warning("No %s data found for %s.",
dataset_split, mixture_or_task_name)
return outputs
@gin.configurable()
def tsv_dataset_fn(
filename,
sequence_length,
dataset_split,
vocabulary,
shuffle_buffer_size=10000):
r"""Returns a dataset based on a TSV file formatted as `<input>\t<target>`."""
# Currently `tf.gfile.glob` is broken on GCS, so we only read a file or
# list of files.
return transformer_dataset.packed_parallel_tsv_dataset(
dataset=tf.data.TextLineDataset(filename).shuffle(shuffle_buffer_size),
sequence_length=sequence_length,
vocabulary=vocabulary,
dataset_split=dataset_split,
append_eos=True,
eos_id=1)
@gin.configurable()
def get_vocabulary(mixture_or_task_name=None):
"""Get the appropriate value for the utils.run.vocabulary argument.
Args:
mixture_or_task_name: string, an identifier for a Mixture or Task in the
appropriate registry. Must be specified via gin.
Returns:
Either a single seqio.vocabularies.Vocabulary or a tuple of
seqio.vocabularies.Vocabulary for inputs and targets.
"""
return model_utils.get_vocabulary(mixture_or_task_name)