<a href="https://colab.research.google.com/github/dylstuart/streaming_asr/blob/main/asr_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Online ASR with Emformer RNN-T

**Original Authors**: [Jeff Hwang](jeffhwang@meta.com)_, [Moto Hira](moto@meta.com)_
**Modified By**: [Dylan Stuart](dyln.strt@gmail.com)

This notebook shows how to use Emformer RNN-T and streaming API
to perform online speech recognition, and measures various latency metrics of the model.


<div class="alert alert-info"><h4>Note</h4><p>This tutorial requires FFmpeg libraries and SentencePiece.
</div>




## 0. Overview

Performing online speech recognition is composed of the following steps

1. Build the inference pipeline
   Emformer RNN-T is composed of three components: feature extractor,
   decoder and token processor.
2. Format the waveform into chunks of expected sizes.
3. Pass data through the pipeline.



## 1. Install dependencies

torchaudio.io.StreamReader used in this repo requires torchaudio==2.8.0. There may be a dependency mismatch with the pre-installed torchvision on colab, we can ignore as we are not using torchvision.

In [None]:
!pip install torchaudio==2.8.0 sentencepiece

## 2. Preparation




In [None]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import IPython
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader

## 3. Construct the pipeline

Pre-trained model weights and related pipeline components are
bundled as :py:class:`torchaudio.pipelines.RNNTBundle`.

We use :py:data:`torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH`,
which is a Emformer RNN-T model trained on LibriSpeech dataset.




In [None]:
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH

feature_extractor = bundle.get_streaming_feature_extractor()
decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor()

Streaming inference works on input data with overlap.
Emformer RNN-T model treats the newest portion of the input data
as the "right context" â€” a preview of future context.
In each inference call, the model expects the main segment
to start from this right context from the previous inference call.
The following figure illustrates this.

<img src="https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_context.png">

The size of main segment and right context, along with
the expected sample rate can be retrieved from bundle.




In [None]:
sample_rate = bundle.sample_rate
segment_length = bundle.segment_length * bundle.hop_length
context_length = bundle.right_context_length * bundle.hop_length

print(f"Sample rate: {sample_rate}")
print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")

## 4. Configure the audio stream

[NOTE: This has been left in from the original tutorial, and is still worth running to globally instantiate the streamer]

Next, we configure the input audio stream using :py:class:`torchaudio.io.StreamReader`.

For the detail of this API, please refer to the
[StreamReader Basic Usage](./streamreader_basic_tutorial.html)_.




The following audio file was originally published by LibriVox project,
and it is in the public domain.

https://librivox.org/great-pirate-stories-by-joseph-lewis-french/

It was re-uploaded for the sake of the tutorial.




In [None]:
src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3"

streamer = StreamReader(src)
streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)

print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))

As previously explained, Emformer RNN-T model expects input data with
overlaps; however, `Streamer` iterates the source media without overlap,
so we make a helper structure that caches a part of input data from
`Streamer` as right context and then appends it to the next input data from
`Streamer`.

The following figure illustrates this.

<img src="https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_streamer_context.png">




In [None]:
class ContextCacher:
    """Cache the end of input data and prepend the next input data with it.

    Args:
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
    """

    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])

    def __call__(self, chunk: torch.Tensor):
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
        return chunk_with_context

## 5. Run stream inference

Finally, we run the recognition.

First, we initialize the stream iterator, context cacher, and
state and hypothesis that are used by decoder to carry over the
decoding state between inference calls.




In [None]:
cacher = ContextCacher(segment_length, context_length)

state, hypothesis = None, None

Add a timing utility for latency measurements

In [None]:
from time import perf_counter
from contextlib import contextmanager

@contextmanager
def timed(section, stats):
    start = perf_counter()
    yield
    elapsed = perf_counter() - start
    stats[section].append(elapsed)

chunk_stats = {
    "feature_extraction": [],
    "model_forward": [],
    "decoder": [],
    "total_chunk": []
}

audio_time = 0.0

Next we, run the inference.

For the sake of better display, we create a helper function which
processes the source stream up to the given times and call it
repeatedly.




In [None]:
stream_iterator = streamer.stream()


def _plot(feats, num_iter, unit=25):
    unit_dur = segment_length / sample_rate * unit
    num_plots = num_iter // unit + (1 if num_iter % unit else 0)
    fig, axes = plt.subplots(num_plots, 1)
    t0 = 0
    for i, ax in enumerate(axes):
        feats_ = feats[i * unit : (i + 1) * unit]
        t1 = t0 + segment_length / sample_rate * len(feats_)
        feats_ = torch.cat([f[2:-2] for f in feats_])  # remove boundary effect and overlap
        ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
        ax.tick_params(which="both", left=False, labelleft=False)
        ax.set_xlim(t0, t0 + unit_dur)
        t0 = t1
    fig.suptitle("MelSpectrogram Feature")
    plt.tight_layout()

decoder = torch.compile(decoder)

@torch.inference_mode()
def run_inference():
    global state, hypothesis, audio_time
    chunks = []
    feats = []
    chunk_end_timestamps = []
    for i, (chunk,) in enumerate(stream_iterator, start=1):
        audio_time += (segment_length / sample_rate)
        with timed("total_chunk", chunk_stats):
          segment = cacher(chunk[:, 0])
          with timed("feature_extraction", chunk_stats):
              features, length = feature_extractor(segment)

          with timed("model_forward", chunk_stats):
              hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)

          hypothesis = hypos
          with timed("decoder", chunk_stats):
              transcript = token_processor(hypos[0][0], lstrip=False)
          chunk_end_timestamps.append(perf_counter())

          chunks.append(chunk)
          feats.append(features)
    return chunk_end_timestamps

Run inference and collect statistics for all audio samples in our mini dataset

In [None]:
import glob
# Find all mp3 files in our samples directory
sample_paths = glob.glob("./samples/*.mp3")
print(f"{len(sample_paths)} samples available for processing")
full_dataset_stats = []
full_sample_stats = {"latency": [], "rtf": [], "ttft": []}
for path in sample_paths:

  #Initialize temp statistics structures
  audio_time = 0
  chunk_stats = {
    "feature_extraction": [],
    "model_forward": [],
    "decoder": [],
    "total_chunk": []
  }

  # Reset the cache, state, hypothesis
  cacher = ContextCacher(segment_length, context_length)
  state, hypothesis = None, None

  pipeline_start_timestamp = perf_counter()

  with timed("latency", full_sample_stats):
    streamer = StreamReader(path)
    streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)
    stream_iterator = streamer.stream()
    chunk_end_timestamps = run_inference()

  full_dataset_stats.append(chunk_stats)
  print(f"Finished {path}")
  rtf = sum(chunk_stats["total_chunk"]) / audio_time
  full_sample_stats["rtf"].append(rtf)

  # TTFT is measured from pipeline start to first token processor output (even if output is null)
  full_sample_stats["ttft"].append(chunk_end_timestamps[0] - pipeline_start_timestamp)
  print(f"Sum of chunk processing times: {sum(chunk_stats["total_chunk"])}\nAudio file duration: {audio_time}\nRTF: {rtf}")


## 6. Analysis

Aggregate the statistics taken over all of the audio samples and plot the cdfs of each metric for all of the samples

In [None]:
from collections import defaultdict
import numpy as np

def aggregate_dicts(stats_list):
    """
    Flattens a list of dicts into a single dict of arrays.
    """
    aggregated = defaultdict(list)

    for stats in stats_list:
        for key, values in stats.items():
            aggregated[key].extend(values)

    # Convert to numpy arrays for convenience
    return {k: np.array(v) for k, v in aggregated.items()}

def compute_cdf(values):
    """
    Returns sorted values and cumulative probabilities.
    """
    x = np.sort(values)
    y = np.arange(1, len(x) + 1) / len(x)
    return x, y

def summarize(values):
    return {
        "mean": np.mean(values),
        "p50": np.percentile(values, 50),
        "p90": np.percentile(values, 90),
        "p99": np.percentile(values, 99)
    }

# Plot CDF of aggregated stats
aggregated_stats = aggregate_dicts(full_dataset_stats)
metrics = aggregated_stats.keys()
plt.figure(figsize=(8, 6))

for metric in metrics:
    x, y = compute_cdf(aggregated_stats[metric])
    plt.plot(x * 1000, y, label=metric)  # seconds -> ms

plt.xlabel("Latency (ms)")
plt.ylabel("Cumulative Density")
plt.title("Streaming ASR Latency CDFs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

chunk_latency_summary = summarize(aggregated_stats["total_chunk"])

print(f"Chunk Latencies (s)\n \
mean: {chunk_latency_summary["mean"]}\n \
p50: {chunk_latency_summary["p50"]}\n \
p90: {chunk_latency_summary["p90"]}\n \
p99: {chunk_latency_summary["p99"]}\n \
")

Plot Histograms of RTFs and TTFTs

In [None]:
# Plot histogram of RTFs
plt.figure(figsize=(8, 6))
plt.hist(np.array(full_sample_stats["rtf"]), 5)
plt.xlabel("RTF")
plt.ylabel("Count")
plt.title("Histogram of RTFs")
plt.grid(True)
plt.show()

# Plot histogram of TTFTs
plt.figure(figsize=(8, 6))
plt.hist(np.array(full_sample_stats["ttft"]), 5)
plt.xlabel("TTFT (s)")
plt.ylabel("Count")
plt.title("Histogram of TTFTs")
plt.grid(True)
plt.show()

ttft_summary = summarize(np.array(full_sample_stats["ttft"]))

print(f"TTFT (s)\n \
mean: {ttft_summary["mean"]}\n \
p50: {ttft_summary["p50"]}\n \
p90: {ttft_summary["p90"]}\n \
p99: {ttft_summary["p99"]}\n \
")

Let's analyze the processing latency of each chunk as we stream a given audio sample.

In [None]:
plt.figure(figsize=(10, 6))

#Add separate line to the plot for each sample
for sample_idx, stats in enumerate(full_dataset_stats):
    latencies = stats.get("total_chunk", [])
    if not latencies:
        continue

    chunk_indices = range(len(latencies))
    plt.plot(
        chunk_indices,
        [l * 1000 for l in latencies],  # seconds -> ms
        alpha=0.7,
        label=f"sample_{sample_idx}"
    )

plt.xlabel("Chunk Index")
plt.ylabel("Total Chunk Latency (ms)")
plt.title("Per-Sample Total Chunk Latency Over Time")
plt.grid(True)
plt.legend().remove()
plt.tight_layout()
plt.axline((0, 160), slope=0, linewidth=4, color='r') # Print average real-time requirement
plt.show()