Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone Transducer v1.1 #5140

Merged
merged 61 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
aa0f1a0
fix chunk mask
b-flo Nov 23, 2022
9bd640a
remove right context + minor fixes
b-flo Nov 23, 2022
028b2d5
monkey patch chunk-by-chunk decoding before rework
b-flo Nov 23, 2022
7940ec0
rework v0.1
b-flo Dec 6, 2022
edeacc9
Merge branch 'master' into refactoring
b-flo Dec 6, 2022
12636be
bump to v0.2
b-flo Dec 7, 2022
0c57964
update streaming tests
b-flo Dec 7, 2022
0e1691b
remove old commented code
b-flo Dec 7, 2022
ce615fc
add back buffering
b-flo Dec 8, 2022
87acd89
alternative v0.2
b-flo Dec 16, 2022
8ce6802
Merge branch 'master' into refactoring
b-flo Dec 20, 2022
0606aa4
add back display_partial_hypotheses option + minor fixes
b-flo Jan 3, 2023
bd3062b
Merge branch 'master' into refactoring
b-flo Jan 3, 2023
6cc36ba
remove unused code
b-flo Jan 3, 2023
17b917b
fix convinput subsampling tests
b-flo Jan 3, 2023
97cb471
Merge branch 'master' into refactoring
b-flo Jan 4, 2023
7267f5b
remove math lib usage
b-flo Jan 4, 2023
f7bbb4e
improve doc and tutorial for left context/chunks
b-flo Jan 4, 2023
e4a4317
improve doc and tutorial for left context/chunks (2)
b-flo Jan 4, 2023
8b1cd2c
fix streaming test
b-flo Jan 4, 2023
722457e
Merge branch 'master' into refactoring
b-flo Jan 18, 2023
643e964
v0.2 stable
b-flo Feb 1, 2023
4b0f196
Merge branch 'master' into refactoring
b-flo Feb 1, 2023
2229ef2
add offline/online ebranchformer + tests
b-flo Feb 5, 2023
9756379
add layerdrop (w/ decay)
b-flo Feb 6, 2023
0c825d0
update doc
b-flo Feb 6, 2023
86b0577
small fix for layerdrop
b-flo Feb 7, 2023
db4a344
apply new black
b-flo Feb 9, 2023
01893f8
Merge branch 'master' into refactoring
b-flo Feb 9, 2023
ee30ce9
add back dec proj bias + remove merge mod dropout
b-flo Feb 14, 2023
9ed81b1
slight refactor before adding MEGA decoder
b-flo Apr 20, 2023
9ca8a4b
add error calc. fix to avoid conflicts
b-flo Apr 20, 2023
f52e2d2
add MEGA decoder + docs
b-flo Apr 20, 2023
e06e905
Merge branch 'master' into refactoring
b-flo Apr 20, 2023
ad09c47
add chunk mechanism
b-flo Apr 24, 2023
7e51a6e
Merge branch 'master' into refactoring
b-flo Apr 24, 2023
d341500
fix device mismatch
b-flo Apr 26, 2023
cfb3b52
monkey patch states for inference
b-flo Apr 26, 2023
5d9398b
add first unit tests for mega
b-flo Apr 26, 2023
7de9f26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2023
40043d8
remove unused methods for task
b-flo Apr 27, 2023
d108dc7
improve mega coverage
b-flo Apr 27, 2023
a466e74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2023
246104d
fix conflict
b-flo Apr 28, 2023
72aab0c
Merge branch 'master' into refactoring
b-flo Apr 28, 2023
34befa7
fix streaming integration test
b-flo Apr 28, 2023
859abef
stitch branches (rwkv, mega/general fixes, etc)
b-flo Jun 1, 2023
ae1d400
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
e3192f9
Merge branch 'master' into refactoring
b-flo Jun 1, 2023
5651dad
add missing init files
b-flo Jun 1, 2023
17a624c
add missing init files (2)
b-flo Jun 1, 2023
a085a20
add rescaling option during inference
b-flo Jun 5, 2023
65dc349
Merge branch 'master' into refactoring
b-flo Jun 5, 2023
81798b3
add missing guard conditions
b-flo Jun 5, 2023
aa0045d
add rwkv tests + fixes
b-flo Jun 5, 2023
04cc557
add ninja install through warp-transducer install
b-flo Jun 6, 2023
8bc222c
add skip for rwkv tests without gpu
b-flo Jun 7, 2023
a59c94b
remove unused import
b-flo Jun 7, 2023
4f81c23
improve/fix documentation for new additions
b-flo Jun 7, 2023
f1632c0
fix typos (type, docs)
b-flo Jun 13, 2023
683a291
Merge branch 'master' into refactoring
b-flo Jun 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ ESPnet uses [pytorch](http://pytorch.org/) as a deep learning engine and also fo
- Data augmentation
- **Transducer** based end-to-end ASR
- Architecture:
- RNN-based encoder and decoder.
- Custom encoder and decoder supporting Transformer, Conformer (encoder), 1D Conv / TDNN (encoder) and causal 1D Conv (decoder) blocks.
- VGG2L (RNN/custom encoder) and Conv2D (custom encoder) bottlenecks.
- Custom encoder supporting RNNs, Conformer, Branchformer (w/ variants), 1D Conv / TDNN.
- Decoder w/ parameters shared accross blocks supporting RNN, stateless w/ 1D Conv, [[MEGA]](https://arxiv.org/abs/2209.10655), and [[RWKV]](https://arxiv.org/abs/2305.13048).
- Pre-encoder: VGG2L or Conv2D available.
- Search algorithms:
- Greedy search constrained to one emission by timestep.
- Default beam search algorithm [[Graves, 2012]](https://arxiv.org/abs/1211.3711) without prefix search.
Expand All @@ -86,6 +86,7 @@ ESPnet uses [pytorch](http://pytorch.org/) as a deep learning engine and also fo
- N-step Constrained beam search modified from [[Kim et al., 2020]](https://arxiv.org/abs/2002.03577).
- modified Adaptive Expansion Search based on [[Kim et al., 2021]](https://ieeexplore.ieee.org/abstract/document/9250505) and NSC.
- Features:
- Unified interface for offline and streaming speech recognition.
- Multi-task learning with various auxiliary losses:
- Encoder: CTC, auxiliary Transducer and symmetric KL divergence.
- Decoder: cross-entropy w/ label smoothing.
Expand Down
2 changes: 1 addition & 1 deletion ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if python3 -c "from warprnnt_pytorch import RNNTLoss" &> /dev/null; then
--encoder_conf main_conf='{'dynamic_chunk_training': True}' \
--encoder_conf body_conf='[{'block_type': 'conformer', 'hidden_size': 30, 'linear_size': 30, 'heads': 2, 'conv_mod_kernel_size': 3}]' \
--decoder_conf='{'embed_size': 30, 'hidden_size': 30}' --joint_network_conf joint_space_size=30 " \
--inference-args "--streaming true --chunk_size 2 --left_context 2 --right_context 0"
--inference-args "--streaming true --decoding_window 160 --left_context 2"
done
fi

Expand Down
148 changes: 104 additions & 44 deletions doc/espnet2_tutorial.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion espnet2/asr_transducer/activation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Activation functions for Transducer."""
"""Activation functions for Transducer models."""

import torch
from packaging.version import parse as V
Expand Down
21 changes: 7 additions & 14 deletions espnet2/asr_transducer/beam_search_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class Hypothesis:
Args:
score: Total log-probability.
yseq: Label sequence as integer ID sequence.
dec_state: RNNDecoder or StatelessDecoder state.
((N, 1, D_dec), (N, 1, D_dec) or None) or None
dec_state: RNN/MEGA Decoder state (None if Stateless).
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None

"""
Expand Down Expand Up @@ -51,7 +50,7 @@ class BeamSearchTransducer:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Size of the beam.
lm: LM class.
lm: LM module.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
self.score_norm = score_norm
self.nbest = nbest

self.reset_inference_cache()
self.reset_cache()

def __call__(
self,
Expand All @@ -168,16 +167,16 @@ def __call__(
hyps = self.search_algorithm(enc_out)

if is_final:
self.reset_inference_cache()
self.reset_cache()

return self.sort_nbest(hyps)

self.search_cache = hyps

return hyps

def reset_inference_cache(self) -> None:
"""Reset cache for decoder scoring and streaming."""
def reset_cache(self) -> None:
"""Reset cache for streaming decoding."""
self.decoder.score_cache = {}
self.search_cache = None

Expand Down Expand Up @@ -312,14 +311,7 @@ def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
max_hyp = max(hyps, key=lambda x: x.score)
hyps.remove(max_hyp)

label = torch.full(
(1, 1),
max_hyp.yseq[-1],
dtype=torch.long,
device=self.decoder.device,
)
dec_out, state = self.decoder.score(
label,
max_hyp.yseq,
max_hyp.dec_state,
)
Expand Down Expand Up @@ -405,6 +397,7 @@ def align_length_sync_decoding(

B_ = []
B_enc_out = []

for hyp in B:
u = len(hyp.yseq) - 1
t = i - u
Expand Down
92 changes: 70 additions & 22 deletions espnet2/asr_transducer/decoder/abs_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Abstract decoder definition for Transducer models."""

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

Expand All @@ -14,33 +14,40 @@ def forward(self, labels: torch.Tensor) -> torch.Tensor:
"""Encode source label sequences.

Args:
labels: Label ID sequences. (B, L)
labels: Label ID sequences.

Returns:
dec_out: Decoder output sequences. (B, T, D_dec)
: Decoder output sequences.

"""
raise NotImplementedError

@abstractmethod
def score(
self,
label: torch.Tensor,
label_sequence: List[int],
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
states: Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.Tensor]],
],
) -> Tuple[
torch.Tensor,
Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.Tensor]],
],
]:
"""One-step forward hypothesis.

Args:
label: Previous label. (1, 1)
label_sequence: Current label sequence.
dec_state: Previous decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None) or None
state: Decoder hidden states.

Returns:
dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb)
dec_state: Decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None) or None
out: Decoder output sequence.
state: Decoder hidden states.

"""
raise NotImplementedError
Expand All @@ -49,16 +56,22 @@ def score(
def batch_score(
self,
hyps: List[Any],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
) -> Tuple[
torch.Tensor,
Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.Tensor]],
],
]:
"""One-step forward hypotheses.

Args:
hyps: Hypotheses.

Returns:
dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb)
out: Decoder output sequences.
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None

"""
raise NotImplementedError
Expand All @@ -76,35 +89,70 @@ def set_device(self, device: torch.Tensor) -> None:
@abstractmethod
def init_state(
self, batch_size: int
) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]:
) -> Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.tensor]],
]:
"""Initialize decoder states.

Args:
batch_size: Batch size.

Returns:
: Initial decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
: Decoder hidden states.

"""
raise NotImplementedError

@abstractmethod
def select_state(
self,
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
states: Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.Tensor]],
],
idx: int = 0,
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
) -> Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.Tensor]],
]:
"""Get specified ID state from batch of states, if provided.

Args:
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
idx: State ID to extract.

Returns:
: Decoder hidden state for given ID.
((N, 1, D_dec), (N, 1, D_dec) or None) or None

"""
raise NotImplementedError

@abstractmethod
def create_batch_states(
self,
new_states: List[
Union[
List[Dict[str, Optional[torch.Tensor]]],
List[List[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor]],
],
],
) -> Union[
List[Dict[str, torch.Tensor]],
List[torch.Tensor],
Tuple[torch.Tensor, Optional[torch.Tensor]],
]:
"""Create batch of decoder hidden states given a list of new states.

Args:
new_states: Decoder hidden states.

Returns:
: Decoder hidden states.

"""
raise NotImplementedError
Empty file.
Loading
Loading