# A tutorial on obtaining accurate speech-to-text alignment for long audio and noisy text

This tutorial consists of two parts. 
- The first part corresponds to the paper *Less Peaky And More Accurate CTC Forced Alignment by Label Priors* published in ICASSP 2024. We will desmonstrate how to obtain more accurate speech-to-text alignment compared to a standard CTC model.
- In the second part, we will provide a robust pytorch-based speech-to-text alignment library to align long audio and noisy text. For example, aligning the whole book, [Walden by Henry David Thoreau](https://www.gutenberg.org/cache/epub/205/pg205-images.html) (of 115K words), with its [audiobook chapter](https://librivox.org/walden-by-henry-david-thoreau/) (of 30 minutes in this demo, or even longer) in the [LibriVox project](https://librivox.org/).

## Preparation

In [None]:
# import necessary libraries
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import IPython
import sys
from tqdm import tqdm


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

## Part 1: obtaining more accurate CTC alignment by label priors

### 1.1 Alignment with standard CTC model

### 1.2 Alignment with CTC model with label priors

### 1.3 Fine-tuning CTC model with label priors

In [None]:
# To be added. Basically, we will show this part

## Part 2: obtaining robust alignment for long audio and noisy text

In part 1, we performed forced alignment at the utterance level. In practice, we don't usually have a small segment (e.g., 10 seconds) of audio and its corresponding exact, verbatim transcription as in a laboratory setting (e.g. [LIBRISPEECH](https://www.openslr.org/12) corpus). Instead, the audios come in long form (e.g., a whole mp3 recording of speech for an hour). The transcription for the whole recording can by noisy and non-verbatim, which may not exactly match what's been spoken in the recording. In particular, in order to use the raw speech data for machine learning, we usually need to prepare a corpus of segmented audios. In some applications, we still hope to align the long audio and text data as much as possible. In this tutorial, we will provide a python library to support such use cases.

Here, we are facing two challenges:
- **The audio is long**, which may not be suitable to be handled as a whole due to, e.g., limited CPU/GPU memory.
- **The transcript is noisy**. It can be a partial transcript with some missing words. It may have significant errors. It may also contain extra contents that's not spoken in the audio (e.g., the corresponding audio is corrupted). It can be a combination of all cases. Thus, the conventional, basic forced alignment algorithm could provide very bad alignment results, as it assumes the audio and text match exactly.

There are a few existing solutions:
- [Kaldi](https://ieeexplore.ieee.org/document/8268956), [Gentle](https://github.com/lowerquality/gentle) and [this work](https://ieeexplore.ieee.org/document/7404861) employ a weighted finite state transducer (WFST) framework to model the noisy texts. 
- [WhisperX](https://github.com/m-bain/whisperX) uses attention mechanism to propose rough time stamps for uniformly segmented audio. Then, it performs phone-level or word-level forced alignment with an external aligner.
- [MMS](https://arxiv.org/abs/2305.13516) uses a special `<star>` token to handle missing words in the transcript.
- [SailAlign](https://www.semanticscholar.org/paper/SailAlign%3A-Robust-long-speech-text-alignment-Katsamanis-Georgiou/0b7f86429641b188cc62ec32eee590e8795a3d02) iteratively identifies reliable regions and then narrows down to align the remaining unaligned regions.

This tutorial is based on WFST and thus falls in the first category. Our implementation is based on PyTorch. Any CTC model in PyTorch can be equipped with our library to become a robust aligner. This makes our aligner distinguish from existing ones.

### 2.1 Install dependencies

For WFST, our library depends on [k2](https://github.com/k2-fsa/k2/), a pytorch-based WFST library.

In [None]:
# Check python and pytorch's version
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
!python --version

In [None]:
# k2 (cpu)
# pip install k2==1.24.4.dev20240223+cpu.torch2.2.1 -f https://k2-fsa.github.io/k2/cpu.html
# k2 (gpu)
!pip install k2==1.24.4.dev20240301+cuda12.1.torch2.2.1 -f https://k2-fsa.github.io/k2/cuda.html
!pip install cmudict g2p_en
!pip install git+https://github.com/huangruizhe/lis.git

In [None]:
# !git clone xxx

sys.path.append("/exp/rhuang/meta/audio_latest/examples/asr/librispeech_alignment/alignment")

from torchaudio_k2_aligner import (
    uniform_segmentation_with_overlap,
    align_segments,
    concat_alignments,
    get_final_word_alignment,
    align,
)
from tokenizer import EnglishCharTokenizer
from factor_transducer import make_factor_transducer_word_level_index_with_skip
from k2_icefall_utils import (
    get_best_paths,
    get_texts_with_timestamp,
)

We will use a pre-trained Wav2Vec2 model, [torchaudio.pipelines.MMS_FA](https://pytorch.org/audio/main/generated/torchaudio.pipelines.MMS_FA.html#torchaudio.pipelines.MMS_FA), as the acoustic model.

In [None]:
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)

In [None]:
LABELS = bundle.get_labels(star="*")
DICTIONARY = bundle.get_dict(star="*")

print(LABELS)

tokenizer = EnglishCharTokenizer(
    token2id=DICTIONARY,
    blk_token="-",
    unk_token="*",
)

### 2.2 Prepare long audio and noisy text

We will demonstrate aligning the whole book, [Walden by Henry David Thoreau](https://www.gutenberg.org/cache/epub/205/pg205-images.html) (of 115K words), with its audiobook chapter (of 30 minutes) in the [LibriVox project](https://librivox.org/walden-by-henry-david-thoreau/).

In [None]:
# Download the whole book
import requests
from bs4 import BeautifulSoup

url = "https://www.gutenberg.org/cache/epub/205/pg205-images.html"
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")

text = soup.get_text()
text = text.replace("\r\n", "\n")

In [None]:
# Download a chapter of the audio book
# !wget https://ia800707.us.archive.org/20/items/walden_librivox/walden_c07.mp3

SPEECH_FILE = "walden_c07.mp3"

In [None]:
# Play the long audio
IPython.display.Audio(SPEECH_FILE)

In [None]:
# Preview the transcript
print(text[:1000])

In [None]:
# Preview the transcript relevant to the audio

# In this running example, the whole audio corresponds to 
# text[271616: 293529]
# or
# text.split()[49489: 53362]

print(text[271400: 271400+1000])

As we can see above, the audio contains a header "This is a LibriVox recording ..." which is not transcribed. On the other hand, as we have downloaded the whole book, it contains a lot of extra text that's not spoken in the audio. Obviously, the standard forced alignment algorithm will not work in this case.

In [None]:
# We tokenize the text into the labels of the acoustic model's output
text_tokenized = tokenizer.encode(tokenizer.text_normalize(text))

print(f"There are {len(text_tokenized)} words in the text")

# Preview the tokenization results. This corresponds to the beginning of the audio
print(tokenizer.decode(text_tokenized)[49489: 49489+15])

### 2.3 Use WFST to represent the text

#### 2.3.1 WFST basics

#### 2.3.2 CTC graph, factor transducer and the variants

Now, we will represent the whole book of over 100K words into one single WFST decoding graph that allows occasional insertion/deletioin/substitution errors. We set `skip_penalty=-0.5` and `return_penalty=-18.0`.

In [None]:
decoding_graph, word_index_sym_tab, token_sym_tab = \
    make_factor_transducer_word_level_index_with_skip(
        text_tokenized, 
        blank_penalty=0, 
        skip_penalty=-0.5, 
        return_penalty=-18.0
    )
decoding_graph = decoding_graph.to(device)

print(f"There are {decoding_graph.shape[0]} nodes and {decoding_graph.num_arcs} arcs in the decoding graph for the text of {len(text_tokenized)} words.")
print(f"The decoding graph is on device: {decoding_graph.device}")

### Handle long audio

We hope to feed the audio into the neural network to get the frame-wise posteriors (i.e., the emission matrix) over the label vocabulary. However, the audio book chapter is of about 30 minutes. This is too long to feed into the acoustic model all at once. 

A common practice is to segment the long audio into small overlapping segments. The segments are processed independently and the results are concatenated appropriately to make the final alignment result. Here, as the Wav2Vec2 pretrained model that we use is taking a raw wave form as input, so we will segment the original audio into 15-second segments. In practice, we can also segment the feature vector (e.g., Fbanks) to feed into the acoustic model.

The library that we provide in this tutorial provides functions for appropriate segmentation and concatenation.

Now, let's use the library to segment the 30-minute audio into 15-second segments, with an overlap of 2 seconds between neighboring segments.

In [None]:
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
resample_rate = 16000  # this is the sample rate of the Wav2Vec2 model
waveform = torchaudio.functional.resample(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
print(waveform.shape, sample_rate)

In [None]:
if waveform.dim() == 2:
    waveform.unsqueeze_(-1)
segment_size = sample_rate * 15 + 128  # 15 seconds; use extra 128 waveform samples to make sure we have 750 frames for each full-sized segment
overlap = sample_rate * 2 + 128        # 2 seconds
shortest_segment_size = sample_rate * 0.2  # if the last segment has less than 3200 samples (0.2 seconds), it will be discarded

waveform_segmented, segment_lengths, segment_offsets = uniform_segmentation_with_overlap(
    waveform,
    segment_size, 
    overlap, 
    shortest_segment_size=shortest_segment_size
)
waveform_segmented = waveform_segmented.squeeze()
print(waveform_segmented.shape, segment_lengths.shape, segment_offsets.shape)

We can listen to a segment to make sure the segmentation pipeline runs correctly.

In [None]:
IPython.display.Audio(waveform_segmented[0], rate=sample_rate)

### 2.4 Obtain alignment

Now, we will feed the short segments to the neural network based acoustic model. We do this in batches. For each segment, the acoustic model produces frame-wise classification `batch_emissions` over its label vocabulary. This is then combined with the WFST `decoding_graph`, which is shared by all segments, to produce the best alignment path for each segment. If the `decoding_graph` is a linear WFST, which means there is no insertion/deletion/substitution errors in the transcript, this is equivalent to the conventional forced alignment as provided by this [TorchAudio API](https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html).

There something to note for the `decoding_graph`. When we combine it with the neural network output, its "input label" should match the network's label vocabulary. However, it is our decision to define decoding graph's "output labels", which basically ask this question: what do we align the audio to? Here are two examples of design choices:

- The output labels are word labels or phoneme labels, as in most WFST-based alignment solutions (e.g., in [Gentle](https://github.com/lowerquality/gentle/blob/master/gentle/diff_align.py#L16)). In this case, the alignment paths consist of the word-level or phoneme-level transcripts predicted by the model.

- The output labels are the **word indices** in the transcript, instead of word labels. There are two benefits: (1) we can obtain the word labels easily from the word indices; (2) with word indices, we are able to preserve the ordering of words in the transcript, even though the alignment paths may contain words in the order different from the original transcript. As we will see later, this word ordering information enables efficient and effective heuristics for post-processing.

In [None]:
if device == torch.device("cpu"):
    batch_size = 2
else:
    batch_size = 32

output_frames_offset = segment_offsets // (sample_rate * 0.02)

alignment_results = list()
for i in tqdm(range(0, waveform_segmented.size(0), batch_size)):
    batch_waveform_segmented = waveform_segmented[i: i+batch_size].to(device)
    batch_segment_lengths = segment_lengths[i: i+batch_size]
    batch_output_frames_offset = output_frames_offset[i: i+batch_size]

    with torch.inference_mode():
        # Checkout the API of the forward function here: https://github.com/pytorch/audio/blob/main/src/torchaudio/pipelines/_wav2vec2/utils.py#L34
        batch_emissions, batch_emissions_lengths = model(batch_waveform_segmented.to(device), batch_segment_lengths.to(device))
    
    # Attach the star dimension manually, see torchaudio issue #3772
    star_dim = torch.empty((batch_emissions.size(0), batch_emissions.size(1), 1), device=batch_emissions.device, dtype=batch_emissions.dtype)
    star_dim[:] = -5.0
    batch_emissions = torch.cat((batch_emissions, star_dim), 2)

    # `token_ids` and `timestamps` will each be a list of lists.
    # Each sublist corresponds to a segment in the batch.
    batch_results = align_segments(
        batch_emissions,
        decoding_graph,
        batch_emissions_lengths,
    )

    # The interpretation of `token.token_id` depends on the decoding graph.
    # Here, in this tutorial, `token.token_id` is the key to the `word_index_sym_tab``
    # and `token_sym_tab` dictionaries.
    for aligned_tokens, offset in zip(batch_results, batch_output_frames_offset):
        for token in aligned_tokens:
            token.timestamp += offset  # This will become the absolute frame timestamp in the whole audio
            if token.token_id == tokenizer.blk_id:
                continue
            if token.token_id in word_index_sym_tab:
                token.attr["wid"] = word_index_sym_tab[token.token_id]
            if token.token_id in token_sym_tab:
                token.attr["tk"] = token_sym_tab[token.token_id]            

    alignment_results.extend(batch_results)

    # It will be too slow to do inference on CPU. 
    # As it's only for the demonstration purpose, so we will break the loop here.
    # We can still see some partial alignment results.
    if device == torch.device("cpu"):
        break

We've got the alignment results for all 140 15-second segments.

In [None]:
len(alignment_results)

Now we are going to concatenate the alignment results appropriately. We need to take care of two things: (1) the mis-aligned results, (2) the overlapping parts. The common algorithm for this concatenation step is through [Levenshtein distance](https://en.wikipedia.org/wiki/Levenshtein_distance) (e.g., in [Gentle](https://github.com/lowerquality/gentle/blob/master/gentle/diff_align.py#L16)), which aligns the forced-aligned transcript with the ground-truth transcript. As we have word indices instead of word labels, we will do this a bit differently. Given the alignment results for all segments, we will find the [longest increasing subsequence (LIS)](https://en.wikipedia.org/wiki/Longest_increasing_subsequence) in the word indices. This can be done in $O(NlogN)$ time complexity and $O(N)$ space complexity, instead of $O(N^2)$ by Levenshtein distance, where $N$ is the total length of the segment-wise alignment results. This can be useful especially when $N$ is large for long audios.

Moreover, with word indices, we are able to develop heuristics to remove outliers in the alignment results. This has been all implemented in our alignment library. E.g., by calling the `concat_alignments` function below.

In [None]:
# TODO: 
# The beginning and ending of the audio is hard to be aligned to the book.
# We might need to use VAD or something to handle it. Or predefine the start/end of the audio and text.
# Ignored them for now cos we just use the alignment for ASR training.

# `resolved_alignment_results` is a list of `AlignedToken`
# `unaligned_text_indices` is a list of (start_word_index, end_word_index)
#    which corresponds to "holes" in the long text that are not aligned

resolved_alignment_results, unaligned_text_indices = concat_alignments(
    alignment_results, 
    neighborhood_size=5,
)

len(resolved_alignment_results), len(unaligned_text_indices)

Our final word-level alignment results can be obtained by the following, where `word_alignment` is a dictionary from word indices in `text` to an `AlignedWord` object.

In [None]:
word_alignment = get_final_word_alignment(resolved_alignment_results, text, tokenizer)

Finally, let's preview the alignment results:

In [None]:
list(word_alignment.items())[:10]

In [None]:
list(word_alignment.items())[-10:]

As we can see from above, the alignment results look pretty good! Despite the book being long and noisy (e.g., ), it successfully locates [the chapter](https://www.gutenberg.org/cache/epub/205/pg205-images.html#chap08) from the whole book corresponding to the audio, except the first two words ("The Bean-Field") which are the chapter title.

In `unaligned_text_indices`, we can see the parts in the book chapter that are not successfully aligned to the audio:

In [None]:
# [s, e[] are in a "hole" in the long text that is not aligned
for s, e in unaligned_text_indices:
    if e-s > 0
[x2-x1  if x2-x1 > 0]

In [None]:
text_splitted = text.split()
[text_splitted[tk.attr['wid']] if 'wid' in tk.attr else tokenizer.id2token[tk.attr['tk']] for tk in resolved_alignment_results[-40:-1]]


In [None]:
resolved_alignment_results[0]

In [None]:
audacity_labels_str = to_audacity_label_format(alignment_results, 0.02, text)

In [None]:
audacity_path = str(SPEECH_FILE)[:-4] + "_audacity.txt"
with open(audacity_path, "w") as fout:
    print(audacity_labels_str, file=fout)

In [None]:
_alignment_results = alignment_results

In [None]:
print([f"{x:5}" for x in results_ids[1]])
print([f"{x:5}" for x in results_timestamps[1]])

In [None]:
output_frames_offset

In [None]:
alignment_results

In [None]:
import lis
from torchaudio_k2_aligner import *

hyps = [[token.attr["wid"] for token in aligned_tokens if "wid" in token.attr] for aligned_tokens in alignment_results]
timestamps = [[token.timestamp for token in aligned_tokens if "wid" in token.attr] for aligned_tokens in alignment_results]
num_segments_per_chunk=5
neighbor_threshold=5
device=device

# Find the longest increasing subsequence (LIS)
hyp_list = [i for hyp in hyps for i in hyp]
lis_results = lis.longestIncreasingSubsequence(hyp_list)

# Post-process1: remove outliers from the LIS results
lis_results = remove_outliers(lis_results, scan_range=100, outlier_threshold=60)
print(lis_results)
print(len(lis_results))

# Post-process2: remove isolatedly aligned words
# Each aligned word should have a neighborhood of at least neighbor_threshold words
rg_min = min(lis_results)
rg_max = max(lis_results)
set_lis_results = set(lis_results)
for i in range(rg_min, rg_max + 1):
    if i in set_lis_results:
        left_neighbors_in_lis = [j for j in range(i-neighbor_threshold, i) if j in set_lis_results]
        right_neighbors_in_lis = [j for j in range(i+1, i+neighbor_threshold+1) if j in set_lis_results]
        num_left_neighbors = i - max(i-neighbor_threshold, rg_min)
        num_right_neighbors = min(i+neighbor_threshold, rg_max) - i
        # only less than 50% of the words in the neighborhood are aligned
        if len(left_neighbors_in_lis) < 0.4 * num_left_neighbors and \
            len(right_neighbors_in_lis) < 0.4 * num_right_neighbors:
            set_lis_results.remove(i)
lis_results = [i for i in lis_results if i in set_lis_results]
print(lis_results)
print(len(lis_results))

# Align LIS results with the original `alignment_results`
alignment_results = get_lis_alignment(lis_results, alignment_results)  # hyp_list and lis_result are both word indices in the long text

In [None]:
# Keep only the aligned tokens which are in LIS
resolved_alignment_results = list()
for aligned_tokens in alignment_results:
    word_start_flag = False
    for token in aligned_tokens:
        if token.attr.get("lis", False):
            resolved_alignment_results.append(token)
            word_start_flag = True
        elif "wid" in token.attr:
            # assert "lis" not in token.attr
            word_start_flag = False
        elif word_start_flag:
            resolved_alignment_results.append(token)

In [None]:
len(resolved_alignment_results)

In [None]:
len(text.split())

In [None]:
import torchaudio_k2_aligner

import importlib
importlib.reload(torchaudio_k2_aligner)

concat_alignments = torchaudio_k2_aligner.concat_alignments

In [None]:
tki = 1
token_ids_indices[tki]

In [None]:
"".join([tokenizer.id2token[token_sym_tab[i]] for i in token_ids_indices[tki]])

In [None]:
[word_index_sym_tab[i] for i in token_ids_indices[tki] if i in word_index_sym_tab]

In [None]:
text_splitted = text.split()
[text_splitted[word_index_sym_tab[i]] for i in token_ids_indices[tki] if i in word_index_sym_tab]

# Test

In [None]:
import k2
import torch


def uniform_segmentation_with_overlap(m, segment_size, overlap, shortest_segment_size=0):
    '''
    This function cuts the input matrix `m` into overlapping segments.
    `m` can be, e.g., the feature matrix of the input audio or the emission matrix.

    Args:
        m: 3-D tensor of shape (1, T, D).
        segment_size: an integer, the size of each segment.
        overlap: an integer, the number of frames to overlap between segments.
    '''
    assert m.ndim == 3
    assert m.size(0) == 1
    # m = m.unsqueeze(0)  # (1, T, D)

    step = segment_size - overlap
    if (m.size(1) - segment_size) % step == 0:
        n_segments = (m.size(1) - segment_size) // step + 1
        padding_size = 0
    else:
        n_segments = (m.size(1) - segment_size) // step + 2
        padding_size = (n_segments - 1) * step + segment_size - m.size(1)
    
    m_padded = torch.nn.functional.pad(m, (0, 0, 0, padding_size))  # Pad the tensor with zeros
    m_segmented = m_padded.unfold(dimension=1, size=segment_size, step=step)
    m_segmented = m_segmented.permute(0, 1, 3, 2)  # `m_segmented` is of shape (1, L, segment_size, D), where L is the number of segments.
    m_segmented = m_segmented[0]  # (L, segment_size, D)
    
    segment_lengths = [segment_size] * m_segmented.size(0)
    segment_lengths[-1] -= padding_size
    segment_lengths = torch.tensor(segment_lengths)

    segment_offsets = torch.arange(0, segment_lengths.size(0) * step, step)

    assert len(segment_lengths) == n_segments
    assert len(segment_lengths) == m_segmented.size(0)
    assert len(segment_lengths) == len(segment_offsets)

    # Discard the last chunk if it is too short
    if segment_lengths[-1] < shortest_segment_size:
        m_segmented = m_segmented[:-1]
        segment_lengths = segment_lengths[:-1]
        segment_offsets = segment_offsets[:-1]
    
    return m_segmented, segment_lengths, segment_offsets

In [None]:
m = torch.rand(1, 20, 3)
m

In [None]:
uniform_segmentation_with_overlap(m, 7, 2, shortest_segment_size=0)

In [None]:
m = torch.rand(1, 7, 3)
segment_size = 7
overlap = 2
step = segment_size - overlap
if (m.size(1) - segment_size) % step == 0:
    n_segments = (m.size(1) - segment_size) // step + 1
    padding_size = 0
else:
    n_segments = (m.size(1) - segment_size) // step + 2
    padding_size = (n_segments - 1) * step + segment_size - m.size(1)

print(n_segments, padding_size)

TODO:
- dataclass for alignment results
- inspect the improvement for walden
- phone-level alignment
- gentle
- quasi-gentle in colab
- training recipe: need to start it tonight