In [1]:
import os
os.environ["CUDA_HOME"] = "/usr/local/cuda-12.1.0"
os.environ["PATH"] = f"{os.environ['CUDA_HOME']}/bin:{os.environ['PATH']}"
os.environ["LD_LIBRARY_PATH"] = f"{os.environ['CUDA_HOME']}/lib64:{os.environ['LD_LIBRARY_PATH']}"
os.environ["LD_LIBRARY_PATH"] = f"{os.environ['CUDA_HOME']}/lib:{os.environ['LD_LIBRARY_PATH']}"
os.environ["LD_LIBRARY_PATH"] = f"{os.environ['CUDA_HOME']}/extras/CUPTI/lib64:{os.environ['LD_LIBRARY_PATH']}'"
os.environ["CUDAToolkit_ROOT_DIR"] = f"{os.environ['CUDA_HOME']}"
os.environ["CUDAToolkit_ROOT"] = f"{os.environ['CUDA_HOME']}"

os.environ["CUDA_TOOLKIT_ROOT_DIR"] = f"{os.environ['CUDA_HOME']}"
os.environ["CUDA_TOOLKIT_ROOT"] = f"{os.environ['CUDA_HOME']}"
os.environ["CUDA_BIN_PATH"] = f"{os.environ['CUDA_HOME']}"
os.environ["CUDA_PATH"] = f"{os.environ['CUDA_HOME']}"
os.environ["CUDA_INC_PATH"] = f"{os.environ['CUDA_HOME']}/targets/x86_64-linux"
os.environ["CFLAGS"] = f"-I{os.environ['CUDA_HOME']}/targets/x86_64-linux/include:{os.environ['CFLAGS']}"
os.environ["CUDAToolkit_TARGET_DIR"] = f"{os.environ['CUDA_HOME']}/targets/x86_64-linux"

In [2]:
import torch
import torchaudio
import torchaudio.functional as F
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

audio_path = Path("../outputs/openbible_swahili/EPH/EPH_003/EPH_003_001.wav")
json_path = Path("../data/openbible_swahili/EPH.json")
book, verse_id = json_path.stem, audio_path.stem

In [3]:
import json
import re
import string
import unicodedata
from unidecode import unidecode
from num2words import num2words

def preprocess_verse(text: str) -> str:
    text = unidecode(text)
    text = unicodedata.normalize('NFKC', text)
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = re.sub(r"\d+", lambda x: num2words(int(x.group(0)), lang="sw"), text)
    text = re.sub("\s+", " ", text)
    return text

def load_transcript(json_path: Path, verse: str) -> str:
    with open(json_path, "r") as f:
        data = json.load(f)
    
    # convert PSA 19:1 -> PSA_019_001
    get_verse = lambda x: x.split()[0] + "_" + x.split(":")[0].split()[1].zfill(3) + "_" + x.split(":")[1].zfill(3)
    # filter by verse
    transcript = [d["verseText"] for d in data if get_verse(d["verseNumber"]) == verse][0]
    return transcript

In [45]:
transcript = load_transcript(json_path, verse_id)
verse = preprocess_verse(transcript)
words = verse.split()

In [5]:
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)
LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)

In [6]:
chunk_size_s = 15
waveform, sr = torchaudio.load(audio_path)
waveform = torchaudio.functional.resample(waveform, sr, bundle.sample_rate)
sr = bundle.sample_rate
chunk_size_frames = chunk_size_s * sr
chunks = [waveform[:, i : i + chunk_size_frames] for i in range(0, waveform.shape[1], chunk_size_frames)]

In [7]:
emissions = []

with torch.inference_mode():
    for chunk in chunks:
        if chunk.size(1) >= 400:
            emission, _ = model(chunk.to(device))
            emissions.append(emission)

emission = torch.cat(emissions, dim=1)
assert len(DICTIONARY) == emission.shape[2]
num_frames = emission.size(1)

In [8]:
verse

'kwa sababu hii mimi paulo mfungwa wa kristo yesu kwa ajili yenu ninyi watu wa mataifa'

In [29]:
probs = torch.softmax(emission, dim=2)
greedy_prob = torch.max(probs, dim=-1).values.squeeze()
greedy_log_probs = torch.sum(torch.log(greedy_prob)).cpu().numpy().item()
greedy_log_probs

-7.465915203094482

In [35]:
def align(emission, tokens):
    targets = torch.tensor([tokens], dtype=torch.int32, device=device)
    alignments, scores = F.forced_align(emission, targets, blank=0)

    alignments, scores = alignments[0], scores[0]  # remove batch dimension for simplicity
    scores = scores.exp()  # convert back to probability
    return alignments, scores

def compute_alignments(emission, transcript, dictionary):
    tokens = [dictionary[char] for word in transcript for char in word]
    _, scores = align(emission, tokens)
    return scores

In [52]:
aligned_probs = compute_alignments(emission, words, DICTIONARY)
aligned_log_probs = torch.sum(torch.log(aligned_probs)).cpu().numpy().item()
aligned_log_probs

-9.359113693237305

In [49]:
probability_diff = (aligned_log_probs - greedy_log_probs) / num_frames
probability_diff

-0.11303391470923438