In [1]:
from datasets import load_dataset

dataset = load_dataset("bookbot/common-voice-accent-gb")

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import torchaudio.functional as F
from pathlib import Path

import torch
import torchaudio
import torchaudio.transforms as T


from scipy.io.wavfile import write
from dataclasses import dataclass
from string import punctuation
from typing import List
from num2words import num2words
from unidecode import unidecode
from num2words import num2words
import unicodedata
import re
import string

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = Path("./cv-gb-alignment-result")
output_dir.mkdir(exist_ok=True)
# load MMS aligner model
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model().to(device)
chunk_size_s = 15
DICTIONARY = bundle.get_dict()
MMS_SUBSAMPLING_RATIO = 400


def preprocess_text(text: str) -> str:
    text = unidecode(text)
    text = unicodedata.normalize("NFKC", text)
    text = text.lower()
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = re.sub(r"\d+", lambda x: num2words(int(x.group(0)), lang="en"), text)
    text = re.sub("\\s+", " ", text)
    return text

def align(emission, tokens, device):
    targets = torch.tensor([tokens], dtype=torch.int32, device=device)
    alignments, scores = F.forced_align(emission, targets, blank=0)

    alignments, scores = alignments[0], scores[0]  # remove batch dimension for simplicity
    scores = scores.exp()  # convert back to probability
    return alignments, scores


def unflatten(list_, lengths):
    assert len(list_) == sum(lengths)
    i = 0
    ret = []
    for l in lengths:
        ret.append(list_[i : i + l])
        i += l
    return ret

def compute_alignments(emission, transcript, dictionary, device):
    tokens = [dictionary[char] for word in transcript for char in word]
    alignment, scores = align(emission, tokens, device)
    token_spans = F.merge_tokens(alignment, scores)
    word_spans = unflatten(token_spans, [len(word) for word in transcript])
    return word_spans

def get_word_segments(datum):
    transcript = datum["sentence"]
    transcript = preprocess_text(transcript)
    words = transcript.split()
    audio = datum["audio"]
    sampling_rate = audio["sampling_rate"]
    audio_array = torch.from_numpy(audio["array"])
    audio_id = Path(audio["path"]).stem

    resampler = T.Resample(sampling_rate, bundle.sample_rate, dtype=audio_array.dtype)
    resampled_waveform = resampler(audio_array)

    # split audio into chunks to avoid OOM and faster inference
    chunk_size_frames = chunk_size_s * bundle.sample_rate
    resampled_waveform = torch.unsqueeze(resampled_waveform, 0).float()
    chunks = [
        resampled_waveform[:, i : i + chunk_size_frames]
        for i in range(0, resampled_waveform.shape[1], chunk_size_frames)
    ]

    # collect per-chunk emissions, rejoin
    emissions = []
    with torch.inference_mode():
        for chunk in chunks:
            # NOTE: we could pad here, but it'll need to be removed later
            # skipping for simplicity, since it's at most 25ms
            print(chunk.size(1))
            if chunk.size(1) >= MMS_SUBSAMPLING_RATIO:
                emission, _ = model(chunk.to(device))
                print(emission.shape)
                emissions.append(emission)


    emission = torch.cat(emissions, dim=1)
    num_frames = emission.size(1)
    assert len(DICTIONARY) == emission.shape[2]

    # perform forced-alignment
    word_spans = compute_alignments(emission, words, DICTIONARY, device)
    assert len(word_spans) == len(words)

    # collect verse-level segments
    segments, labels, start = [], [], 0
    for word, span in zip(words, word_spans):
        ratio = resampled_waveform.size(1) / num_frames
        x0 = int(ratio * span[0].start)
        x1 = int(ratio * span[-1].end)
        segment = resampled_waveform[:, x0:x1]
        segments.append(segment)
        labels.append(word)

    for segment, label in zip(segments, labels):
        audio_name = audio_id + "-" + label
        # write audio
        audio_path = (output_dir / audio_name).with_suffix(".wav")
        write(audio_path, bundle.sample_rate, segment.squeeze().numpy())

        # write transcript
        transcript_path = (output_dir / audio_name).with_suffix(".txt")
        with open(transcript_path, "w") as f:
            f.write(label)

In [6]:
from tqdm.auto import tqdm
for datum in tqdm(dataset["train"].select(range(1))):
    get_word_segments(datum)

100%|██████████| 1/1 [00:00<00:00, 15.43it/s]

110976
torch.Size([1, 346, 29])



