Skip to content

Commit

Permalink
TF: rework XLA generate tests (#16866)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Apr 22, 2022
1 parent 3b1bbef commit 6d90d76
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 54 deletions.
78 changes: 47 additions & 31 deletions tests/gpt2/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import get_gpu_count, require_tf, slow

from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
Expand Down Expand Up @@ -294,7 +294,7 @@ def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)
Expand Down Expand Up @@ -408,9 +408,9 @@ def test_gpt2_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)

def test_gpt2_xla_generate(self):
def test_gpt2_xla_generate_fast(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
self.model_tester.create_and_check_gpt2_xla_generate_fast(*config_and_inputs)

def test_gpt2_double_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
Expand Down Expand Up @@ -536,41 +536,57 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
self.assertListEqual(output_strings, expected_output_string)

@slow
def test_lm_generate_gpt2(self):
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_greedy_xla(self):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem)
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

@slow
def test_lm_generate_gpt2_xla_greedy(self):
"""This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
sentences = ["The dog"]
expected_output_strings = [
"The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids

# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = model.generate(input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)

xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)

@slow
def test_lm_generate_gpt2_xla_sample(self):
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_sample_xla(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# fmt: off
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

sentence = ["The dog"]
expected_output_string = [
"The dog must be well educated to do anything. If anything, this must be her best friend"
]
expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids

output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)

xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
52 changes: 29 additions & 23 deletions tests/t5/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

from transformers import T5Config, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import get_gpu_count, require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property

from ..test_configuration_common import ConfigTester
Expand Down Expand Up @@ -227,7 +227,7 @@ def create_and_check_t5_decoder_model_past_large_inputs(
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)

def create_and_check_t5_xla_generate(self, config, input_ids, *args):
def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
config.do_sample = False
Expand Down Expand Up @@ -297,9 +297,9 @@ def test_t5_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)

def test_t5_model_xla_generate(self):
def test_t5_model_xla_generate_fast(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs)
self.model_tester.create_and_check_t5_xla_generate_fast(*config_and_inputs)

def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down Expand Up @@ -481,12 +481,18 @@ def test_train_pipeline_custom_model(self):
@require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")

sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
# two examples with different lengths to confirm that attention masks are operational in XLA
sentences = [
"Translate English to German: Today is a beautiful day.",
"Translate English to German: I have four cats, three dogs, two birds, and a horse.",
]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids

xla_generate = tf.function(model.generate, jit_compile=True)

Expand All @@ -496,7 +502,10 @@ def test_greedy_xla_generate_simple(self):
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)

expected_output_string = ["Heute ist ein sch枚ner Tag."]
expected_output_string = [
"Heute ist ein sch枚ner Tag.",
"Ich habe vier Katzen, drei Hunde, zwei V枚gel und ein Pferd.",
]

self.assertListEqual(expected_output_string, output_strings)
self.assertListEqual(expected_output_string, output_strings_xla)
Expand Down Expand Up @@ -525,31 +534,28 @@ def test_greedy_generate(self):
self.assertListEqual(expected_output_string, output_strings)

@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_sample_xla_generate_simple(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
# results are sensible and that we can seed both versions.
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")

sentence = "Translate English to German: Today is a beautiful day."
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
# divergences in generate -- especially with sampling.
expected_output_string = ["Heute ist ein sch枚ner Tag."]
expected_output_string_xla = ["Heute ist ein sch枚ne Tage."]
# However, notice that the first tokens are the same, for the same seed
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
expected_output_string = ["Ich habe 2 Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]

# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)

# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)

Expand Down

0 comments on commit 6d90d76

Please sign in to comment.