## Dependencies

approximately 5-6 minutes

In [2]:
%%time
%%bash

pip install -qq sentencepiece jiwer nemo-toolkit[asr]==2.0.0

     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 47.2/47.2 kB 4.3 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.9/63.9 kB 7.1 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 117.0/117.0 kB 11.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 180.3/180.3 kB 16.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 4.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.5/154.5 kB 14.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 845.4/845.4 kB 36.5 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.5/79.5 kB 8.5 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.0/16.0 MB 59.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.0/823.0 kB 43.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 71.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 897.5/897.5 kB 52.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.12.0 which is incompatible.


CPU times: user 443 ms, sys: 63.5 ms, total: 507 ms
Wall time: 2min 13s




---

## Hybrid CTC & RNN-t model


* Hybrid RNNT-CTC models is a group of models with both the RNNT and CTC decoders. Training a unified model would speedup the convergence for the CTC models and would enable the user to use a single model which works as both a CTC and RNNT model. This category can be used with any of the ASR models.[1]

* So we can get speed from a CTC decoder and quality from a RNN-t decoder. This is extremely useful for production systems where you need to make partial predictions to show on screen while people are talking, and then make a final prediction. The first requests are usually handled by a fast CTC decoder, and the final prediction is done by RNN-t decoder.

\

<img alt="hybrid" src="https://drive.google.com/uc?id=1e8oe4CfBf8UmvWdm--DK_q126EK9A9tg" width=400>

\

* More about hybrid models:
  * [[1] NeMo docs](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/models.html#hybrid-transducer-ctc)
  * [[2] RNNT + LAS](https://arxiv.org/pdf/1908.10992)
  * [[3] CTC + LAS](https://arxiv.org/pdf/1609.06773)
  * [[4] Hybrid Rescoring 1](https://arxiv.org/pdf/2008.13093)
  * [[5] Hybrid Rescoring 2](https://arxiv.org/pdf/2101.11577)




---



In [3]:
import re
import typing as tp

import torch
import torch.nn as nn
import torchaudio
import soundfile as sf
from jiwer import wer
from tqdm import tqdm
import IPython.display as dsp
from sentencepiece import SentencePieceProcessor

from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel

BLANK_IND: int = 1024


def clear(text: str):
  return re.sub(r'[^A-Za-z +]', '', text.lower())

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = EncDecHybridRNNTCTCBPEModel.from_pretrained(
    model_name="stt_en_fastconformer_hybrid_large_pc"
).to(device).eval()

TOKENIZER: SentencePieceProcessor = model.tokenizer.tokenizer

dsp.clear_output()

#### Audio

In [1]:
! gdown https://drive.google.com/uc?id=1eNQt0R7Dm71utkLuRhc9wjTjdWoYrwsk -O best-song-ever.wav

Downloading...
From: https://drive.google.com/uc?id=1eNQt0R7Dm71utkLuRhc9wjTjdWoYrwsk
To: /content/best-song-ever.wav
  0% 0.00/1.09M [00:00<?, ?B/s]100% 1.09M/1.09M [00:00<00:00, 107MB/s]


In [5]:
dsp.display(dsp.Audio('best-song-ever.wav'))

In [6]:
transcription = clear((
    "Never gonna give you up, never gonna let you down "
    "Never gonna run around and desert you "
    "Never gonna make you cry, never gonna say goodbye "
    "Never gonna tell a lie and hurt you"
))
transcription

'never gonna give you up never gonna let you down never gonna run around and desert you never gonna make you cry never gonna say goodbye never gonna tell a lie and hurt you'

### RNN-t inference

**RNN-t** modules:

<img alt="rnnt" src="https://www.mdpi.com/symmetry/symmetry-11-01018/article_deploy/html/images/symmetry-11-01018-g004.png" width=400>


The encoder can be arbitrary, like RNN, DeepSpeech 2 encoder or Сonformer encoder, it can be streamable or non streamable, then whole model will be streamable or non streamable respectively.

Inference stage looks like:

<img alt="rnnt" src="https://drive.google.com/uc?id=1EoSRLSSIg2fSge0yVKakKnbcgWUlCVJJ" width=700>

The prediction network consists of two required parts: embedder and RNN.

<img alt="rnnt" src="https://drive.google.com/uc?id=1SaMiv5F3bDRngNS6ot-TBi12Xo6IFOsX" width=700>

And the joint network can have arbitrary complexity and architecture, but in a simple case, it is a simple DNN.

<img alt="rnnt" src="https://drive.google.com/uc?id=11qccpDLBuAEXvsdkOIB9UZbVXqwD4zJC" width=700>





In [7]:
# Read wav
wav, sr = torchaudio.load('best-song-ever.wav')
wav = wav.to(device)
assert sr == 16_000, sr

# Get mel spectrogram
input_signal_length = torch.tensor([wav.size(-1)], dtype=torch.int32, device=device)
spectrogram, spec_length = model.preprocessor.forward(
    input_signal=wav,
    length=input_signal_length,
)

# Get encoded acoustic embeddings
acoustic_embs, acoustic_embs_length = model.encoder.forward(
    audio_signal=spectrogram, length=spec_length
)

In [9]:
acoustic_embs.size() # f0, f1, ... f212

torch.Size([1, 512, 213])

#### CTC Inference

Let's use the `ctc_decode` function from the previous seminar and make a prediction by argmax.

In [12]:
def ctc_decode(inds: list):
    decoded = []
    last_char_idx = BLANK_IND

    for idx in inds:
        if idx == last_char_idx:
            continue
        elif idx != BLANK_IND:
            decoded.append(idx)
        last_char_idx = idx

    return decoded

In [13]:
logits = model.ctc_decoder.forward(encoder_output=acoustic_embs)

inds = logits.argmax(-1).tolist()[0]
inds = ctc_decode(inds)

ctc_hypothesis = model.tokenizer.tokenizer.decode_ids(inds)
ctc_hypothesis = clear(ctc_hypothesis)
ctc_hypothesis

'theyve got to de to ma theyve ever going to let two down they are got to run round and deserve too they have got to make tooth p i are going to say goodbye if going to say la and where to'

In [14]:
wer(transcription, ctc_hypothesis)

1.0294117647058822

#### RNN-t inference

Use `PredictionNetwork` and `JointNetwork` modules for RNN-t decoding. Sometimes it is useful to limit the number of tokens that will be emitted per frame, try to use this in your code with the `MAX_SYMBOLS_PER_FRAME: int` variable.

In [45]:
class PredictionNetwork(nn.Module):
  def __init__(
      self,
      input_size: int,
      hidden_size: int,
      num_layers: int,
      dropout: float,
      num_embeddings: int,
      embedding_dim: int,
      padding_idx=None,
    ):
    super().__init__()

    # https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
    self.embed = nn.modules.sparse.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        padding_idx=padding_idx,
    )
    # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
    self.lstm = nn.LSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        dropout=dropout,
    )
    self.dropout = nn.Dropout(dropout) if dropout else None

  def forward(
      self,
      y: torch.tensor,
      state: tp.Optional[tp.Tuple[torch.tensor, ...]] = None,
    ) -> tp.Tuple[torch.tensor, tp.Tuple[torch.tensor, ...]]:
    """
    input:
      y_labels (bs, seq_len): ids from tokenizer of labels
      state: lstm state, can be None in the first moment, (see torch docs)
    output:
      g (bs, seq_len, hid_dim): language context
      state: lstm state
    """
    # Get embeddings for labels
    y_embs = self.embed(y) # (bs, seq_len, emb_size)
    y_embs = y_embs.transpose(0, 1) # (seq_len, bs, emb_size)

    # Proccess it with LSTM
    g, state = self.lstm(y_embs, state)

    if self.dropout:
        g = self.dropout(g)

    g = g.transpose(0, 1) # (bs, seq_len, hidden_dim)

    return g, state


In [46]:
prediction_network = PredictionNetwork(
    input_size=640,
    hidden_size=640,
    num_layers=1,
    dropout=0.2,
    num_embeddings=1025,
    embedding_dim=640,
    padding_idx=BLANK_IND,
).to(device).eval()

prediction_network.embed.load_state_dict(
    model.decoder.prediction.embed.state_dict()
)
prediction_network.lstm.load_state_dict(
    model.decoder.prediction.dec_rnn.lstm.state_dict()
)

    


<All keys matched successfully>

In [47]:
class JointNetwork(nn.Module):
  def __init__(
      self,
      pred_emb_size: int,
      enc_emb_size: int,
      hidden_size: int,
      dropout: float,
      vocab_size: int,
    ):
    super().__init__()

    self.pred_proj = nn.Linear(
        pred_emb_size, hidden_size
    )
    self.enc_proj = nn.Linear(
        enc_emb_size, hidden_size
    )
    self.joint_net = nn.Sequential(
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_size, vocab_size + 1)
    )

  def forward(
      self,
      encoder_outputs: torch.tensor,
      decoder_outputs: torch.tensor,
    ) -> torch.Tensor:
    """
    input:
      encoder outputs (B, H1, T): acoustic context
      decoder outputs (B, H2, U): language context
    output:
      joint activation (B, T, U, V+1)
    """
    # Project the output of the encoder/decoder into the latent space and concatenate them
    f = self.enc_proj(encoder_outputs) # (bs, T, H)
    f = f.unsqueeze(1) # (bs, 1, T, H)
    g = self.pred_proj(decoder_outputs) # (bs, U, H)
    g = g.unsqueeze(2) # (bs, U, 1, H)

    # (bs, T, U, H)
    pred_state = f + g

    # Project the following state into the vocab distribution space
    # (bs, T, U, H) ->  (bs, T, U, V + 1)
    out_state = self.joint_net(pred_state)

    return out_state


In [23]:
joint_network = JointNetwork(
    pred_emb_size=640,
    enc_emb_size=512,
    hidden_size=640,
    dropout=0.2,
    vocab_size=1024,
).to(device).eval()

joint_network.pred_proj.load_state_dict(
    model.joint.pred.state_dict()
)
joint_network.enc_proj.load_state_dict(
    model.joint.enc.state_dict()
)
joint_network.joint_net.load_state_dict(
    model.joint.joint_net.state_dict()
)

<All keys matched successfully>

write `rnnt_decoder_inference` function:

<img alt="rnnt" src="https://drive.google.com/uc?id=1EoSRLSSIg2fSge0yVKakKnbcgWUlCVJJ" width=700>


In [24]:
MAX_SYMBOLS_PER_FRAME: int = 100

In [40]:
@torch.inference_mode()
def rnnt_decoder_inference(
    prediction_network: nn.Module,
    joint_network: nn.Module,
    f: torch.Tensor,  # acoustic context
) -> tp.List[int]:
    """
    f - torch.tensor (B, H1, T): acoustic context
    """
    bs, _, T = f.size()
    assert bs == 1, bs

    predicted_ids = []

    y_curr = torch.tensor([[0]], dtype=torch.long, device=device) # start from begin of sequence token
    prediction_network_state = None

    for time_step in tqdm(range(T)):
        is_blank = False
        curr_pred_tokens = 0
        f_i = f[:, :, time_step].unsqueeze(-1) # (bs, 1, H1), H1=512

        while not is_blank and curr_pred_tokens < MAX_SYMBOLS_PER_FRAME:
            g_j, prediction_network_state = prediction_network(
                y_curr, prediction_network_state,
            )
            vocab_distr = joint_network(f_i.mT, g_j)
            vocab_distr = vocab_distr[0, 0, 0, :] # v+1

            _, predicted_token = vocab_distr.max(0)
            predicted_token_id = predicted_token.item()

            if predicted_token_id == BLANK_IND:
                is_blank = True
            else:
                predicted_ids.append(predicted_token_id)
                y_curr = torch.empty_like(y_curr).fill_(predicted_token_id)
            curr_pred_tokens += 1


    return predicted_ids


In [41]:
decoded_outut = rnnt_decoder_inference(
    prediction_network=prediction_network,
    joint_network=joint_network,
    f=acoustic_embs,
)

100%|██████████| 213/213 [00:00<00:00, 1016.90it/s]


In [42]:
rnnt_hypothesis = clear(TOKENIZER.decode_ids(decoded_outut))
rnnt_hypothesis

'theyvere got age to let two down they are going toll round and goodbe if i'

In [43]:
transcription

'never gonna give you up never gonna let you down never gonna run around and desert you never gonna make you cry never gonna say goodbye never gonna tell a lie and hurt you'

In [44]:
wer(transcription, rnnt_hypothesis)

0.9411764705882353

### RNN-t training step

In [48]:
transcription

'never gonna give you up never gonna let you down never gonna run around and desert you never gonna make you cry never gonna say goodbye never gonna tell a lie and hurt you'

In [49]:
transcription_ids = TOKENIZER.encode(transcription)
transcription_ids = torch.tensor(transcription_ids, dtype=torch.long, device=device).unsqueeze(0)
transcription_ids

tensor([[464, 999, 407,  32, 149, 464, 999, 529,  32, 378, 464, 999, 700, 650,
           8,  72,   1,  24,   6,  32, 464, 999, 310,  32,  42, 230, 464, 999,
         301, 293,  48,  19,  15, 464, 999, 648,   5, 250,  15,   8, 495,  18,
           6,  32]], device='cuda:0')

In [50]:
# Read wav
wav, sr = torchaudio.load('best-song-ever.wav')
wav = wav.to(device)
assert sr == 16_000, sr

# Get mel spectrogram
input_signal_length = torch.tensor([wav.size(-1)], dtype=torch.int32, device=device)
spectrogram, spec_length = model.preprocessor.forward(
    input_signal=wav,
    length=input_signal_length,
)

# Get encoded acoustic embeddings
acoustic_embs, acoustic_embs_length = model.encoder.forward(
    audio_signal=spectrogram, length=spec_length
)

In [None]:
cur_token_emb, hidden_state = prediction_network(
    y=transcription_ids,
    state=None,
)
vocab_distributon = joint_network(
    encoder_outputs=acoustic_embs.mTcur_token_emb,
    decoder_outputs=cur_token_emb,
)

In [None]:
vocab_distributon.size()

: 



---



for further reading:
  * [Sequence-to-sequence learning with Transducers](https://lorenlugosch.github.io/posts/2020/11/transducer/)
  * RNN-t optimizations:
    * [Multi-Blank Transducers for Speech Recognition, Hainan Xu et al., NVIDIA, 2024](https://arxiv.org/pdf/2211.03541v2)
    * [Efficient Sequence Transduction by Jointly Predicting Tokens and Durations, Hainan Xu et al., NVIDIA, 2023](https://arxiv.org/abs/2304.06795)
    * [FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization, Jiahui Yu et al., Google, 2021](https://arxiv.org/abs/2010.11148)
    * [Fast Conformer with Linearly Scalable Attention for Efficient Speech Recognition, Dima Rekesh et al., NVIDIA, 2023](https://arxiv.org/abs/2305.05084)
    * [Rnn-Transducer with Stateless Prediction Network, Mohammadreza Ghodsi et al., 2020](https://ieeexplore.ieee.org/document/9054419)
