Skip to content

Commit

Permalink
add simple k2 unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
b-flo committed Jul 5, 2023
1 parent 56a4fbf commit 8c70533
Showing 1 changed file with 72 additions and 2 deletions.
74 changes: 72 additions & 2 deletions test/espnet2/asr_transducer/test_espnet_transducer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@ def stats_file(tmp_path: Path):
return p


def prepare(model, input_size, vocab_size, batch_size):
def prepare(model, input_size, vocab_size, batch_size, use_k2_modified_loss=False):
n_token = vocab_size - 1

feat_len = [15, 11]
label_len = [13, 9]

# (b-flo): For k2 "modified", we need to ensure that T >= U after subsampling.
if use_k2_modified_loss:
feat_len = [i * 5 for i in label_len]
else:
feat_len = [15, 11]

feats = torch.randn(batch_size, max(feat_len), input_size)
labels = (torch.rand(batch_size, max(label_len)) * n_token % n_token).long()

Expand Down Expand Up @@ -356,6 +361,71 @@ def test_model_training(
_ = model(feats, feat_len, labels, label_len)


@pytest.mark.parametrize(
"k2_params",
[
{},
{"lm_scale": 0.25, "am_scale": 0.5},
{"loss_type": "modified"},
],
)
def test_model_training_with_k2(k2_params):
pytest.importorskip("k2")

batch_size = 2
input_size = 10

token_list = ["<blank>", "a", "b", "c", "<space>"]
vocab_size = len(token_list)

encoder = Encoder(
input_size,
[
{
"block_type": "conformer",
"hidden_size": 4,
"linear_size": 4,
"conv_mod_kernel_size": 3,
}
],
)
decoder = RNNDecoder(vocab_size, embed_size=8, hidden_size=8)

joint_network = JointNetwork(
vocab_size,
encoder.output_size,
decoder.output_size,
)

model = ESPnetASRTransducerModel(
vocab_size,
token_list,
frontend=None,
normalize=None,
specaug=None,
encoder=encoder,
decoder=decoder,
joint_network=joint_network,
use_k2_pruned_loss=True,
k2_pruned_loss_args=k2_params,
report_cer=True,
report_wer=True,
)

feats, labels, feat_len, label_len = prepare(
model,
input_size,
vocab_size,
batch_size,
use_k2_modified_loss=True,
)

_ = model(feats, feat_len, labels, label_len)

model.training = False
_ = model(feats, feat_len, labels, label_len)


@pytest.mark.parametrize("extract_feats", [True, False])
def test_collect_feats(extract_feats):
token_list = ["<blank>", "a", "b", "c", "<space>"]
Expand Down

0 comments on commit 8c70533

Please sign in to comment.