Skip to content

Commit

Permalink
Integration test for electra model (#10073)
Browse files Browse the repository at this point in the history
  • Loading branch information
spatil6 committed Feb 8, 2021
1 parent 781220a commit 263fac7
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/test_modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,19 @@ def test_model_from_pretrained(self):
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = ElectraModel.from_pretrained(model_name)
self.assertIsNotNone(model)


@require_torch
class ElectraModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[-8.9253, -4.0305, -3.9306, -3.8774, -4.1873, -4.1280, 0.9429, -4.1672, 0.9281, 0.0410, -3.4823]]
)

self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))

0 comments on commit 263fac7

Please sign in to comment.