## Utils / Housekeeping

In [1]:
import IPython.display
from pathlib import Path

import os
import numpy as np

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at
import re

import concurrent.futures
import glob

import os

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# from tqdm.notebook import tqdm   # for colab
from tqdm import tqdm              # for jupyter

import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

from utils import CfgNode

from typing import List, Union, Tuple

## Data Preprocessing

Since we are ustilizing HuggingFace's AudioLoader for staging our audio files and transcriptions, we'll need to pair each audio file's path with its corresponding transcript and save this information in a .txt file. Let's code up some functions for handling this pre-processing step.

We'll be using three locally stored datasets for this training task. This notebook is to contain a training pipeline for data stored _locally_. If you'd like to train a Whisper model using Lambda Cloud SLI, see the notebook `lambda_cloud_whisper_train.ipynb`. 

The three datasets are KenSpeech, Broadcast News Swahili, and Babel Swahili Language Pack.

### Utility Functions

In [2]:
def remove_bracketed_text(string: str, return_removed_text: bool=False) -> Union[str, Tuple[str, List[str]]]:
    """
    Remove text in angle brackets from a string and return the cleaned string.
    Optionally, return the list of removed text.
    
    :param string: The input string.
    :param return_removed_text: If True, return the list of removed text.
    :return: The cleaned string and, optionally, the list of removed text.
    """
    # use the re.findall() method to find all matches of text in angle brackets
    bracketed_text = re.findall(r"<.*?>", string)
    
    # use the re.sub() method to remove text in angle brackets from the string
    cleaned_string = re.sub(r"<.*?>", "", string)
    
    # return the cleaned string and the list of removed text
    return (cleaned_string, list(set(bracketed_text))) if return_removed_text==True else (cleaned_string)


def remove_extra_whitespace(lines: List[str]) -> List[str]:
    """
    Removes extra whitespace from a list of lines.    
    Args:
        lines: A list of strings containing the lines to process.

    Returns:
        A list of strings with extra whitespace removed.
    """
    # Initialize an empty list to store the processed lines
    processed_lines = []

    # Iterate over the lines
    for line in lines:
        # Replace multiple whitespace characters with a single space
        line = re.sub(r"\s+", " ", line)
    
        # Strip leading and trailing whitespace from the line
        line = line.strip()

        # Add the processed line to the list
        processed_lines.append(line)
  
    # Return the list of processed lines
    return processed_lines


def remove_parentheses(lines: List[str]) -> List[str]:
    """
    This function removes the "(())" pattern from the given lines.
    Args:
        lines (List[str]): A list of strings containing the "(())" pattern.

    Returns:
        List[str]: A list of strings with the "(())" pattern removed.
    """
    # Initialize an empty list to store the processed lines
    processed_lines = []

    # Iterate over the lines
    for line in lines:
        # Replace the "(())" pattern with an empty string
        line = line.replace("(())", "")

        # Add the processed line to the list
        processed_lines.append(line)
  
    # Return the list of processed lines
    return processed_lines


def find_wav_files(root_folder: str, ending=".wav") -> List[str]:
    """
    Searches the folder tree starting at the specified root folder, and returns
    a list of paths to all of the .wav files that it finds.
    """
    # Create an empty list to store the paths
    wav_file_paths = []

    # Iterate over the folders and files in the tree
    for path, folders, files in os.walk(root_folder):
      # Iterate over the files in the current folder
      for file in files:
        # Check if the file has the .wav extension
        if file.endswith(ending):
          # Add the path to the .wav file to the list
          wav_file_paths.append(os.path.join(path, file))

    return wav_file_paths


def extract_filenames(file_paths: List[str], name_to_extract: str=".wav") -> Tuple[List[str], List[str]]:
    """
    This function takes in a list of file paths and an optional string of the file extension to extract. It returns a tuple of two lists - 
    the first list contains the extracted filenames, and the second list contains the paths of files that couldn't be found.
    Args:
    file_paths: a list of file paths
    name_to_extract: a string of the file extension to extract (default is ".wav")

    Returns:
    A tuple of two lists - the first list contains the extracted filenames, and the second list contains the paths of files that couldn't be found.
    """
    filenames = []
    failed_paths = []
    for file_path in file_paths:
        try:
            # Split the file path into a list of strings using the "\\" separator
            split_path = file_path.split("\\")
            # Get the last element of the split path (i.e. the filename)
            filename = split_path[-1]
            # Remove the ".wav" extension from the filename
            filename = filename.replace(name_to_extract, "")
            # Add the extracted filename to the list of filenames
            filenames.append(filename)
        except FileNotFoundError:
            # If a FileNotFoundError is raised, add the file path to the list of failed paths
            failed_paths.append(file_path)
    if len(failed_paths) > 0:
        print(f"{len(failed_paths)} paths weren't found.")
    return filenames, failed_paths


def process_lines(lines: List[str]) -> List[Tuple[str, str]]:
    """
    Given a list of lines in the format 'filename=utterance',
    this function returns a list of tuples containing the audio file name and the utterance.
    """  
    # Initialize an empty list to store the processed lines
    processed_lines = []
  
    # Iterate over the lines
    for line in lines:
        
        # Split the line at the first '=' character
        parts = line.split('=', 1)
    
        # Get the audio file name and the utterance from the parts
        audio_file = parts[0] + ".wav"
        utterance = parts[1].strip()  # Use str.strip() to remove leading and trailing whitespace
    
        # Add the audio file name and the utterance to the list of processed lines
        processed_lines.append((audio_file, utterance))
  
    # Return the list of processed lines
    return processed_lines

### KenSpeech

This code defines a function named `find_wav_files()` that takes a root folder as an input, and returns a list of paths to all of the .wav files that it finds in the folder tree. The function uses the os.walk() function to iterate over the folders and files in the tree, and appends the path to any .wav files that it finds to the wav_file_paths list.

In [None]:
# where to start our '.wav' search
root_folder = r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\kenspeech\audios"

kenspeech_wav_files = find_wav_files(root_folder)
kenspeech_wav_files[:3]

There is but one folder containing all transcripts that we'll need to search. The files's naming convention follows this format: `sample_*.txt` or `tweet_*.txt` where `*` denotes each file's unique id. We'll use the paths collected from our `find_wav_files` function and extract the file names into a list.

In [None]:
filenames, files_not_found = extract_filenames(kenspeech_wav_files)

print(f"Length of filenames: {len(filenames)}. \nSamples:")
print(filenames[:10]) # Should print ["tweet_*", "tweet_*", ...]

In [None]:
transcriptions_path = r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\kenspeech\transcripts"
transcriptions_paths = [
    transcriptions_path + "\\" + text_file_name  + ".txt" for text_file_name in filenames
]

print(f"No. of transcriptions: {len(transcriptions_paths)}")
print(transcriptions_paths[:3])

In this updated version of the `TextProcessor` class, we added a new method called process_text that takes a list of transcript paths as an input and processes the text in each transcript file.

The process_text method first reads the lines of text from each transcript file, and stores the lines in the text_set list. Then, it uses the remove_parentheses, clean_strings, and remove_extra_spaces methods to process each line of text, removing parentheses, extra whitespace, and extra spaces. Finally, it prints the processed text to the console.

To use the process_text method, we create a `TextProcessor` object, and then call the process_text method on the object, passing a list of transcript paths as an argument.

In [None]:
class TextProcessor:
    def __init__(self):
        import re
        self.re = re

    def remove_parentheses(self, string):
        """Remove all text in parentheses from the given string."""
        return self.re.sub(r"\([^)]*\)", "", string)

    def clean_strings(self, string_list):
        """ Remove \n and leading/trailing whitespace from strings using list comprehension"""
        clean_list = [string.replace("\n", "").strip() for string in string_list]

        return clean_list

    def remove_extra_spaces(self, string_list):
        """Replace two or more consecutive spaces with a single space in the strings in the given list."""
        # use regular expression to match two or more spaces
        pattern = self.re.compile(r"\s{2,}")

        # use list comprehension to apply pattern.sub to each string in string_list
        processed_list = [pattern.sub(" ", string) for string in string_list]

        return processed_list

    def process_text(self, transcript_paths):
        # create a list to store the paths that raised a FileNotFoundError
        not_found = []

        text_set = []
        for i in range(len(transcript_paths)):
            try:
                # try to open the file and read the lines
                with open(transcript_paths[i]) as f:
                    text = f.readlines()
                    text_set.append(text)
            except FileNotFoundError:
                # if a FileNotFoundError is raised, append the path to the not_found list
                not_found.append(transcript_paths[i])

        # remove `(tweet_*)` from samples
        # use generator expression to process text
        new_text_set = self.remove_extra_spaces(
            self.clean_strings(
                (self.remove_parentheses(text_set[idx][0]) for idx, text in enumerate(text_set))
        ))

        return new_text_set
    
def find_strings_with_substring(strings: List[str], substring: str) -> List[str]:
    """
    This function takes a list of strings and a substring and returns a list of strings from the original list
    that contain the substring.
    Args:
    - strings: a list of strings
    - substring: a string that we want to search for

    Returns:
    - a list of strings from the original list that contain the substring
    """
    regex = r".*" + re.escape(substring) + r".*"
    return [string for string in strings if re.match(regex, string)]

In [None]:
tp = TextProcessor()
transcriptions = tp.process_text(transcriptions_paths)
transcriptions[:5]

The Radford et al. paper describes a prediction process that begins with a special token. The language being spoken is then predicted based on a unique token for each language in the training set. If there is no speech in an audio segment, the model predicts a token indicating this. The next token specifies whether the task is transcription or translation. Timestamps can also be predicted by including a special token. Finally, the output begins with another special token. For our case, we'll only use the `<|startoftranscript|>` and `<|endoftranscript|>` tokens.

In [None]:
def add_start_token(input_string):
    return input_string.replace("<s>", "")

def add_end_token(input_string):
    return input_string.replace("< s>", "")

transcriptions_with_tokens = [
    add_end_token(
        add_start_token(
            transcriptions[idx])) for idx in range(len(transcriptions))
]

transcriptions_with_tokens = [
    transcriptions_with_tokens[i].strip() for i in range(len(transcriptions_with_tokens))
]
transcriptions_with_tokens[20:24]

In [None]:
# remove text in angled brackets
files = []
for idx in range(len(transcriptions_with_tokens)):
    files.append(
        remove_bracketed_text(transcriptions_with_tokens[idx])
    )

cleaned_files = remove_extra_whitespace(files)
cleaned_files[20:24]

Finally, we'll pair each transcript with its corresponding filename.

In [None]:
substrings = []
for file_path in transcriptions_paths:
    # Split the file path into a list of strings using the "\\" separator
    split_path = file_path.split("\\")
    # Get the last element of the split path (i.e. the filename)
    filename = split_path[-1]
    # Remove the ".wav" extension from the filename
    filename = filename.replace(".txt", ".wav")
    # Add the extracted filename to the list of filenames
    substrings.append(filename)
    

files = []
for idx in range(len(substrings)):
    files.append(
        find_strings_with_substring(kenspeech_wav_files, substrings[idx])
    )
    
kenspeech_data = []
for audio_path, transcript in zip(files, cleaned_files):
        kenspeech_data.append((audio_path[0], transcript))

kenspeech_data[20:24]

In [None]:
kenspeech_data = pd.DataFrame(kenspeech_data, columns=['file_name', 'transcription'])
kenspeech_data

In [None]:
kenspeech_data.to_csv("datasets/swahili/kenspeech_cleaned.txt", index=False)

### Babel Swahili

In [None]:
CN = CfgNode()
CN.DATASET_DIR = r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\babel_swahili_language_pack"
CN.SEED = 3407
CN.DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
seed_everything(CN.SEED, workers=True)

In [None]:
audio_paths = find_wav_files(r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\babel_swahili_language_pack")
text_paths = find_wav_files(r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\babel_swahili_language_pack", ".txt")

audio_files = extract_filenames(audio_paths)

#collect the audio filenames from our dataset
audio_files[0][:5]

In [None]:
# we will seperate out transcriptions without accompanying audio files
transcriptions_paths = []

for file in audio_files[0]:
    transcriptions_paths.append(find_strings_with_substring(text_paths, file)[0])

# collect valid transcription paths
transcriptions_paths[:5]

In [None]:
text_paths = [transcriptions_paths[i].replace(".wav", ".txt") for i in range(len(transcriptions_paths))]
text_paths[:4], audio_paths[:4]

In [None]:
timestamps, utterances, groups = [], [], []

for idx in range(len(transcriptions_paths)):

    with open(transcriptions_paths[idx], "r") as f:
        text_file = f.readlines()

    for element in text_file:
        if element.startswith('['):
            timestamp = element[1:-2]
            timestamps.append(float(timestamp))
        else:
            utterances.append(element[:-1])

    utterances = ' '.join(utterances)    
    groups.append((str(audio_paths[idx]), utterances))
    timestamps = []
    utterances = []
    
groups[:1]

In [None]:
babel_data = []

for idx in range(len(groups)):
    brackets_removed = remove_bracketed_text(groups[idx][1])
    whitespace_removed = remove_extra_whitespace([brackets_removed])
    parentheses_removed = remove_parentheses(whitespace_removed)
    babel_data.append((groups[idx][0], parentheses_removed[0]))

print(len(babel_data))
babel_data[:1]

In [None]:
babel_data = pd.DataFrame(babel_data, columns=['file_name', 'transcription'])
babel_data

In [None]:
babel_data.to_csv("datasets/swahili/babel_cleaned.txt", index=False)

### Broadcast News

In [None]:
audio_paths = find_wav_files(r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\data_broadcastnews_sw", ".wav")
text_paths = find_wav_files(r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\data_broadcastnews_sw", ".txt")

audio_files = extract_filenames(audio_paths[0])

transcriptions_paths = []

path = r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\data_broadcastnews_sw\data\train\train_text.txt"

with open(path, "r") as f:
    text_file = f.readlines()

processed_lines = process_lines(text_file)

broadcastnews_train_data = []

for idx in tqdm(range(len(processed_lines))):
    a_path = find_strings_with_substring(audio_paths, processed_lines[idx][0])[0]
    t_path = processed_lines[idx][1]
    broadcastnews_train_data.append([a_path, t_path])

In [None]:
path = r"C:\Users\Hedronstone\Desktop\whisper_event\datasets\swahili\data_broadcastnews_sw\data\test\test_text.txt"

with open(path, "r") as f:
    text_file = f.readlines()

processed_lines = process_lines(text_file)

broadcastnews_test_data = []

for idx in tqdm(range(len(processed_lines))):
    a_path = find_strings_with_substring(audio_paths, processed_lines[idx][0])[0]
    t_path = processed_lines[idx][1]
    broadcastnews_test_data.append([a_path, t_path])

In [None]:
broadcastnews_train = pd.DataFrame(broadcastnews_train_data, columns=['file_name', 'transcription'])
broadcastnews_test = pd.DataFrame(broadcastnews_test_data, columns=['file_name', 'transcription'])

broadcastnews_combined = pd.concat([
    broadcastnews_train, broadcastnews_test
    ]
)

broadcastnews_combined

In [None]:
broadcastnews_combined.to_csv('datasets/swahili/broadcastnews_cleaned.txt', index=False)

## Dataset Metrics

In [None]:
import wave
import os

def get_total_length_of_audio(audio_file_paths):
  total_length = 0
  for path in audio_file_paths:
    if path.endswith(".wav"):
      file_size = os.path.getsize(path)
      with wave.open(path, 'rb') as audio_file:
        total_length += file_size / (audio_file.getnchannels() * audio_file.getsampwidth()) / audio_file.getframerate()
  return total_length


broadcastnews_total_minuntes = get_total_length_of_audio(
    broadcastnews_combined.file_name
) / 60


kenspeech_total_minuntes = get_total_length_of_audio(
    kenspeech_data.file_name
) / 60


babel_total_minuntes = get_total_length_of_audio(
    babel_data.file_name
) / 60 

print(f"KenSpeech Total Audio: {kenspeech_total_minuntes / 60}")
print(f"Babel Total Audio: {babel_total_minuntes / 60}")
print(f"Broadcast News Total Audio: {broadcastnews_total_minuntes/ 60}")

In [None]:
kenspeech_and_babel_data = pd.concat([kenspeech_data, babel_data])
combined_dataset = pd.concat([kenspeech_and_babel_data, broadcastnews_combined])

In [None]:
combined_total_minuntes = get_total_length_of_audio(
    combined_dataset.file_name
) / 60

print(combined_total_minuntes / 60)

combined_dataset.to_csv('datasets/swahili/metadata.csv', index=False)

## Test Data Loading Tool

In [None]:
from datasets import load_dataset
dataset = load_dataset("audiofolder", data_dir="datasets/swahili", drop_metadata=True)

In [None]:
next(iter(dataset['train']))

In [None]:
dataset['train']