Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Speed up test_crossencoder (#3587)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller committed Apr 14, 2021
1 parent 337a688 commit 138bfd5
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions tests/nightly/gpu/test_bert.py
Expand Up @@ -17,47 +17,36 @@ class TestBertModel(unittest.TestCase):
samples on convai2
"""

@testing_utils.retry(ntries=3, log_retry=True)
def test_biencoder(self):
valid, test = testing_utils.train_model(
dict(
task='convai2',
task='integration_tests:overfit',
model='bert_ranker/bi_encoder_ranker',
num_epochs=0.1,
batchsize=8,
learningrate=3e-4,
text_truncate=32,
validation_max_exs=20,
short_final_eval=True,
max_train_steps=500,
batchsize=2,
candidates="inline",
gradient_clip=1.0,
learningrate=1e-3,
text_truncate=8,
)
)
# can't conclude much from the biencoder after that little iterations.
# this test will just make sure it hasn't crashed and the accuracy isn't
# too high
self.assertLessEqual(test['accuracy'], 0.5)
self.assertGreaterEqual(test['accuracy'], 0.5)

@testing_utils.retry(ntries=3, log_retry=True)
def test_crossencoder(self):
valid, test = testing_utils.train_model(
dict(
task='convai2',
task='integration_tests:overfit',
model='bert_ranker/cross_encoder_ranker',
num_epochs=0.002,
batchsize=1,
candidates="inline",
max_train_steps=500,
batchsize=2,
learningrate=1e-3,
gradient_clip=1.0,
type_optimization="all_encoder_layers",
warmup_updates=100,
text_truncate=32,
label_truncate=32,
validation_max_exs=20,
short_final_eval=True,
text_truncate=8,
label_truncate=8,
)
)
# The cross encoder reaches an interesting state MUCH faster
# accuracy should be present and somewhere between 0.2 and 0.8
# (large interval so that it doesn't flake.)
self.assertGreaterEqual(test['accuracy'], 0.03)
self.assertLessEqual(test['accuracy'], 0.8)
self.assertGreaterEqual(test['accuracy'], 0.8)


if __name__ == '__main__':
Expand Down

0 comments on commit 138bfd5

Please sign in to comment.