
# Forced Alignment with Wav2Vec2

Running Force Alignment with PyTorch wav2vec2 models.


## Overview

The process of alignment looks like the following.

1. Estimate the frame-wise label probability from audio waveform
2. Generate the trellis matrix which represents the probability of
   labels aligned at time step.
3. Find the most likely path from the trellis matrix.

In this example, we use ``torchaudio``\ ’s ``Wav2Vec2`` model for
acoustic feature extraction.




## Preparation

First we import the necessary packages, and fetch data that we work on.




In [1]:
import os
from dataclasses import dataclass

import torch
import torchaudio

import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]

## Generate frame-wise label probability

The first step is to generate the label class porbability of each aduio
frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use
:py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.

``torchaudio`` provides easy access to pretrained models with associated
labels.

<div class="alert alert-info"><h4>Note</h4><p>In the subsequent sections, we will compute the probability in
   log-domain to avoid numerical instability. For this purpose, we
   normalize the ``emission`` with :py:func:`torch.log_softmax`.</p></div>




In [5]:
# path to input speech file
SPEECH_FILE = '/home/paperspace/repos/transformers/examples/research_projects/wav2vec2/wavs/g2.wav'

# transcript for this file
TEXT_FILE = '/home/paperspace/repos/transformers/examples/research_projects/wav2vec2/transcript_clean.txt'


# model we are loading
MODEL_NAME = 'WAV2VEC2_ASR_LARGE_LV60K_960H'
# sampling rate the model expects
# NOTE: most wav2vec models assume 16k sampling rate
MODEL_SR = 16_000


@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start



def load_audio(wav_file, model_sr=MODEL_SR):
    '''Loads `wav_file` using torchaudio.
    
    Resamples the audio to `MODEL_SR` if needed.
    '''
    wav, sr = torchaudio.load(wav_file)
    if sr != model_sr:
        print(f'Resampling from {sr} to {MODEL_SR}')
        wav = torchaudio.transforms.Resample(sr, MODEL_SR)(wav)
    return wav


def clean_text(o):
    clean = o.replace(',', ' ')
    clean = clean.strip('\n')
    return clean

def load_transcript(text_file):
    '''Loads a transcript in `text_file`.
    
    Assumes one transcription per line.
    '''
    # load text file
    lines = open(text_file, encoding="utf8").readlines()
    # cleanup the text
    lines = [clean_text(line) for line in lines]
    # replace spaces with special token `|`
    lines = ['|'.join(line.split(' ')) for line in lines]
    # make all characters upper-case for wav2vec token outputs
    lines = [line.upper() for line in lines]
    return lines
    
    
def load_model(model_name):
    '''Loads a pytorch `model_name` from torchaudio.
    '''
    bundle = getattr(torchaudio.pipelines, model_name, None)
    if bundle:
        model = bundle.get_model()
    else:
        raise ValueError(f'Could not find model: {model_name}')
    return model, bundle



def get_emissions(model, audio):
    '''Gets token probabilities from `model` for the speech given in `audio`.
    '''
    emissions, _ = model(audio)
    emissions = torch.log_softmax(emissions, dim=-1)
    return emissions


def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    # Trellis has extra diemsions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )
    return trellis


def backtrack(trellis, emission, tokens, blank_id=0):
    # Note:
    # j and t are indices for trellis, which has extra dimensions
    # for time and tokens at the beginning.
    # When referring to time frame index `T` in trellis,
    # the corresponding index in emission is `T-1`.
    # Similarly, when referring to token index `J` in trellis,
    # the corresponding index in transcript is `J-1`.
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        # 1. Figure out if the current position was stay or change
        # Note (again):
        # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
        # Score for token staying the same from time frame J-1 to T.
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        # Score for token changing from C-1 at T-1 to J at T.
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

        # 2. Store the path with frame-wise probability.
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        # Return token index and time index in non-trellis coordinate.
        path.append(Point(j - 1, t - 1, prob))

        # 3. Update the token
        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]


def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments


# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words

In [6]:
# example run through


def run_forced_alignment(
    model_name,
    speech_file,
    text_file,
    device=None,
    
):
    
    # set the hardware device
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load the model
    model, bundle = load_model(model_name)
    model.eval()
    model.to(device)

    # get model labels
    labels = bundle.get_labels()
    dictionary = {c: i for i, c in enumerate(labels)}

    # load the audio and transcript
    audio = load_audio(speech_file)
    audio = audio.to(device)
    
    # read in the transcript
    if os.path.isfile(text_file):
        transcript = load_transcript(text_file)[0]
    else:
        print(f'Note: using "{text_file}" as the transcript')
        transcript = text_file

    # get the token probabilities
    emissions = get_emissions(model, audio)
    emissions = emissions[0].detach().cpu()

    # turn the transcript into tokens
    tokens = [dictionary[c] for c in transcript]

    # populate the trellis
    trellis = get_trellis(emissions, tokens)

    # walkback to find the most likely trellis path
    path = backtrack(trellis, emissions, tokens)

    # merge the paths with repeated labels
    segments = merge_repeats(path, transcript)

    # merge the words based on the separator token '|'
    word_segments = merge_words(segments)
    
    return {
        'character_segs': segments,
        'word_segs': word_segments,
    }

In [7]:
segs = run_forced_alignment(
    MODEL_NAME,
    SPEECH_FILE,
    TEXT_FILE,
)

In [8]:
segs['character_segs']

[F	(1.00): [   20,    27),
 O	(1.00): [   27,    28),
 U	(0.99): [   28,    30),
 R	(0.63): [   30,    34),
 |	(0.94): [   34,    37),
 S	(1.00): [   37,    41),
 C	(1.00): [   41,    49),
 O	(1.00): [   49,    51),
 R	(1.00): [   51,    53),
 E	(1.00): [   53,    55),
 |	(1.00): [   55,    58),
 A	(1.00): [   58,    59),
 N	(1.00): [   59,    60),
 D	(0.52): [   60,    62),
 |	(1.00): [   62,    64),
 S	(1.00): [   64,    68),
 E	(1.00): [   68,    71),
 V	(1.00): [   71,    73),
 E	(0.93): [   73,    75),
 N	(0.86): [   75,    77),
 |	(1.00): [   77,    78),
 Y	(1.00): [   78,    79),
 E	(0.50): [   79,    81),
 A	(0.83): [   81,    83),
 R	(1.00): [   83,    86),
 S	(1.00): [   86,    88),
 |	(1.00): [   88,    90),
 A	(1.00): [   90,    93),
 G	(1.00): [   93,    97),
 O	(1.00): [   97,   100),
 |	(0.96): [  100,   124),
 O	(1.00): [  124,   125),
 U	(1.00): [  125,   127),
 R	(0.88): [  127,   129),
 |	(1.00): [  129,   132),
 F	(1.00): [  132,   138),
 A	(1.00): [  138,   142),
 