Skip to content

Commit

Permalink
all tests are passing locally
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 3, 2024
1 parent bff7fc3 commit 48ecc7e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,15 @@ def check_pt_tf_equivalence(self, tf_model, pt_model, tf_inputs_dict):
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)

def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
if _run_slow_tests:
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"

encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Output all for aggressive testing
encoder_decoder_config.output_hidden_states = True
# All models tested in this file have attentions
encoder_decoder_config.output_attentions = True
encoder_decoder_config.output_attentions = _run_slow_tests

pt_model = VisionEncoderDecoderModel(encoder_decoder_config)

Expand All @@ -481,11 +485,17 @@ def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)

def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
# When taking a model from tf we are using the default attention
# mode (sdpa) so we are not expecting attention
config_output_attention = config.output_attentions
config.output_attentions = False

encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

# Output all for aggressive testing
encoder_decoder_config.output_hidden_states = True
# TODO: A generalizable way to determine this attribute
encoder_decoder_config.output_attentions = True
encoder_decoder_config.output_attentions = False

tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built before saving
Expand All @@ -496,6 +506,8 @@ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)

self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
# Revert mutable objet modification
config.output_attentions = config_output_attention

def test_encoder_decoder_model(self):
config_inputs_dict = self.prepare_config_and_inputs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def test_save_and_load_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load_encoder_decoder_model(**input_ids_dict)

@slow
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
Expand Down Expand Up @@ -412,9 +413,6 @@ def check_encoder_decoder_model_output_attentions(
pixel_values=None,
**kwargs,
):
if not _run_slow_tests:
return

# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
Expand Down Expand Up @@ -461,6 +459,8 @@ def check_encoder_decoder_model_output_attentions(
)

def get_encoder_decoder_model(self, config, decoder_config):
if _run_slow_tests:
config._attn_implementation = "eager"
encoder_model = DeiTModel(config).eval()
decoder_model = BertLMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
Expand Down Expand Up @@ -526,6 +526,8 @@ def get_pretrained_model_and_inputs(self):
return model, inputs

def get_encoder_decoder_model(self, config, decoder_config):
if _run_slow_tests:
config._attn_implementation = "eager"
encoder_model = ViTModel(config).eval()
decoder_model = BertLMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
Expand Down Expand Up @@ -565,29 +567,6 @@ def prepare_config_and_inputs(self):
"labels": decoder_token_labels,
}

def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
):
if not _run_slow_tests:
return

super().check_encoder_decoder_model_output_attentions(
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
)


@require_torch
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
Expand Down Expand Up @@ -677,6 +656,8 @@ def test_real_model_save_load_from_pretrained(self):
@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
if _run_slow_tests:
config._attn_implementation = "eager"
encoder_model = ViTModel(config).eval()
decoder_model = TrOCRForCausalLM(decoder_config).eval()
return encoder_model, decoder_model
Expand Down Expand Up @@ -704,29 +685,6 @@ def prepare_config_and_inputs(self):
"labels": decoder_input_ids,
}

def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
):
if not _run_slow_tests:
return

super().check_encoder_decoder_model_output_attentions(
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
)

# there are no published pretrained TrOCR checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass
Expand Down

0 comments on commit 48ecc7e

Please sign in to comment.