# A tutorial on obtaining accurate speech-to-text alignment for long audio and noisy text

This tutorial consists of two parts. 
- The first part corresponds to the paper *Less Peaky And More Accurate CTC Forced Alignment by Label Priors* published in ICASSP 2024. We will desmonstrate how to obtain more accurate speech-to-text alignment compared to a standard CTC model.
- In the second part, we will provide a robust pytorch-based speech-to-text alignment library to align long audio and noisy text. For example, aligning the whole book, [Walden by Henry David Thoreau](https://www.gutenberg.org/cache/epub/205/pg205-images.html) (of 115K words), with its [audiobook chapter](https://librivox.org/walden-by-henry-david-thoreau/) (of 30 minutes in this demo, or even longer) in the [LibriVox project](https://librivox.org/).

## Preparation

In [None]:
# import necessary libraries
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import IPython
import sys

## Part 1: obtaining more accurate CTC alignment by label priors

### 1.1 Alignment with standard CTC model

### 1.2 Alignment with CTC model with label priors

### 1.3 Fine-tuning CTC model with label priors

## Part 2: obtaining robust alignment for long audio and noisy text

In part 1, we performed forced alignment at the utterance level. In practice, we don't usually have a small segment (e.g., 10 seconds) of audio and its corresponding exact, verbatim transcription as in a laboratory setting (e.g. [LIBRISPEECH](https://www.openslr.org/12) corpus). Instead, the audios come in long form (e.g., a whole mp3 recording of speech for an hour). The transcription for the whole recording can by noisy and non-verbatim, which may not exactly match what's been spoken in the recording. In particular, in order to use the raw speech data for machine learning, we usually need to prepare a corpus of segmented audios. In some applications, we still hope to align the long audio and text data as much as possible. In this tutorial, we will provide a python library to support such use cases.

Here, we are facing two challenges:
- The audio is long, which may not be suitable to be handled as a whole due to, e.g., limited CPU/GPU memory.
- The transcript is noisy. It can be a partial transcript with some missing words. It may have significant errors. It may also contain extra contents that's not spoken in the audio (e.g., the corresponding audio is corrupted). It can be a combination of all cases. Thus, the conventional, basic forced alignment algorithm could provide very bad alignment results, as it assumes the audio and text match exactly.

There are a few existing solutions:
- [Kaldi](https://ieeexplore.ieee.org/document/8268956), [Gentle](https://github.com/lowerquality/gentle) and [this work](https://ieeexplore.ieee.org/document/7404861) employ a weighted finite state transducer (WFST) framework to model the noisy texts. 
- [WhisperX](https://github.com/m-bain/whisperX) uses attention mechanism to propose rough time stamps for uniformly segmented audio. Then, it performs phone-level or word-level forced alignment with an external aligner.
- [MMS](https://arxiv.org/abs/2305.13516) uses a special `<star>` token to handle missing words in the transcript.
- [SailAlign](https://www.semanticscholar.org/paper/SailAlign%3A-Robust-long-speech-text-alignment-Katsamanis-Georgiou/0b7f86429641b188cc62ec32eee590e8795a3d02) iteratively identifies reliable regions and then narrows down to align the remaining unaligned regions.

This tutorial is based on WFST and thus falls in the first category. Our implementation is based on PyTorch. Any CTC model in PyTorch can be equipped with our library to become a robust aligner. This makes our aligner distinguish from existing ones.

### 2.1 Install dependencies

For WFST, our library depends on [k2](https://github.com/k2-fsa/k2/), a pytorch-based WFST library.

In [None]:
# Check python and pytorch's version
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
!python --version

In [None]:
!pip install k2==1.24.4.dev20240223+cuda11.7.torch2.0.1-cp310 -f https://k2-fsa.github.io/k2/cpu.html
!pip install cmudict g2p_en
!pip install git+https://github.com/huangruizhe/lis.git

In [None]:
# !git clone xxx

sys.path.append("/exp/rhuang/meta/audio_latest/examples/asr/librispeech_alignment/alignment")

from torchaudio_k2_aligner import (
    uniform_segmentation_with_overlap,
)
from tokenizer import EnglishCharTokenizer
from factor_transducer import make_factor_transducer_word_level_index_with_skip
from k2_icefall_utils import (
    get_best_paths,
    get_texts_with_timestamp,
)

We will use a pre-trained Wav2Vec2 model, [torchaudio.pipelines.MMS_FA](https://pytorch.org/audio/main/generated/torchaudio.pipelines.MMS_FA.html#torchaudio.pipelines.MMS_FA), as the acoustic model.

In [None]:
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)

In [None]:
LABELS = bundle.get_labels(star="*")
DICTIONARY = bundle.get_dict(star="*")

print(LABELS)

tokenizer = EnglishCharTokenizer(
    token2id=DICTIONARY,
    blk_token="-",
    unk_token="*",
)

### 2.2 Prepare long audio and noisy text

We will demonstrate aligning the whole book, [Walden by Henry David Thoreau](https://www.gutenberg.org/cache/epub/205/pg205-images.html) (of 115K words), with its audiobook chapter (of 30 minutes) in the [LibriVox project](https://librivox.org/walden-by-henry-david-thoreau/).

In [None]:
# Download the whole book
import requests
from bs4 import BeautifulSoup

url = "https://www.gutenberg.org/cache/epub/205/pg205-images.html"
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")

text = soup.get_text()
text = text.replace("\r\n", "\n")

In [None]:
# Download a chapter of the audio book
# !wget https://ia800707.us.archive.org/20/items/walden_librivox/walden_c07.mp3

SPEECH_FILE = "walden_c07.mp3"

In [None]:
# Play the long audio
IPython.display.Audio(SPEECH_FILE)

In [None]:
# Preview the transcript
print(text[:1000])

In [None]:
# Preview the transcript relevant to the audio
print(text[271400: 271400+1000])

As we can see above, the audio contains a header "This is a LibriVox recording ..." which is not transcribed. On the other hand, as we have downloaded the whole book, it contains a lot of extra text that's not spoken in the audio. Obviously, the standard forced alignment algorithm will not work in this case.

In [None]:
# We tokenize the text into the labels of the acoustic model's output
text_tokenized = tokenizer.encode(tokenizer.text_normalize(text))

print(f"There are {len(text_tokenized)} words in the text")

# Preview the tokenization results. This corresponds to the beginning of the audio
print(tokenizer.decode(text_tokenized)[49489: 49489+15])

### 2.3 Use WFST to represent the text

#### 2.3.1 WFST basics

#### 2.3.2 CTC graph, factor transducer and the variants

Now, we will represent the whole book of over 100K words into one single WFST decoding graph that allows occasional insertion/deletioin/substitution errors. We set `skip_penalty=-0.5` and `return_penalty=-18.0`.

In [None]:
decoding_graph, word_index_sym_tab, token_sym_tab = \
    make_factor_transducer_word_level_index_with_skip(
        text_tokenized, 
        blank_penalty=0, 
        skip_penalty=-0.5, 
        return_penalty=-18.0
    )

print(f"There are {decoding_graph.shape[0]} nodes and {decoding_graph.num_arcs} arcs in the decoding graph for the text of {len(text_tokenized)} words.")

### Handle long audio

We hope to feed the audio into the neural network to get the frame-wise posteriors (i.e., the emission matrix) over the label vocabulary. However, the audio book chapter is of about 30 minutes. This is too long to feed into the acoustic model all at once. 

A common practice is to segment the long audio into small overlapping segments. The segments are processed independently and the results are concatenated appropriately to make the final alignment result. Here, as the Wav2Vec2 pretrained model that we use is taking a raw wave form as input, so we will segment the original audio into 15-second segments. In practice, we can also segment the feature vector (e.g., Fbanks) to feed into the acoustic model.

The library that we provide in this tutorial provides functions for appropriate segmentation and concatenation.

Now, let's use the library to segment the 30-minute audio into 15-second segments, with an overlap of 2 seconds between neighboring segments.

In [None]:
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
resample_rate = 16000  # this is the sample rate of the Wav2Vec2 model
waveform = torchaudio.functional.resample(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
print(waveform.shape, sample_rate)

In [None]:
import torchaudio_k2_aligner

import importlib
importlib.reload(torchaudio_k2_aligner)

uniform_segmentation_with_overlap = torchaudio_k2_aligner.uniform_segmentation_with_overlap

In [None]:
if waveform.dim() == 2:
    waveform.unsqueeze_(-1)
segment_size = 16000 * 15 + 128  # 15 seconds; use extra 128 samples to make sure we have 750 frames for each full-sized segment
overlap = 16000 * 2 + 128        # 2 seconds
shortest_segment_size = 3200     # if the last segment has less than 3200 samples (0.2 seconds), it will be discarded

waveform_segmented, segment_lengths, segment_offsets = uniform_segmentation_with_overlap(
    waveform,
    segment_size, 
    overlap, 
    shortest_segment_size=shortest_segment_size
)
waveform_segmented = waveform_segmented.squeeze()
print(waveform_segmented.shape, segment_lengths.shape, segment_offsets.shape)

In [None]:
# We can listen to a segment to make sure it is correct
IPython.display.Audio(waveform_segmented[0], rate=sample_rate)

### 2.4 Obtain alignment

In [None]:
if device == torch.device("cpu"):
    batch_size = 2
else:
    batch_size = 32

for i in range(0, waveform_segmented.size(0), batch_size):
    batch_waveform_segmented = waveform_segmented[i: i+batch_size].to(device)
    batch_segment_lengths = segment_lengths[i: i+batch_size]
    batch_segment_offset = segment_offsets[i: i+batch_size] // 320

    with torch.inference_mode():
        # Checkout the API of the forward function here: https://github.com/pytorch/audio/blob/main/src/torchaudio/pipelines/_wav2vec2/utils.py#L34
        batch_emissions, output_lengths = model(batch_waveform_segmented.to(device), batch_segment_lengths)
    
    # Attach the star dimension manually, see torchaudio issue #3772
    star_dim = torch.empty((batch_emissions.size(0), batch_emissions.size(1), 1), device=batch_emissions.device, dtype=batch_emissions.dtype)
    star_dim[:] = -5.0
    batch_emissions = torch.cat((batch_emissions, star_dim), 2)

    # find best alignment paths using k2 library and the input WFST graph
    best_paths = get_best_paths(batch_emissions, output_lengths, decoding_graph)
    best_paths = best_paths.detach().to('cpu')

    decoding_results = get_texts_with_timestamp(best_paths)
    token_ids_indices = decoding_results["hyps"]  # Note, here the "hyps" are actually indices 
    timestamps = decoding_results["timestamps"]

    # It will be too slow to do inference on CPU. 
    # As it's only for the demonstration purpose, so we will break the loop here.
    if device == torch.device("cpu"):
        break

In [None]:
tki = 1
token_ids_indices[tki]

In [None]:
"".join([tokenizer.id2token[token_sym_tab[i]] for i in token_ids_indices[tki]])

In [None]:
[word_index_sym_tab[i] for i in token_ids_indices[tki] if i in word_index_sym_tab]

In [None]:
text_splitted = text.split()
[text_splitted[word_index_sym_tab[i]] for i in token_ids_indices[tki] if i in word_index_sym_tab]

# Test

In [None]:
import k2
import torch


def uniform_segmentation_with_overlap(m, segment_size, overlap, shortest_segment_size=0):
    '''
    This function cuts the input matrix `m` into overlapping segments.
    `m` can be, e.g., the feature matrix of the input audio or the emission matrix.

    Args:
        m: 3-D tensor of shape (1, T, D).
        segment_size: an integer, the size of each segment.
        overlap: an integer, the number of frames to overlap between segments.
    '''
    assert m.ndim == 3
    assert m.size(0) == 1
    # m = m.unsqueeze(0)  # (1, T, D)

    step = segment_size - overlap
    if (m.size(1) - segment_size) % step == 0:
        n_segments = (m.size(1) - segment_size) // step + 1
        padding_size = 0
    else:
        n_segments = (m.size(1) - segment_size) // step + 2
        padding_size = (n_segments - 1) * step + segment_size - m.size(1)
    
    m_padded = torch.nn.functional.pad(m, (0, 0, 0, padding_size))  # Pad the tensor with zeros
    m_segmented = m_padded.unfold(dimension=1, size=segment_size, step=step)
    m_segmented = m_segmented.permute(0, 1, 3, 2)  # `m_segmented` is of shape (1, L, segment_size, D), where L is the number of segments.
    m_segmented = m_segmented[0]  # (L, segment_size, D)
    
    segment_lengths = [segment_size] * m_segmented.size(0)
    segment_lengths[-1] -= padding_size
    segment_lengths = torch.tensor(segment_lengths)

    segment_offsets = torch.arange(0, segment_lengths.size(0) * step, step)

    assert len(segment_lengths) == n_segments
    assert len(segment_lengths) == m_segmented.size(0)
    assert len(segment_lengths) == len(segment_offsets)

    # Discard the last chunk if it is too short
    if segment_lengths[-1] < shortest_segment_size:
        m_segmented = m_segmented[:-1]
        segment_lengths = segment_lengths[:-1]
        segment_offsets = segment_offsets[:-1]
    
    return m_segmented, segment_lengths, segment_offsets

In [None]:
m = torch.rand(1, 20, 3)
m

In [None]:
uniform_segmentation_with_overlap(m, 7, 2, shortest_segment_size=0)

In [None]:
m = torch.rand(1, 7, 3)
segment_size = 7
overlap = 2
step = segment_size - overlap
if (m.size(1) - segment_size) % step == 0:
    n_segments = (m.size(1) - segment_size) // step + 1
    padding_size = 0
else:
    n_segments = (m.size(1) - segment_size) // step + 2
    padding_size = (n_segments - 1) * step + segment_size - m.size(1)

print(n_segments, padding_size)