<a href="https://colab.research.google.com/github/dylstuart/streaming_asr/blob/main/asr_analysis_torch_profile.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Online ASR with Emformer RNN-T

**Original Authors**: [Jeff Hwang](jeffhwang@meta.com)_, [Moto Hira](moto@meta.com)_
**Modified By**: [Dylan Stuart](dyln.strt@gmail.com)

This notebook shows how to use Emformer RNN-T and streaming API
to perform online speech recognition, and measures various latency metrics of the model.


<div class="alert alert-info"><h4>Note</h4><p>This tutorial requires FFmpeg libraries and SentencePiece.
</div>




## 0. Overview

Performing online speech recognition is composed of the following steps

1. Build the inference pipeline
   Emformer RNN-T is composed of three components: feature extractor,
   decoder and token processor.
2. Format the waveform into chunks of expected sizes.
3. Pass data through the pipeline.



## 1. Install dependencies

torchaudio.io.StreamReader used in this repo requires torchaudio==2.8.0. There may be a dependency mismatch with the pre-installed torchvision on colab, we can ignore as we are not using torchvision.

In [None]:
!pip install torchaudio==2.8.0 sentencepiece

## 2. Preparation




In [None]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import IPython
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader

## 3. Construct the pipeline

Pre-trained model weights and related pipeline components are
bundled as :py:class:`torchaudio.pipelines.RNNTBundle`.

We use :py:data:`torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH`,
which is a Emformer RNN-T model trained on LibriSpeech dataset.




In [None]:
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH

feature_extractor = bundle.get_streaming_feature_extractor()
decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor()

Streaming inference works on input data with overlap.
Emformer RNN-T model treats the newest portion of the input data
as the "right context" — a preview of future context.
In each inference call, the model expects the main segment
to start from this right context from the previous inference call.
The following figure illustrates this.

<img src="https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_context.png">

The size of main segment and right context, along with
the expected sample rate can be retrieved from bundle.




In [None]:
sample_rate = bundle.sample_rate
segment_length = bundle.segment_length * bundle.hop_length
context_length = bundle.right_context_length * bundle.hop_length

print(f"Sample rate: {sample_rate}")
print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")

## 4. Configure the audio stream

[NOTE: This has been left in from the original tutorial, and is still worth running to globally instantiate the streamer]

Next, we configure the input audio stream using :py:class:`torchaudio.io.StreamReader`.

For the detail of this API, please refer to the
[StreamReader Basic Usage](./streamreader_basic_tutorial.html)_.




The following audio file was originally published by LibriVox project,
and it is in the public domain.

https://librivox.org/great-pirate-stories-by-joseph-lewis-french/

It was re-uploaded for the sake of the tutorial.




In [None]:
src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3"

streamer = StreamReader(src)
streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)

print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))

As previously explained, Emformer RNN-T model expects input data with
overlaps; however, `Streamer` iterates the source media without overlap,
so we make a helper structure that caches a part of input data from
`Streamer` as right context and then appends it to the next input data from
`Streamer`.

The following figure illustrates this.

<img src="https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_streamer_context.png">




In [None]:
class ContextCacher:
    """Cache the end of input data and prepend the next input data with it.

    Args:
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
    """

    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])

    def __call__(self, chunk: torch.Tensor):
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
        return chunk_with_context

## 5. Run stream inference

Finally, we run the recognition.

First, we initialize the stream iterator, context cacher, and
state and hypothesis that are used by decoder to carry over the
decoding state between inference calls.




In [None]:
cacher = ContextCacher(segment_length, context_length)

state, hypothesis = None, None

Next we, run the inference.

For the sake of better display, we create a helper function which
processes the source stream up to the given times and call it
repeatedly.




In [None]:
stream_iterator = streamer.stream()

@torch.inference_mode()
def run_inference():
    global state, hypothesis, audio_time
    chunks = []
    feats = []
    profile_results = None
    for i, (chunk,) in enumerate(stream_iterator, start=1):
        audio_time += (segment_length / sample_rate)
        segment = cacher(chunk[:, 0])

        features, length = feature_extractor(segment)

        # Profile model forward pass only
        with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=False, with_flops=True) as prof:
            hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
            profile_results = prof
            hypothesis = hypos
        transcript = token_processor(hypos[0][0], lstrip=False)

        chunks.append(chunk)
        feats.append(features)

    return profile_results

Run inference and collect statistics for all audio samples in our mini dataset

In [None]:
import glob
from torch.profiler import profile, ProfilerActivity, record_function

# Find all mp3 files in our samples directory
sample_paths = glob.glob("./samples/*.mp3")
print(f"{len(sample_paths)} samples available for processing")

profile_stats = None

for path in sample_paths:
  #Initialize temp statistics structures
  audio_time = 0

  # Reset the cache, state, hypothesis
  cacher = ContextCacher(segment_length, context_length)
  state, hypothesis = None, None

  streamer = StreamReader(path)
  streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)
  stream_iterator = streamer.stream()
  profile_stats = run_inference()
  print(profile_stats.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=20))

total_cpu_time = sum(item.cpu_time_total for item in profile_stats.key_averages())
print(f"Total CPU time (µs): {total_cpu_time:.2f}")
