Skip to content

Commit

Permalink
Merge pull request #5084 from tjysdsg/interctc_patch
Browse files Browse the repository at this point in the history
Add InterCTC to E-Branchformer encoder, and the ability to save InterCTC inference output to files
  • Loading branch information
mergify[bot] committed Apr 7, 2023
2 parents 4679463 + 6b33255 commit a505a23
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 10 deletions.
48 changes: 47 additions & 1 deletion espnet2/asr/encoder/e_branchformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from typeguard import check_argument_types

from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.layers.cgmlp import ConvolutionalGatingMLP
from espnet2.asr.layers.fastformer import FastSelfAttention
Expand Down Expand Up @@ -207,6 +208,8 @@ def __init__(
linear_units: int = 2048,
positionwise_layer_type: str = "linear",
merge_conv_kernel: int = 3,
interctc_layer_idx=None,
interctc_use_conditioning: bool = False,
):
assert check_argument_types()
super().__init__()
Expand Down Expand Up @@ -378,6 +381,14 @@ def __init__(
)
self.after_norm = LayerNorm(output_size)

if interctc_layer_idx is None:
interctc_layer_idx = []
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None

def output_size(self) -> int:
return self._output_size

Expand All @@ -386,13 +397,17 @@ def forward(
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
max_layer: int = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
ctc (CTC): Intermediate CTC module.
max_layer (int): Layer depth below which InterCTC is applied.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
Expand Down Expand Up @@ -420,11 +435,42 @@ def forward(
elif self.embed is not None:
xs_pad = self.embed(xs_pad)

xs_pad, masks = self.encoders(xs_pad, masks)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
if max_layer is not None and 0 <= max_layer < len(self.encoders):
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx >= max_layer:
break
else:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)

if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad

if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]

intermediate_outs.append((layer_idx + 1, encoder_out))

if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)

if isinstance(xs_pad, tuple):
xs_pad = list(xs_pad)
xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
xs_pad = tuple(xs_pad)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)

if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]

xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
65 changes: 57 additions & 8 deletions espnet2/bin/asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import logging
import sys
from distutils.version import LooseVersion
from itertools import groupby
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -32,6 +33,7 @@
from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
from espnet.nets.beam_search import BeamSearch, Hypothesis
from espnet.nets.beam_search_timesync import BeamSearchTimeSync
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorers.ctc import CTCPrefixScorer
Expand All @@ -46,6 +48,16 @@
except ImportError:
is_transformers_available = False

# Alias for typing
ListOfHypothesis = List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis, ExtTransHypothesis, TransHypothesis],
]
]


class Speech2Text:
"""Speech2Text class
Expand Down Expand Up @@ -362,13 +374,12 @@ def __init__(
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray]
) -> List[
) -> Union[
ListOfHypothesis,
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis, ExtTransHypothesis, TransHypothesis],
]
ListOfHypothesis,
Optional[Dict[int, List[str]]],
],
]:
"""Inference
Expand All @@ -395,7 +406,7 @@ def __call__(
batch = to_device(batch, device=self.device)

# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
enc, enc_olens = self.asr_model.encode(**batch)
if self.multi_asr:
enc = enc.unbind(dim=1) # (batch, num_inf, ...) -> num_inf x [batch, ...]
if self.enh_s2t_task or self.multi_asr:
Expand All @@ -420,16 +431,41 @@ def __call__(

else:
# Normal ASR
intermediate_outs = None
if isinstance(enc, tuple):
intermediate_outs = enc[1]
enc = enc[0]
assert len(enc) == 1, len(enc)

# c. Passed the encoder result and the beam search
results = self._decode_single_sample(enc[0])

# Encoder intermediate CTC predictions
if intermediate_outs is not None:
encoder_interctc_res = self._decode_interctc(intermediate_outs)
results = (results, encoder_interctc_res)
assert check_return_type(results)

return results

def _decode_interctc(
self, intermediate_outs: List[Tuple[int, torch.Tensor]]
) -> Dict[int, List[str]]:
assert check_argument_types()

exclude_ids = [self.asr_model.blank_id, self.asr_model.sos, self.asr_model.eos]
res = {}
token_list = self.beam_search.token_list

for layer_idx, encoder_out in intermediate_outs:
y = self.asr_model.ctc.argmax(encoder_out)[0] # batch_size = 1
y = [x[0] for x in groupby(y) if x[0] not in exclude_ids]
y = [token_list[x] for x in y]

res[layer_idx] = y

return res

def _decode_single_sample(self, enc: torch.Tensor):
if self.beam_search_transducer:
logging.info("encoder output length: " + str(enc.shape[0]))
Expand Down Expand Up @@ -686,6 +722,10 @@ def inference(

else:
# Normal ASR
encoder_interctc_res = None
if isinstance(results, tuple):
results, encoder_interctc_res = results

for n, (text, token, token_int, hyp) in zip(
range(1, nbest + 1), results
):
Expand All @@ -700,6 +740,15 @@ def inference(
if text is not None:
ibest_writer["text"][key] = text

# Write intermediate predictions to
# encoder_interctc_layer<layer_idx>.txt
ibest_writer = writer[f"1best_recog"]
if encoder_interctc_res is not None:
for idx, text in encoder_interctc_res.items():
ibest_writer[f"encoder_interctc_layer{idx}.txt"][
key
] = " ".join(text)


def get_parser():
parser = config_argparse.ArgumentParser(
Expand Down
22 changes: 21 additions & 1 deletion test/espnet2/asr/encoder/test_e_branchformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.e_branchformer_encoder import EBranchformerEncoder


Expand All @@ -25,6 +26,14 @@
@pytest.mark.parametrize("linear_units", [1024, 2048])
@pytest.mark.parametrize("merge_conv_kernel", [3, 31])
@pytest.mark.parametrize("layer_drop_rate", [0.0, 0.1])
@pytest.mark.parametrize(
"interctc_layer_idx, interctc_use_conditioning",
[
([], False),
([1], False),
([1], True),
],
)
def test_encoder_forward_backward(
input_layer,
use_linear_after_conv,
Expand All @@ -37,6 +46,8 @@ def test_encoder_forward_backward(
linear_units,
merge_conv_kernel,
layer_drop_rate,
interctc_layer_idx,
interctc_use_conditioning,
):
encoder = EBranchformerEncoder(
20,
Expand All @@ -57,13 +68,22 @@ def test_encoder_forward_backward(
linear_units=linear_units,
merge_conv_kernel=merge_conv_kernel,
layer_drop_rate=layer_drop_rate,
interctc_layer_idx=interctc_layer_idx,
interctc_use_conditioning=interctc_use_conditioning,
)
if input_layer == "embed":
x = torch.randint(0, 10, [2, 32])
else:
x = torch.randn(2, 32, 20, requires_grad=True)
x_lens = torch.LongTensor([32, 28])
y, _, _ = encoder(x, x_lens)

if len(interctc_layer_idx) > 0: # intermediate CTC
encoder.conditioning_layer = torch.nn.Linear(2, 2)
y, _, _ = encoder(x, x_lens, ctc=CTC(odim=2, encoder_output_size=2))
y, intermediate_outs = y
else:
y, _, _ = encoder(x, x_lens)

y.sum().backward()


Expand Down
33 changes: 33 additions & 0 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,36 @@ def test_Speech2Text_pit(asr_config_file_pit, lm_config_file):
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)


@pytest.mark.execution_timeout(20)
@pytest.mark.parametrize(
"encoder_class", ["transformer", "conformer", "e_branchformer"]
)
def test_Speech2Text_interctc(asr_config_file, lm_config_file, encoder_class):
# Change the configuration file to enable InterCTC
file = open(asr_config_file, "r", encoding="utf-8")
asr_train_config = file.read()
asr_train_config = yaml.full_load(asr_train_config)
asr_train_config["encoder"] = encoder_class
asr_train_config["encoder_conf"]["interctc_layer_idx"] = [1, 2]
asr_train_config["model_conf"]["interctc_weight"] = 0.5
with open(asr_config_file, "w", encoding="utf-8") as files:
yaml.dump(asr_train_config, files)

speech2text = Speech2Text(
asr_train_config=asr_config_file, lm_train_config=lm_config_file, beam_size=1
)
speech = np.random.randn(100000)
results, interctc_res = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)

assert isinstance(interctc_res, dict)
for k, tokens in interctc_res.items():
assert isinstance(k, int)
assert isinstance(tokens, list)
assert isinstance(tokens[0], str)

0 comments on commit a505a23

Please sign in to comment.