Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ def test_pt_tf_equivalence(self):
model = TFEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict)

@slow
def test_real_model_save_load_from_pretrained(self):
def test_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
Expand Down Expand Up @@ -542,7 +541,10 @@ def test_real_model_save_load_from_pretrained(self):
@require_tf
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
"hf-internal-testing/tiny-random-bert",
)

def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder")
Expand Down Expand Up @@ -637,7 +639,10 @@ def test_bert2bert_summarization(self):
@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "../gpt2")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
"hf-internal-testing/tiny-random-gpt2",
)

def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder")
Expand Down Expand Up @@ -726,7 +731,10 @@ def test_bert2gpt2_summarization(self):
@require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-roberta",
"hf-internal-testing/tiny-random-roberta",
)

def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRobertaModel(config, name="encoder")
Expand Down Expand Up @@ -782,7 +790,10 @@ def prepare_config_and_inputs(self):
@require_tf
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("google/rembert", "google/rembert")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-rembert",
"hf-internal-testing/tiny-random-rembert",
)

def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRemBertModel(config, name="encoder")
Expand Down