Skip to content

Commit

Permalink
add required pytest skipif reason
Browse files Browse the repository at this point in the history
  • Loading branch information
Shih-Lun Wu committed Jan 27, 2023
1 parent afee905 commit 40080c6
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 28 deletions.
32 changes: 25 additions & 7 deletions test/espnet2/asr/decoder/test_whisper_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@

# NOTE(Shih-Lun): needed for `persistent` param in
# torch.nn.Module.register_buffer()
is_torch_1_6_plus = V(torch.__version__) >= V("1.6.0")
is_torch_1_7_plus = V(torch.__version__) >= V("1.7.0")
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_6_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.fixture()
def whisper_decoder(request):
return OpenAIWhisperDecoder(
Expand All @@ -26,7 +29,10 @@ def whisper_decoder(request):
)


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_6_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_init(whisper_decoder):
assert (
Expand All @@ -35,7 +41,10 @@ def test_decoder_init(whisper_decoder):
)


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_6_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_reinit_emb():
vocab_size = 1000
Expand All @@ -47,7 +56,10 @@ def test_decoder_reinit_emb():
assert decoder.decoders.token_embedding.num_embeddings == vocab_size


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_6_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
def test_decoder_invalid_init():
with pytest.raises(AssertionError):
decoder = OpenAIWhisperDecoder(
Expand All @@ -58,7 +70,10 @@ def test_decoder_invalid_init():
del decoder


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_6_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_forward_backward(whisper_decoder):
hs_pad = torch.randn(4, 100, 384, device=next(whisper_decoder.parameters()).device)
Expand All @@ -71,7 +86,10 @@ def test_decoder_forward_backward(whisper_decoder):
out.sum().backward()


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_6_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_scoring(whisper_decoder):
hs_pad = torch.randn(4, 100, 384, device=next(whisper_decoder.parameters()).device)
Expand Down
30 changes: 24 additions & 6 deletions test/espnet2/asr/encoder/test_whisper_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,40 @@
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.fixture()
def whisper_encoder(request):
encoder = OpenAIWhisperEncoder(whisper_model="tiny")

return encoder


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_init(whisper_encoder):
assert whisper_encoder.output_size() == 384


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
def test_encoder_invalid_init():
with pytest.raises(AssertionError):
encoder = OpenAIWhisperEncoder(whisper_model="aaa")
del encoder


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_forward_no_ilens(whisper_encoder):
input_tensor = torch.randn(
Expand All @@ -45,7 +57,10 @@ def test_encoder_forward_no_ilens(whisper_encoder):
assert xs_pad.size() == torch.Size([4, 10, 384])


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_forward_ilens(whisper_encoder):
input_tensor = torch.randn(
Expand All @@ -60,7 +75,10 @@ def test_encoder_forward_ilens(whisper_encoder):
assert torch.equal(olens.cpu(), torch.tensor([2, 3, 5, 10]))


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_backward(whisper_encoder):
whisper_encoder.train()
Expand Down
25 changes: 20 additions & 5 deletions test/espnet2/asr/frontend/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,40 @@
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.fixture()
def whisper_frontend(request):
with torch.no_grad():
return WhisperFrontend("tiny")


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_frontend_init():
frontend = WhisperFrontend("tiny")
assert frontend.output_size() == 384


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
def test_frontend_invalid_init():
with pytest.raises(AssertionError):
frontend = WhisperFrontend("aaa")
del frontend


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_frontend_forward_no_ilens(whisper_frontend):
input_tensor = torch.randn(
Expand All @@ -45,7 +57,10 @@ def test_frontend_forward_no_ilens(whisper_frontend):
assert feats.size() == torch.Size([4, 10, 384])


@pytest.mark.skipif(not is_python_3_8_plus or not is_torch_1_7_plus)
@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_frontend_forward_ilens(whisper_frontend):
input_tensor = torch.randn(
Expand Down
20 changes: 15 additions & 5 deletions test/espnet2/text/test_whisper_token_id_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,33 @@
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
@pytest.fixture(params=["whisper_multilingual"])
def whisper_token_id_converter(request):
return OpenAIWhisperTokenIDConverter(request.param)


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_invalid():
with pytest.raises(ValueError):
OpenAIWhisperTokenIDConverter("whisper_aaa")


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_en():
id_converter = OpenAIWhisperTokenIDConverter("whisper_en")
assert id_converter.get_num_vocabulary_size() == 50363


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_ids2tokens(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
tokens = whisper_token_id_converter.ids2tokens(
[17155, 11, 220, 83, 378, 320, 311, 5503, 307, 1481, 13, 8239, 485]
Expand All @@ -50,7 +58,9 @@ def test_ids2tokens(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
]


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_tokens2ids(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
ids = whisper_token_id_converter.tokens2ids(
[
Expand Down
20 changes: 15 additions & 5 deletions test/espnet2/text/test_whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,40 @@
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
@pytest.fixture(params=["whisper_multilingual"])
def whisper_tokenizer(request):
return OpenAIWhisperTokenizer(request.param)


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_en():
tokenizer = OpenAIWhisperTokenizer("whisper_en")
assert tokenizer.tokenizer.tokenizer.vocab_size == 50257


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_invalid():
with pytest.raises(ValueError):
OpenAIWhisperTokenizer("whisper_aaa")


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_repr(whisper_tokenizer: OpenAIWhisperTokenizer):
print(whisper_tokenizer)


@pytest.mark.skipif(not is_python_3_8_plus)
@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_tokenization_consistency(whisper_tokenizer: OpenAIWhisperTokenizer):
s = "Hi, today's weather is nice. Hmm..."

Expand Down

0 comments on commit 40080c6

Please sign in to comment.