Skip to content

Commit

Permalink
Merge pull request #4891 from slSeanWU/whisper-ci-fix
Browse files Browse the repository at this point in the history
Add python 3.8 requirement for Whisper & update tests
  • Loading branch information
mergify[bot] committed Jan 27, 2023
2 parents 95b5827 + 40080c6 commit a5a4c23
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 60 deletions.
2 changes: 1 addition & 1 deletion espnet2/text/cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

try:
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
except ImportError:
except (ImportError, SyntaxError):
BasicTextNormalizer = None


Expand Down
47 changes: 28 additions & 19 deletions test/espnet2/asr/decoder/test_whisper_decoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import pytest
import torch
from packaging.version import parse as V
Expand All @@ -10,37 +12,41 @@

# NOTE(Shih-Lun): needed for `persistent` param in
# torch.nn.Module.register_buffer()
is_torch_1_6_plus = V(torch.__version__) >= V("1.6.0")
is_torch_1_7_plus = V(torch.__version__) >= V("1.7.0")
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.fixture()
def whisper_decoder(request):
if not is_torch_1_6_plus:
return None

return OpenAIWhisperDecoder(
vocab_size=VOCAB_SIZE_WHISPER_MULTILINGUAL,
encoder_output_size=384,
whisper_model="tiny",
)


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_init(whisper_decoder):
if not is_torch_1_6_plus:
return

assert (
whisper_decoder.decoders.token_embedding.num_embeddings
== VOCAB_SIZE_WHISPER_MULTILINGUAL
)


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_reinit_emb():
if not is_torch_1_6_plus:
return

vocab_size = 1000
decoder = OpenAIWhisperDecoder(
vocab_size=vocab_size,
Expand All @@ -50,10 +56,11 @@ def test_decoder_reinit_emb():
assert decoder.decoders.token_embedding.num_embeddings == vocab_size


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
def test_decoder_invalid_init():
if not is_torch_1_6_plus:
return

with pytest.raises(AssertionError):
decoder = OpenAIWhisperDecoder(
vocab_size=VOCAB_SIZE_WHISPER_MULTILINGUAL,
Expand All @@ -63,11 +70,12 @@ def test_decoder_invalid_init():
del decoder


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_forward_backward(whisper_decoder):
if not is_torch_1_6_plus:
return

hs_pad = torch.randn(4, 100, 384, device=next(whisper_decoder.parameters()).device)
ys_in_pad = torch.randint(
0, 3000, (4, 10), device=next(whisper_decoder.parameters()).device
Expand All @@ -78,11 +86,12 @@ def test_decoder_forward_backward(whisper_decoder):
out.sum().backward()


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_decoder_scoring(whisper_decoder):
if not is_torch_1_6_plus:
return

hs_pad = torch.randn(4, 100, 384, device=next(whisper_decoder.parameters()).device)
ys_in_pad = torch.randint(
0, 3000, (4, 10), device=next(whisper_decoder.parameters()).device
Expand Down
45 changes: 27 additions & 18 deletions test/espnet2/asr/encoder/test_whisper_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import pytest
import torch
from packaging.version import parse as V
Expand All @@ -8,40 +10,45 @@

# NOTE(Shih-Lun): needed for `return_complex` param in torch.stft()
is_torch_1_7_plus = V(torch.__version__) >= V("1.7.0")
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.fixture()
def whisper_encoder(request):
if not is_torch_1_7_plus:
return None

encoder = OpenAIWhisperEncoder(whisper_model="tiny")

return encoder


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_init(whisper_encoder):
if not is_torch_1_7_plus:
return

assert whisper_encoder.output_size() == 384


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
def test_encoder_invalid_init():
if not is_torch_1_7_plus:
return

with pytest.raises(AssertionError):
encoder = OpenAIWhisperEncoder(whisper_model="aaa")
del encoder


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_forward_no_ilens(whisper_encoder):
if not is_torch_1_7_plus:
return

input_tensor = torch.randn(
4, 3200, device=next(whisper_encoder.parameters()).device
)
Expand All @@ -50,11 +57,12 @@ def test_encoder_forward_no_ilens(whisper_encoder):
assert xs_pad.size() == torch.Size([4, 10, 384])


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_forward_ilens(whisper_encoder):
if not is_torch_1_7_plus:
return

input_tensor = torch.randn(
4, 3200, device=next(whisper_encoder.parameters()).device
)
Expand All @@ -67,11 +75,12 @@ def test_encoder_forward_ilens(whisper_encoder):
assert torch.equal(olens.cpu(), torch.tensor([2, 3, 5, 10]))


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_encoder_backward(whisper_encoder):
if not is_torch_1_7_plus:
return

whisper_encoder.train()
input_tensor = torch.randn(
4, 3200, device=next(whisper_encoder.parameters()).device
Expand Down
38 changes: 23 additions & 15 deletions test/espnet2/asr/frontend/test_whisper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import pytest
import torch
from packaging.version import parse as V
Expand All @@ -8,40 +10,45 @@

# NOTE(Shih-Lun): required by `return_complex` in torch.stft()
is_torch_1_7_plus = V(torch.__version__) >= V("1.7.0")
is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.fixture()
def whisper_frontend(request):
if not is_torch_1_7_plus:
return None

with torch.no_grad():
return WhisperFrontend("tiny")


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_frontend_init():
if not is_torch_1_7_plus:
return

frontend = WhisperFrontend("tiny")
assert frontend.output_size() == 384


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
def test_frontend_invalid_init():
if not is_torch_1_7_plus:
return

with pytest.raises(AssertionError):
frontend = WhisperFrontend("aaa")
del frontend


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_frontend_forward_no_ilens(whisper_frontend):
if not is_torch_1_7_plus:
return

input_tensor = torch.randn(
4, 3200, device=next(whisper_frontend.parameters()).device
)
Expand All @@ -50,11 +57,12 @@ def test_frontend_forward_no_ilens(whisper_frontend):
assert feats.size() == torch.Size([4, 10, 384])


@pytest.mark.skipif(
not is_python_3_8_plus or not is_torch_1_7_plus,
reason="whisper not supported on python<3.8, torch<1.7",
)
@pytest.mark.timeout(50)
def test_frontend_forward_ilens(whisper_frontend):
if not is_torch_1_7_plus:
return

input_tensor = torch.randn(
4, 3200, device=next(whisper_frontend.parameters()).device
)
Expand Down
19 changes: 19 additions & 0 deletions test/espnet2/text/test_whisper_token_id_converter.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
import sys

import pytest

from espnet2.text.whisper_token_id_converter import OpenAIWhisperTokenIDConverter

pytest.importorskip("whisper")

is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
@pytest.fixture(params=["whisper_multilingual"])
def whisper_token_id_converter(request):
return OpenAIWhisperTokenIDConverter(request.param)


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_invalid():
with pytest.raises(ValueError):
OpenAIWhisperTokenIDConverter("whisper_aaa")


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_en():
id_converter = OpenAIWhisperTokenIDConverter("whisper_en")
assert id_converter.get_num_vocabulary_size() == 50363


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_ids2tokens(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
tokens = whisper_token_id_converter.ids2tokens(
[17155, 11, 220, 83, 378, 320, 311, 5503, 307, 1481, 13, 8239, 485]
Expand All @@ -42,6 +58,9 @@ def test_ids2tokens(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
]


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_tokens2ids(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
ids = whisper_token_id_converter.tokens2ids(
[
Expand Down
19 changes: 19 additions & 0 deletions test/espnet2/text/test_whisper_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
import sys

import pytest

from espnet2.text.whisper_tokenizer import OpenAIWhisperTokenizer

pytest.importorskip("whisper")

is_python_3_8_plus = sys.version_info >= (3, 8)


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
@pytest.fixture(params=["whisper_multilingual"])
def whisper_tokenizer(request):
return OpenAIWhisperTokenizer(request.param)


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_en():
tokenizer = OpenAIWhisperTokenizer("whisper_en")
assert tokenizer.tokenizer.tokenizer.vocab_size == 50257


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_init_invalid():
with pytest.raises(ValueError):
OpenAIWhisperTokenizer("whisper_aaa")


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_repr(whisper_tokenizer: OpenAIWhisperTokenizer):
print(whisper_tokenizer)


@pytest.mark.skipif(
not is_python_3_8_plus, reason="whisper not supported on python<3.8"
)
def test_tokenization_consistency(whisper_tokenizer: OpenAIWhisperTokenizer):
s = "Hi, today's weather is nice. Hmm..."

Expand Down

0 comments on commit a5a4c23

Please sign in to comment.