Skip to content

Commit

Permalink
Make _test_xla_generate less flaky (huggingface#22996)
Browse files Browse the repository at this point in the history
* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
2 people authored and gojiteji committed Jun 5, 2023
1 parent a339003 commit 0fce494
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,18 @@ def _generate_and_check_results(model, inputs_dict):
generated = model.generate(inputs, **generate_kwargs).numpy()
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
self.assertListEqual(generated.tolist(), generated_xla.tolist())

# Due to numerical instability, let's fail the test only if there are more than 10% of input sequences give
# different outputs between XLA and non-XLA versions. If there are less than 10 examples, let's be strict
# and not allow any difference.
diff = [[], []]
for _generated, _generated_xla in zip(generated.tolist(), generated_xla.tolist()):
if _generated != _generated_xla:
diff[0].append(_generated)
diff[1].append(_generated_xla)
ratio = len(diff[0]) / len(generated)
if ratio > 0.1 or (len(diff[0]) > 0 and len(generated) < 10):
self.assertListEqual(diff[0], diff[1])

for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 0fce494

Please sign in to comment.