Skip to content

Commit

Permalink
[Flax/JAX] Run jitted tests at every commit (#13090)
Browse files Browse the repository at this point in the history
* up

* up

* up
  • Loading branch information
patrickvonplaten committed Aug 12, 2021
1 parent 773d386 commit 6900dde
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 5 deletions.
27 changes: 27 additions & 0 deletions tests/test_modeling_flax_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@


if is_flax_available():
import jax
from transformers.models.big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
Expand Down Expand Up @@ -162,3 +163,29 @@ def test_model_from_pretrained(self):
def test_attention_outputs(self):
if self.test_attn_probs:
super().test_attention_outputs()

@slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)

@jax.jit
def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()

with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()

self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):

self.assertEqual(jitted_output.shape, output.shape)
1 change: 0 additions & 1 deletion tests/test_modeling_flax_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def setUp(self):
def test_hidden_states_output(self):
pass

@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
2 changes: 0 additions & 2 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
is_pt_flax_cross_test,
is_staging_test,
require_flax,
slow,
)
from transformers.utils import logging

Expand Down Expand Up @@ -391,7 +390,6 @@ def test_save_load_to_base_pt(self):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
1 change: 0 additions & 1 deletion tests/test_modeling_flax_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def test_forward_signature(self):
self.assertListEqual(arg_names[:1], expected_arg_names)

# We neeed to override this test because ViT expects pixel_values instead of input_ids
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
1 change: 0 additions & 1 deletion tests/test_modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def test_forward_signature(self):
expected_arg_names = ["input_values", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)

@slow
# overwrite because of `input_values`
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 6900dde

Please sign in to comment.