Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InterCTC to E-Branchformer encoder, and the ability to save InterCTC inference output to files #5084

Merged
merged 5 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 45 additions & 1 deletion espnet2/asr/encoder/e_branchformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from typeguard import check_argument_types

from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.layers.cgmlp import ConvolutionalGatingMLP
from espnet2.asr.layers.fastformer import FastSelfAttention
Expand Down Expand Up @@ -207,6 +208,8 @@
linear_units: int = 2048,
positionwise_layer_type: str = "linear",
merge_conv_kernel: int = 3,
interctc_layer_idx=None,
interctc_use_conditioning: bool = False,
):
assert check_argument_types()
super().__init__()
Expand Down Expand Up @@ -378,6 +381,14 @@
)
self.after_norm = LayerNorm(output_size)

if interctc_layer_idx is None:
interctc_layer_idx = []
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None

def output_size(self) -> int:
return self._output_size

Expand All @@ -386,6 +397,8 @@
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
layer: int = None,
tjysdsg marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.

Expand Down Expand Up @@ -420,11 +433,42 @@
elif self.embed is not None:
xs_pad = self.embed(xs_pad)

xs_pad, masks = self.encoders(xs_pad, masks)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
if layer is not None and (layer >= 0 and layer < len(self.encoders)):
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx >= layer:
break

Check warning on line 442 in espnet2/asr/encoder/e_branchformer_encoder.py

View check run for this annotation

Codecov / codecov/patch

espnet2/asr/encoder/e_branchformer_encoder.py#L439-L442

Added lines #L439 - L442 were not covered by tests
else:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)

if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad

if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]

intermediate_outs.append((layer_idx + 1, encoder_out))

if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)

if isinstance(xs_pad, tuple):
xs_pad = list(xs_pad)
xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
xs_pad = tuple(xs_pad)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)

if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]

xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
65 changes: 57 additions & 8 deletions espnet2/bin/asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import logging
import sys
from distutils.version import LooseVersion
from itertools import groupby
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -32,6 +33,7 @@
from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
from espnet.nets.beam_search import BeamSearch, Hypothesis
from espnet.nets.beam_search_timesync import BeamSearchTimeSync
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorers.ctc import CTCPrefixScorer
Expand All @@ -46,6 +48,16 @@
except ImportError:
is_transformers_available = False

# Alias for typing
ListOfHypothesis = List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis, ExtTransHypothesis, TransHypothesis],
]
]


class Speech2Text:
"""Speech2Text class
Expand Down Expand Up @@ -362,13 +374,12 @@
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray]
) -> List[
) -> Union[
ListOfHypothesis,
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis, ExtTransHypothesis, TransHypothesis],
]
ListOfHypothesis,
Optional[Dict[int, List[str]]],
],
]:
"""Inference

Expand All @@ -395,7 +406,7 @@
batch = to_device(batch, device=self.device)

# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
enc, enc_olens = self.asr_model.encode(**batch)
if self.multi_asr:
enc = enc.unbind(dim=1) # (batch, num_inf, ...) -> num_inf x [batch, ...]
if self.enh_s2t_task or self.multi_asr:
Expand All @@ -420,16 +431,41 @@

else:
# Normal ASR
intermediate_outs = None
if isinstance(enc, tuple):
intermediate_outs = enc[1]
enc = enc[0]
assert len(enc) == 1, len(enc)

# c. Passed the encoder result and the beam search
results = self._decode_single_sample(enc[0])

# Encoder intermediate CTC predictions
if intermediate_outs is not None:
encoder_interctc_res = self._decode_interctc(intermediate_outs)
results = (results, encoder_interctc_res)
assert check_return_type(results)

return results

def _decode_interctc(
self, intermediate_outs: List[Tuple[int, torch.Tensor]]
) -> Dict[int, List[str]]:
assert check_argument_types()

exclude_ids = [self.asr_model.blank_id, self.asr_model.sos, self.asr_model.eos]
res = {}
token_list = self.beam_search.token_list

for layer_idx, encoder_out in intermediate_outs:
y = self.asr_model.ctc.argmax(encoder_out)[0] # batch_size = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate CTC only supports greedy decoding? I think it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think greedy decoding can probably cover most use cases

y = [x[0] for x in groupby(y) if x[0] not in exclude_ids]
y = [token_list[x] for x in y]

res[layer_idx] = y

return res

def _decode_single_sample(self, enc: torch.Tensor):
if self.beam_search_transducer:
logging.info("encoder output length: " + str(enc.shape[0]))
Expand Down Expand Up @@ -686,6 +722,10 @@

else:
# Normal ASR
encoder_interctc_res = None
if isinstance(results, tuple):
results, encoder_interctc_res = results

Check warning on line 727 in espnet2/bin/asr_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/asr_inference.py#L727

Added line #L727 was not covered by tests

for n, (text, token, token_int, hyp) in zip(
range(1, nbest + 1), results
):
Expand All @@ -700,6 +740,15 @@
if text is not None:
ibest_writer["text"][key] = text

# Write intermediate predictions to
# encoder_interctc_layer<layer_idx>.txt
ibest_writer = writer[f"1best_recog"]
if encoder_interctc_res is not None:
for idx, text in encoder_interctc_res.items():
ibest_writer[f"encoder_interctc_layer{idx}.txt"][

Check warning on line 748 in espnet2/bin/asr_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/asr_inference.py#L747-L748

Added lines #L747 - L748 were not covered by tests
key
] = " ".join(text)


def get_parser():
parser = config_argparse.ArgumentParser(
Expand Down
22 changes: 21 additions & 1 deletion test/espnet2/asr/encoder/test_e_branchformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.e_branchformer_encoder import EBranchformerEncoder


Expand All @@ -25,6 +26,14 @@
@pytest.mark.parametrize("linear_units", [1024, 2048])
@pytest.mark.parametrize("merge_conv_kernel", [3, 31])
@pytest.mark.parametrize("layer_drop_rate", [0.0, 0.1])
@pytest.mark.parametrize(
"interctc_layer_idx, interctc_use_conditioning",
[
([], False),
([1], False),
([1], True),
],
)
def test_encoder_forward_backward(
input_layer,
use_linear_after_conv,
Expand All @@ -37,6 +46,8 @@ def test_encoder_forward_backward(
linear_units,
merge_conv_kernel,
layer_drop_rate,
interctc_layer_idx,
interctc_use_conditioning,
):
encoder = EBranchformerEncoder(
20,
Expand All @@ -57,13 +68,22 @@ def test_encoder_forward_backward(
linear_units=linear_units,
merge_conv_kernel=merge_conv_kernel,
layer_drop_rate=layer_drop_rate,
interctc_layer_idx=interctc_layer_idx,
interctc_use_conditioning=interctc_use_conditioning,
)
if input_layer == "embed":
x = torch.randint(0, 10, [2, 32])
else:
x = torch.randn(2, 32, 20, requires_grad=True)
x_lens = torch.LongTensor([32, 28])
y, _, _ = encoder(x, x_lens)

if len(interctc_layer_idx) > 0: # intermediate CTC
encoder.conditioning_layer = torch.nn.Linear(2, 2)
y, _, _ = encoder(x, x_lens, ctc=CTC(odim=2, encoder_output_size=2))
y, intermediate_outs = y
else:
y, _, _ = encoder(x, x_lens)

y.sum().backward()


Expand Down
33 changes: 33 additions & 0 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,36 @@ def test_Speech2Text_pit(asr_config_file_pit, lm_config_file):
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)


@pytest.mark.execution_timeout(20)
@pytest.mark.parametrize(
"encoder_class", ["transformer", "conformer", "e_branchformer"]
)
def test_Speech2Text_interctc(asr_config_file, lm_config_file, encoder_class):
# Change the configuration file to enable InterCTC
file = open(asr_config_file, "r", encoding="utf-8")
asr_train_config = file.read()
asr_train_config = yaml.full_load(asr_train_config)
asr_train_config["encoder"] = encoder_class
asr_train_config["encoder_conf"]["interctc_layer_idx"] = [1, 2]
asr_train_config["model_conf"]["interctc_weight"] = 0.5
with open(asr_config_file, "w", encoding="utf-8") as files:
yaml.dump(asr_train_config, files)

speech2text = Speech2Text(
asr_train_config=asr_config_file, lm_train_config=lm_config_file, beam_size=1
)
speech = np.random.randn(100000)
results, interctc_res = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)

assert isinstance(interctc_res, dict)
for k, tokens in interctc_res.items():
assert isinstance(k, int)
assert isinstance(tokens, list)
assert isinstance(tokens[0], str)