-
Notifications
You must be signed in to change notification settings - Fork 749
/
mtf_model.py
549 lines (492 loc) · 23.4 KB
/
mtf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mesh Tensorflow T5 Model."""
import functools
import os
import gin
import gin.tf
import mesh_tensorflow as mtf
from mesh_tensorflow import optimize
from mesh_tensorflow.transformer import dataset as transformer_dataset
from mesh_tensorflow.transformer import learning_rate_schedules
from mesh_tensorflow.transformer import utils as mtf_utils
from t5.models import mesh_transformer
from t5.models import utils
from t5.models.t5_model import T5Model
import tensorflow.compat.v1 as tf
def _parse_operative_config(model_dir):
with gin.unlock_config():
gin.parse_config_file(
os.path.join(model_dir, "operative_config.gin"),
skip_unknown=mesh_transformer.DEPRECATED_GIN_REFERENCES)
@gin.configurable
class MtfModel(T5Model):
"""Wrapper class for Mesh-TF models."""
def __init__(
self,
model_dir,
tpu,
tpu_job_name=None,
tpu_zone=None,
gcp_project=None,
tpu_topology="v2-8",
model_parallelism=8,
batch_size=("sequences_per_batch", 1),
sequence_length=None,
model_type="bitransformer",
layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch",
mesh_shape=None,
mesh_devices=None,
autostack=True,
learning_rate_schedule=None,
keep_checkpoint_max=None,
save_checkpoints_steps=5000,
optimizer=None,
predict_fn=None,
variable_filter=None,
ensemble_inputs=None,
iterations_per_loop=100,
extra_gin_bindings=None):
"""Constructor for MtfModel class.
Args:
model_dir: str, directory to save the model.
tpu: str, the TPU address to use.
tpu_job_name: str, name of the TPU worker binary.
tpu_zone: str, GCE zone where the Cloud TPU is located
gcp_project: str, project name for the Cloud TPU-enabled project.
tpu_topology: str, e.g. "2x2" or "v2-8".
model_parallelism: integer, the number of cores per model replica.
batch_size: An integer or a (method, value) pair to pass to
compute_batch_size(). Note that this is the global batch size and not
the per-shard batch size.
sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
model_type: str, a model type from mesh tf models.
layout_rules: an input to mtf.convert_to_layout_rules()
mesh_shape: an mtf.Shape or string (e.g., "model:2,batch:4") specifying
how the data/model should be parallelized. If None (default), the mesh
shape will be constructed using the supplied `tpu_topology` and
`model_parallelism` arguments.
mesh_devices: a list of strings, the device names to use for each mesh
slice. Only required for GPU.
autostack: boolean, internally combine variables.
learning_rate_schedule: an optional function taking the scalar name
argument `step` and the numeric argument `total_train_steps` and return
the scalar learning rate.
keep_checkpoint_max: an integer, maximum number of checkpoints to keep.
save_checkpoints_steps: an integer, steps per checkpoint.
optimizer: a class extending optimize.Optimizer, required for training.
predict_fn: an optional function that can be used to override the default
transformer prediction behavior. Must return a tensor of shape
[batch_dim, length_dim] that will be the prediction for each example.
Must accept the following arguments:
- model: a Unitransformer or Bitransformer
- features: a dict representing an example. Every value will be an
mtf.Tensor with shape [batch_dim, length_dim].
- variable_dtype: an mtf.VariableDType
variable_filter: a str, a variable will only be trained if its name
matches this regex. If None (default), train all trainable variables.
ensemble_inputs: an integer, see `train_model` docstring for details.
iterations_per_loop: integer, steps per train loop
extra_gin_bindings: an optional list of strings, extra gin bindings to
pass to `gin.parse_config` after loading the operative config.
"""
mesh_shape = mesh_shape or (
mtf_utils.tpu_mesh_shape(
tpu_topology, model_parallelism) if tpu else "")
sequence_length = sequence_length or {"inputs": 512, "targets": 512}
if isinstance(sequence_length, int):
sequence_length = {"inputs": sequence_length,
"targets": sequence_length}
self._learning_rate_schedule = (
learning_rate_schedule or
learning_rate_schedules.learning_rate_schedule_noam)
self._optimizer = optimizer or optimize.AdafactorOptimizer
self._sequence_length = sequence_length
self._model_dir = model_dir
self._model_type = model_type
self._ensemble_inputs = ensemble_inputs
self._layout_rules = mtf.convert_to_layout_rules(layout_rules)
self._mesh_shape = mtf.convert_to_shape(mesh_shape)
self._mesh_devices = mesh_devices
self._autostack = autostack
self._keep_checkpoint_max = keep_checkpoint_max
self._save_checkpoints_steps = save_checkpoints_steps
self._predict_fn = predict_fn
self._variable_filter = variable_filter
self._ensemble_inputs = ensemble_inputs
self._iterations_per_loop = iterations_per_loop
self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu, zone=tpu_zone, project=gcp_project) if tpu else None
self._tpu = tpu
self._tpu_job_name = tpu_job_name
self._estimator = None
# Must be called after _sequence_length, _mesh_shape, and _layout_rules are
# set.
self.batch_size = batch_size
self._gin_bindings = extra_gin_bindings
@property
def batch_size(self):
return self._batch_size
@batch_size.setter
def batch_size(self, batch_size):
if not isinstance(batch_size, int):
self._batch_size = mtf_utils.compute_batch_size(
self._sequence_length, self._mesh_shape, self._layout_rules,
batch_size)
else:
self._batch_size = batch_size
def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False,
score_in_predict_mode=False, sequence_length=None):
if not self._tpu or disable_tpu:
with gin.unlock_config():
gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
gin.bind_parameter(
"utils.get_variable_dtype.activation_dtype", "float32")
with gin.unlock_config():
gin.parse_config(self._gin_bindings)
return mtf_utils.get_estimator(
model_type=self._model_type,
vocabulary=vocabulary,
layout_rules=self._layout_rules,
mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape,
mesh_devices=None if disable_tpu else self._mesh_devices,
model_dir=self._model_dir,
batch_size=self.batch_size,
sequence_length=sequence_length or self._sequence_length,
autostack=self._autostack,
learning_rate_schedule=self._learning_rate_schedule,
keep_checkpoint_max=self._keep_checkpoint_max,
save_checkpoints_steps=self._save_checkpoints_steps,
optimizer=self._optimizer,
predict_fn=self._predict_fn,
variable_filter=self._variable_filter,
ensemble_inputs=self._ensemble_inputs,
use_tpu=None if disable_tpu else self._tpu,
tpu_job_name=self._tpu_job_name,
iterations_per_loop=self._iterations_per_loop,
cluster=self._cluster,
init_checkpoint=init_checkpoint,
score_in_predict_mode=score_in_predict_mode)
def train(self, mixture_or_task_name, steps, init_checkpoint=None,
split="train"):
"""Train the model on the given Mixture or Task.
Args:
mixture_or_task_name: str, the name of the Mixture or Task to train on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
steps: int, the total number of steps to train for.
init_checkpoint: a string, if not None then read in variables from this
checkpoint path when initializing variables. Will only initialize
variables that appear both in the current graph and the checkpoint.
split: str, the mixture/task split to train on.
"""
vocabulary = mesh_transformer.get_vocabulary(mixture_or_task_name)
dataset_fn = functools.partial(
mesh_transformer.mesh_train_dataset_fn,
mixture_or_task_name=mixture_or_task_name,
)
mtf_utils.train_model(
self.estimator(vocabulary, init_checkpoint),
vocabulary,
self._sequence_length,
self.batch_size,
dataset_fn,
steps,
self._ensemble_inputs,
dataset_split=split)
def _predict_or_score_fn(self,
tasks,
vocabulary,
checkpoint_step,
sequence_length,
split,
eval_with_score=False,
**unused_kwargs):
"""Helper function used by eval method to generate predictions or scores.
Args:
tasks: list, list of valid tasks to generate predictions or scores.
vocabulary: a t5.data.vocabulary object or a tuple with separate
vocabularies for inputs and targets,
checkpoint_step: integer, step to evaluate the tasks.
sequence_length: a dict, dictionary with sequence length for inputs and
targets.
split: string, split to run the evaluation on.
eval_with_score: bool, whether to compute log likelihood of targets
instead of predictions.
Returns:
list of decoded predictions or scores depending on eval_with_score flag.
"""
estimator = self.estimator(
vocabulary, score_in_predict_mode=eval_with_score,
sequence_length=sequence_length)
def estimator_input_fn(params):
"""Eval input function for estimator."""
del params
# Concatenate all dataset inputs to only have to do one decode loop
combined_ds = None
for task in tasks:
ds = mesh_transformer.mesh_eval_dataset_fn(
mixture_or_task_name=task.name,
sequence_length=sequence_length,
dataset_split=split)[0].dataset_fn()
ds = ds.map(
utils.filter_features,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
combined_ds = ds if not combined_ds else combined_ds.concatenate(ds)
combined_ds = combined_ds.batch(self.batch_size, drop_remainder=False) # pytype:disable=attribute-error
# Pad the final batch.
combined_ds = transformer_dataset.trim_and_pad_dataset(
combined_ds, length=self.batch_size)
combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
return combined_ds
checkpoint_path = os.path.join(self._model_dir,
"model.ckpt-{}".format(checkpoint_step))
if eval_with_score:
outputs, _ = mtf_utils.score_with_estimator(
estimator,
estimator_input_fn,
checkpoint_step,
self._model_dir,
vocabulary)
else:
outputs = [
tf.compat.as_text(d) for d in mtf_utils.decode(
estimator, estimator_input_fn, vocabulary, checkpoint_path)
]
return outputs
def eval(self,
mixture_or_task_name,
checkpoint_steps=None,
summary_dir=None,
split="validation",
eval_with_score=False,
compute_sequence_length=True):
"""Evaluate the model on the given Mixture or Task.
Args:
mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
checkpoint_steps: int, list of ints, or None. If an int or list of ints,
evaluation will be run on the checkpoint files in `model_dir` whose
global steps are closest to the global steps provided. If None, run eval
continuously waiting for new checkpoints. If -1, get the latest
checkpoint from the model directory.
summary_dir: str, path to write TensorBoard events file summaries for
eval. If None, use model_dir/eval_{split}.
split: str, the mixture/task split to evaluate on.
eval_with_score: bool, whether to evaluate using log likelihood scores of
targets instead of decoded predictions.
compute_sequence_length: bool, automatically compute maximum sequence
length to use during eval mode.
"""
_parse_operative_config(self._model_dir)
summary_dir = summary_dir or os.path.join(self._model_dir,
"{}_eval".format(split))
checkpoint_steps = utils.get_checkpoints_iterator(checkpoint_steps,
self._model_dir)
def _get_task_eval_dataset(task, sequence_length, split):
eval_datasets = mesh_transformer.mesh_eval_dataset_fn(
sequence_length=sequence_length,
dataset_split=split,
mixture_or_task_name=task.name,
)
return eval_datasets[0].dataset_fn()
utils.run_eval(
mixture_or_task_name=mixture_or_task_name,
predict_or_score_fn=functools.partial(self._predict_or_score_fn,
eval_with_score=eval_with_score,
split=split),
checkpoint_steps=checkpoint_steps,
dataset_fn=_get_task_eval_dataset,
summary_dir=summary_dir,
split=split,
sequence_length=(None
if compute_sequence_length else self._sequence_length),
batch_size=self._batch_size)
def finetune(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
mixture_or_task_name,
finetune_steps,
pretrained_model_dir,
pretrained_checkpoint_step=-1,
split="train"):
"""Finetunes a model from an existing checkpoint.
Args:
mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.`
finetune_steps: int, the number of additional steps to train for.
pretrained_model_dir: str, directory with pretrained model checkpoints and
operative config.
pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
-1 (default), use the latest checkpoint from the pretrained model
directory.
split: str, the mixture/task split to finetune on.
"""
if pretrained_checkpoint_step == -1:
checkpoint_step = utils.get_latest_checkpoint_from_dir(
pretrained_model_dir)
else:
checkpoint_step = pretrained_checkpoint_step
_parse_operative_config(pretrained_model_dir)
model_ckpt = "model.ckpt-" + str(checkpoint_step)
self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt),
split=split)
def predict(self, input_file, output_file, checkpoint_steps=-1, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=None):
"""Predicts targets from the given inputs.
Args:
input_file: str, path to a text file containing newline-separated input
prompts to predict from.
output_file: str, path prefix of output file to write predictions to. Note
the checkpoint step will be appended to the given filename.
checkpoint_steps: int, list of ints, or None. If an int or list of ints,
inference will be run on the checkpoint files in `model_dir` whose
global steps are closest to the global steps provided. If None, run
inference continuously waiting for new checkpoints. If -1, get the
latest checkpoint from the model directory.
beam_size: int, a number >= 1 specifying the number of beams to use for
beam search.
temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
0.0 means argmax, 1.0 means sample according to predicted distribution.
keep_top_k: integer, a value between 1 and the vocabulary size. When
sampling, only pick tokens that are in the k most likely.
vocabulary: vocabularies.Vocabulary object to use for tokenization, or
None to use the default SentencePieceVocabulary.
"""
# TODO(sharannarang) : It would be nice to have a function like
# load_checkpoint that loads the model once and then call decode_from_file
# multiple times without having to restore the checkpoint weights again.
# This would be particularly useful in colab demo.
if checkpoint_steps == -1:
checkpoint_steps = utils.get_latest_checkpoint_from_dir(self._model_dir)
_parse_operative_config(self._model_dir)
with gin.unlock_config():
gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
gin.bind_parameter("Bitransformer.decode.temperature", temperature)
gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k)
gin.bind_parameter("utils.decode_from_file.input_filename", input_file)
gin.bind_parameter("utils.decode_from_file.output_filename", output_file)
if vocabulary is None:
vocabulary = utils.get_vocabulary()
mtf_utils.infer_model(
self.estimator(vocabulary), vocabulary, self._sequence_length,
self.batch_size, self._model_type, self._model_dir, checkpoint_steps)
def score(self,
inputs=None,
targets=None,
mixture_or_task_name=None,
mixture_or_task_split=None,
scores_file=None,
checkpoint_steps=-1,
vocabulary=None):
"""Computes log-likelihood of target per example in targets.
Args:
inputs: optional - a string (filename), or a list of strings (inputs)
targets: optional - a string (filename), or a list of strings (targets)
mixture_or_task_name: optional - a string, the name of the Mixture or Task
to score on. Must be pre-registered in the global `TaskRegistry` or
`MixtureRegistry.` Cannot be supplied in addition to `inputs` and
`targets`.
mixture_or_task_split: optional - a string, the split of the Mixture or
Task to score on. Must be provided if scoring on a Mixture or Task.
scores_file: optional - a string (filename), to write example scores to,
one per line.
checkpoint_steps: int, list of ints, or None. If an int or list of ints,
inference will be run on the checkpoint files in `model_dir` whose
global steps are closest to the global steps provided. If None, run
inference continuously waiting for new checkpoints. If -1, get the
latest checkpoint from the model directory.
vocabulary: vocabularies.Vocabulary object to use for tokenization, or
None to use the default SentencePieceVocabulary.
Returns:
scores: a list of floating point scores matching the dataset order.
targets: a list of scored strings matching the dataset order.
"""
if bool(inputs or targets) == bool(
mixture_or_task_name or mixture_or_task_split):
raise ValueError(
"Either 'inputs' and 'targets' or "
"'mixture_or_task_name' and 'mixture_or_task_split' must be "
"specified, but not both.")
if checkpoint_steps == -1:
checkpoint_steps = utils.get_latest_checkpoint_from_dir(self._model_dir)
_parse_operative_config(self._model_dir)
with gin.unlock_config():
gin.parse_config(self._gin_bindings)
if vocabulary is None:
vocabulary = utils.get_vocabulary(mixture_or_task_name)
estimator = self.estimator(vocabulary, score_in_predict_mode=True)
score_postprocess_fn = functools.partial(
mtf_utils.save_scores, scores_filename=scores_file)
if mixture_or_task_name:
score_dataset_fn = functools.partial(
mesh_transformer.mesh_eval_dataset_fn,
mixture_or_task_name=mixture_or_task_name,
)
return mtf_utils.score_from_dataset(
estimator=estimator, vocabulary=vocabulary,
batch_size=self.batch_size, sequence_length=self._sequence_length,
model_dir=self._model_dir, eval_checkpoint_step=checkpoint_steps,
dataset_split=mixture_or_task_split,
score_dataset_fn=score_dataset_fn,
score_postprocess_fn=score_postprocess_fn)
else:
return mtf_utils.score_from_strings(
estimator=estimator, vocabulary=vocabulary,
model_type=self._model_type, batch_size=self.batch_size,
sequence_length=self._sequence_length, model_dir=self._model_dir,
eval_checkpoint_step=checkpoint_steps, inputs=inputs, targets=targets,
score_postprocess_fn=score_postprocess_fn)
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1,
temperature=1.0, keep_top_k=-1, vocabulary=None,
eval_with_score=False):
"""Exports a TensorFlow SavedModel.
Args:
export_dir: str, a directory in which to export SavedModels. Will use
`model_dir` if unspecified.
checkpoint_step: int, checkpoint to export. If -1 (default), use the
latest checkpoint from the pretrained model directory.
beam_size: int, a number >= 1 specifying the number of beams to use for
beam search.
temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
0.0 means argmax, 1.0 means sample according to predicted distribution.
keep_top_k: integer, a value between 1 and the vocabulary size. When
sampling, only pick tokens that are in the k most likely.
vocabulary: vocabularies.Vocabulary object to use for tokenization, or
None to use the default SentencePieceVocabulary.
eval_with_score: If True, compute log-likelihood scores of targets.
If False, do inference to generate outputs.
Returns:
The string path to the exported directory.
"""
if checkpoint_step == -1:
checkpoint_step = utils.get_latest_checkpoint_from_dir(self._model_dir)
_parse_operative_config(self._model_dir)
with gin.unlock_config():
gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
gin.bind_parameter("Bitransformer.decode.temperature", temperature)
gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k)
if vocabulary is None:
vocabulary = utils.get_vocabulary()
model_ckpt = "model.ckpt-" + str(checkpoint_step)
export_dir = export_dir or self._model_dir
estimator = self.estimator(
vocabulary, disable_tpu=True, score_in_predict_mode=eval_with_score)
return mtf_utils.export_model(
estimator, export_dir, vocabulary,
self._sequence_length, self._model_type, batch_size=self.batch_size,
checkpoint_path=os.path.join(self._model_dir, model_ckpt),
eval_with_score=eval_with_score)