Let's build our custom WhisperX-like program that is optimized for our code-switching data.

# Step 0: Installations

MFA: https://montreal-forced-aligner.readthedocs.io/en/stable/installation.html

In [None]:
# It's been a while since I've installed MFA but I think I did this.
# Run these commands in your terminal. It'll create a conda environment called 'aligner'.
# Then come back to tihs notebook and select that as your kernel for this notebook

!conda install -c conda-forge mamba
!mamba create -n aligner -c conda-forge montreal-forced-aligner

WhisperX: https://github.com/m-bain/whisperX

In [1]:
!pip install whisperx

Collecting whisperx
  Downloading whisperx-3.7.4-py3-none-any.whl.metadata (16 kB)
Collecting ctranslate2>=4.5.0 (from whisperx)
  Downloading ctranslate2-4.6.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (10 kB)
Collecting faster-whisper>=1.1.1 (from whisperx)
  Downloading faster_whisper-1.2.1-py3-none-any.whl.metadata (16 kB)
Collecting nltk>=3.9.1 (from whisperx)
  Using cached nltk-3.9.2-py3-none-any.whl.metadata (3.2 kB)
Collecting pandas<2.3.0,>=2.2.3 (from whisperx)
  Downloading pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting av<16.0.0 (from whisperx)
  Using cached av-15.1.0-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Collecting numpy<2.3.0,>=2.1.0 (from whisperx)
  Using cached numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting pyannote-audio<4.0.0,>=3.3.2 (from whisperx)
  Using cached pyannote_audio-3.4.0-py2.py3-none-any.whl.metadata (11

Other libraries:

In [7]:
!pip install torch dotenv lingua-language-detector textgrid


Collecting textgrid
  Downloading TextGrid-1.6.1.tar.gz (9.4 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: textgrid
  Building wheel for textgrid (pyproject.toml) ... [?25ldone
[?25h  Created wheel for textgrid: filename=textgrid-1.6.1-py3-none-any.whl size=10217 sha256=64ff2667c529bd82d87005cd24beeb410a1bf3f445ef506cb259a9b2f2af58a4
  Stored in directory: /home/chengyi/.cache/pip/wheels/cf/06/ab/5166b15996f143ff63554ec508d9b52cbe0bb0b82e2a926446
Successfully built textgrid
Installing collected packages: textgrid
Successfully installed textgrid-1.6.1


In [3]:
from dotenv import load_dotenv
import os

load_dotenv()
HF_TOKEN = os.getenv("HF_KEY")

In [4]:
# If you want to use a GPU for faster transcriptions (e.g. if you have Colab kernel connected)
import torch

if torch.cuda.is_available():
    print("CUDA is available! Using GPU.")
    device = torch.device("cuda")
    # You can also get more information about the GPU
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available. Falling back to CPU.")
    device = "cpu"

print(f"Current device: {device}")

CUDA is not available. Falling back to CPU.
Current device: cpu


# Step 1: WhisperX

In [None]:
import whisperx
import gc

audio_file = '../input/DINA1_PS1_IDS1.wav'
batch_size = 16
compute_type = "float32"
model_name = "medium" # Empirically I've seen that medium performs best in getting the utterance timings!

# 1. Transcribe with original whisper (batched) 13 mins on CPU
model = whisperx.load_model(model_name, device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size, language='es') # NOTE: For some reason on code-switching audio it's best to set it to the less dominant language

# # Optional: Save to disk
# import json
# import os
# basename = os.path.basename(audio_file).split('.')[0]
# with open(f"../output/{basename}_WhisperX1.json", "w") as f:
#     json.dump(result, f, indent=4)

print(result["segments"]) # before alignment

In [None]:
# 2. Align whisper output 4 mins on CPU
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

# # Optional: Save to disk
# import json
# import os
# basename = os.path.basename(audio_file).split('.')[0]
# with open(f"../output/{basename}_WhisperX2.json", "w") as f:
#     json.dump(result, f, indent=4)

print(result["segments"]) # after alignment

In [None]:
# 3. Assign speaker labels 52 mins on CPU
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)

# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
diarize_model(audio, min_speakers=1, max_speakers=4)
result = whisperx.assign_word_speakers(diarize_segments, result)

# # Optional: Save to disk
# import json
# import os
# basename = os.path.basename(audio_file).split('.')[0]
# with open(f"../output/{basename}_WhisperX3.json", "w") as f:
#     json.dump(result, f, indent=4)


print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs

# Step 2: MFA

Because Whisper's word level timestamps suck, we use MFA additional to get better timestamps

In [253]:
# NOTE Development: just doing this to avoid having to run step 1 all over again
import json
with open("../output/dina1_playsesh_whisperX_medium_final.json", "r") as f:
    result = json.load(f)
audio_file = '../input/DINA1_PS1_IDS1.wav'

In [254]:
# 1. Split utterances by language

# From another script I made: praat/detect_language.py
import os
import sys
import pandas as pd
from lingua import Language, LanguageDetectorBuilder
def detect_language(detector, text):
    # Detect the language of the given text through punctuation or through a model
    if any(char in text for char in "¡¿áéíóúñüÁÉÍÓÚÑÜ"):
        return "Spanish"
    result = detector.detect_language_of(text)
    if result == Language.ENGLISH:
        return "English"
    elif result == Language.SPANISH:
        return "Spanish"
    return "Unknown"


# Initialize the language detector for English and Spanish
languages = [Language.ENGLISH, Language.SPANISH]
detector = LanguageDetectorBuilder.from_languages(*languages).build()

# Iterate and detect
for segment in result['segments']:
    if segment['text']:
        lang = detect_language(detector, segment['text'].lower())
        segment['language'] = lang

# # Optional: Save to disk
# import json
# import os
# basename = os.path.basename(audio_file).split('.')[0]
# with open(f"../output/{basename}_WhisperX3.json", "w") as f:
#     json.dump(result, f, indent=4)

print(result["segments"]) # segments are now assigned languages


[{'start': 4.283, 'end': 6.886, 'text': ' Okay, so grab it to myself.', 'words': [{'word': 'Okay,', 'start': 4.283, 'end': 5.344, 'score': 0.283, 'speaker': 'SPEAKER_01'}, {'word': 'so', 'start': 5.364, 'end': 5.464, 'score': 0.297, 'speaker': 'SPEAKER_01'}, {'word': 'grab', 'start': 5.485, 'end': 5.665, 'score': 0.421, 'speaker': 'SPEAKER_01'}, {'word': 'it', 'start': 6.486, 'end': 6.606, 'score': 0.266, 'speaker': 'SPEAKER_01'}, {'word': 'to', 'start': 6.626, 'end': 6.666, 'score': 0.576, 'speaker': 'SPEAKER_01'}, {'word': 'myself.', 'start': 6.726, 'end': 6.886, 'score': 0.231, 'speaker': 'SPEAKER_01'}], 'speaker': 'SPEAKER_01', 'language': 'English'}, {'start': 6.906, 'end': 12.894, 'text': 'Oh my gosh, I feel like a vlogger.', 'words': [{'word': 'Oh', 'start': 6.906, 'end': 6.946, 'score': 0.004, 'speaker': 'SPEAKER_01'}, {'word': 'my', 'start': 6.966, 'end': 7.006, 'score': 0.015, 'speaker': 'SPEAKER_01'}, {'word': 'gosh,', 'start': 7.026, 'end': 7.787, 'score': 0.451, 'speaker':

In [3]:
# 2. Setting up for MFA

# # NOTE: Before doing any MFA ensure that you cleared your cache: Delete Documents/MFA
# !rm -rf ~/Documents/MFA
print("Cleared cache")

# # Install alignment models
# !mfa model download --ignore_cache acoustic english_us_arpa
# !mfa model download --ignore_cache dictionary english_us_arpa
# !mfa model download --ignore_cache acoustic spanish_mfa
# !mfa model download --ignore_cache dictionary spanish_mfa
print("Models downloaded")

# Create a textgrid with tier 0 being utterances, tier 1 being language
from textgrid import TextGrid, IntervalTier
tg = TextGrid()
utterances_tier = IntervalTier(name="WhisperX - Utterances", minTime=tg.minTime, maxTime=tg.maxTime)
languages_tier = IntervalTier(name="Lingua - Language", minTime=tg.minTime, maxTime=tg.maxTime)
for segment in result['segments']:
  utterances_tier.add(segment['start'], segment['end'], segment['text'])
  languages_tier.add(segment['start'], segment['end'], segment['language'])
tg.append(utterances_tier)
tg.append(languages_tier)
print("Created utterances and language textgrid")

# Split that TextGrid into an English and Spanish one separately
languages = set([interval.mark for interval in languages_tier.intervals]) - set([''])
languages2tier = {}
for language in languages:
  tier_name = f"{language} Utterances"
  new_tier = IntervalTier(name=tier_name, minTime=tg.minTime, maxTime=tg.maxTime)
  for utterance_interval, language_interval in zip(utterances_tier.intervals, languages_tier.intervals):
    text = utterance_interval.mark if language_interval.mark == language else None
    if text:
      new_tier.add(utterance_interval.minTime, utterance_interval.maxTime, text)
  tg.append(new_tier)
  languages2tier[tier_name] = new_tier
print("Finished splitting languages into separate tiers")

# Output these files
english_path = '../chengyi-mfa/input/english/'
spanish_path = '../chengyi-mfa/input/spanish/'
for key in languages2tier.keys(): # Assuming you just have English & Spanish
  new_tg = TextGrid()
  new_tg.append(languages2tier[key])
  if key == "English Utterances":
    new_tg.write(os.path.join(english_path, os.path.basename(audio_file).replace('.wav', '.TextGrid')))
  elif key == "Spanish Utterances":
    new_tg.write(os.path.join(spanish_path, os.path.basename(audio_file).replace('.wav', '.TextGrid')))
print("Output files created")

# Make copies of the audio file into those directories
!cp "$audio_file" "$english_path"
!cp "$audio_file" "$spanish_path"
print("Copied audio files, script done")

Cleared cache
Models downloaded
Created utterances and language textgrid
Finished splitting languages into separate tiers
Output files created
Copied audio files, script done


In [19]:
# 2.5 validate to ensure correct form (Each file took me raound 3 minutes)
!mfa validate "$english_path" english_us_arpa english_us_arpa
!mfa validate "$spanish_path" spanish_mfa spanish_mfa

[2;36m [0m[32mINFO    [0m Setting up corpus information[33m...[0m                                      
[2;36m [0m[32mINFO    [0m Loading corpus from source files[33m...[0m                                   
[2K[35m   1%[0m [38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1/100 [0m [ [33m0:00:01[0m < [36m-:--:--[0m , [31m? it/s[0m ]
[?25h[2;36m [0m[32mINFO    [0m Found [1;36m1[0m speaker across [1;36m1[0m file, average number of utterances per       
[2;36m [0m         speaker: [1;36m268.0[0m                                                        
[2;36m [0m[32mINFO    [0m Initializing multiprocessing jobs[33m...[0m                                  
[2;36m [0m         MFA will only use [1;36m1[0m jobs. Use the --single_speaker flag if you would  
[2;36m [0m         like to split utterances across jobs regardless of their speaker.     
[2;36m [0m[32mINFO    [0m Normalizing text[33m...[0m                                        

In [6]:
# 3. Run MFA on each of those language utterances separately
output_path = "../chengyi-mfa/output/"
!mfa align "$english_path" english_us_arpa english_us_arpa "$output_path""english/"
!mfa align "$spanish_path" spanish_mfa spanish_mfa "$output_path""spanish/"

[2;36m [0m[32mINFO    [0m Setting up corpus information[33m...[0m                                      
[2;36m [0m[32mINFO    [0m Loading corpus from source files[33m...[0m                                   
[2K[35m   1%[0m [38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1/100 [0m [ [33m0:00:01[0m < [36m-:--:--[0m , [31m? it/s[0m ]
[?25h[2;36m [0m[32mINFO    [0m Found [1;36m1[0m speaker across [1;36m1[0m file, average number of utterances per       
[2;36m [0m         speaker: [1;36m268.0[0m                                                        
[2;36m [0m[32mINFO    [0m Initializing multiprocessing jobs[33m...[0m                                  
[2;36m [0m         MFA will only use [1;36m1[0m jobs. Use the --single_speaker flag if you would  
[2;36m [0m         like to split utterances across jobs regardless of their speaker.     
[2;36m [0m[32mINFO    [0m Normalizing text[33m...[0m                                        

# Step 3: Final adjustments

In [255]:
from textgrid import TextGrid, IntervalTier
import pandas as pd

# 1. Read textgrids
english_tg = TextGrid()
spanish_tg = TextGrid()
english_tg.read("../chengyi-mfa/output/english/DINA1_PS1_IDS1.TextGrid")
spanish_tg.read("../chengyi-mfa/output/spanish/DINA1_PS1_IDS1.TextGrid")

# 2. Grab intervals
english_intervals = [interval for interval in english_tg[0]]
spanish_intervals = [interval for interval in spanish_tg[0]]

# 3. Convert these intervals into a dataframe for easier manipulation
intervals_df = []
for interval in english_intervals + spanish_intervals:
  intervals_df.append({
    "start": interval.minTime,
    "end": interval.maxTime,
    "text": interval.mark,
    "type": "word (mfa)",
  })

# 3.5 Add WhisperX utterances too for order
for segment in result['segments']:
  intervals_df.append({
    "start": segment["start"],
    "end": segment["end"],
    "text": segment['text'],
    "type": "utterance (whisper)"
  })

# 3.75 Create dataframe
intervals_df = pd.DataFrame(intervals_df)
intervals_df = intervals_df.sort_values("start")
intervals_df = intervals_df.reset_index()
intervals_df = intervals_df.drop(columns="index")
intervals_df = intervals_df[intervals_df["text"] != ""]
intervals_df.head(10)

Unnamed: 0,start,end,text,type
2,4.283,6.886,"Okay, so grab it to myself.",utterance (whisper)
3,4.353,4.593,okay,word (mfa)
4,4.593,4.813,so,word (mfa)
5,4.813,5.693,grab,word (mfa)
7,6.193,6.223,it,word (mfa)
8,6.223,6.283,to,word (mfa)
9,6.283,6.833,myself,word (mfa)
11,6.906,12.894,"Oh my gosh, I feel like a vlogger.",utterance (whisper)
12,6.986,7.326,oh,word (mfa)
14,7.356,7.606,my,word (mfa)


In [256]:
# 3.875 Put words under the utterance_intervals
utterance_intervals = []
current_utt = None
for _, row in intervals_df.iterrows():
    if row["type"] == "utterance (whisper)":
        # Start a new utterance group
        current_utt = {
            "start": row["start"],
            "end": row["end"],
            "text": row["text"],
            "words": []
        }
        utterance_intervals.append(current_utt)
    elif row["type"] == "word (mfa)" and current_utt is not None:
        # Attach word to the most recent utterance
        current_utt["words"].append({
            "start": row["start"],
            "end": row["end"],
            "word": row["text"]
        })

# 3.9375 Get MFA utterance times
for utterance_interval in utterance_intervals:
    if len(utterance_interval['words']) != 0:
        utterance_interval['start'] = utterance_interval['words'][0]['start']
        utterance_interval['end'] = utterance_interval['words'][-1]['end']

In [257]:
# 4. Adjust WhisperX with MFA output

# Check
print(len(utterance_intervals), len(result['segments']), "These should be equal")

# TODO: The current approach loses information like OOVs, confidence scores, and speaker
# OOVs: (i.e. whisper transcribes "vlogger" but MFA's model doesn't know how to transcribe that)
# Confidence scores: How confidence the Whisper model is at transcribing that word
# Speaker: SpeakerX's diarization model

# Anyway, we can still naively replace Whisper's trash word-level transcriptions with MFA's
for utterance_interval, result_interval in zip(utterance_intervals, result['segments']):
  result_interval['start'] = utterance_interval['start']
  result_interval['end'] = utterance_interval['end']
  result_interval['words'] = utterance_interval['words']
print("Done changing utterance segments")
result["word_segments"] = []
for result_interval in result['segments']:
  for word_interval in result_interval['words']:
    result["word_segments"].append(word_interval)
print("Done changing word segments")

312 312 These should be equal
Done changing utterance segments
Done changing word segments


In [224]:
result.keys()

dict_keys(['segments', 'word_segments'])

`'segments'` and `'word_segments'` now have a list of segments that were first transcribed by WhisperX and then adjusted with MFA.

# Step 4: Export into Praat

In [258]:
# Visualize format
print(result.keys())
for key in result.keys():
  print(result[key][0].keys())

dict_keys(['segments', 'word_segments'])
dict_keys(['start', 'end', 'text', 'words', 'speaker', 'language'])
dict_keys(['start', 'end', 'word'])


In [259]:
# Check for overlaps (utterances)
prev = 0
for index, utterance in enumerate(result['segments']):
  overlap = prev > utterance['start']
  if overlap:
    resolved = False
    previous_segment = result['segments'][index-1]
    current_segment = result['segments'][index]

    # Case: In this particular file, it seems that we don't lose any information if we just remove the current one
    if remove_punctuation(previous_segment['words'][-1]['word'].lower()) == remove_punctuation(current_segment['text'].split()[0].lower()):
      del result['segments'][index]
      resolved = True

    # Case: OOV item (i.e. MFA detects <unk> but whisper detects "Chekayo")
    if previous_segment['words'][-1]['word'] == "<unk>" and (len(current_segment['text'].split()) == 1):
      result['segments'][index-1]['end'] = current_segment['end']
      result['segments'][index-1]['words'][-1]['end'] = current_segment['end']
      result['segments'][index-1]['words'][-1]['word'] = remove_punctuation(current_segment['text'].lower())
      del result['segments'][index]
      resolved = True

    # Debug
    if resolved == False:
      print("Overlap detected")
      print("Previous:", previous_segment)
      print("Current:", current_segment)
      print()

  # New previous
  prev = utterance['end']

In [260]:
# Check for overlaps (words)
prev = 0
for index, word in enumerate(result['word_segments']):
  overlap = prev > word['start']
  if overlap:
    print("Overlap detected")
    print(result['segments'][index])
    print(result['segments'][index-1])
    print()
  prev = word['end']

In [261]:
from textgrid import TextGrid, IntervalTier

# Initialize
output_tg_path = "../output/WhisperX_MFA_Demo.TextGrid"
final_tg = TextGrid()
utterances_tier = IntervalTier(name="Utterances", minTime=final_tg.minTime, maxTime=final_tg.maxTime)
languages_tier = IntervalTier(name="Languages", minTime=final_tg.minTime, maxTime=final_tg.maxTime)
words_tier = IntervalTier(name="Words", minTime=final_tg.minTime, maxTime=final_tg.maxTime)

# Append onto tiers
nudge = 0.001 # For overlaps
for segment in result['segments']:

  try:
    utterances_tier.add(segment['start'], segment['end'], segment['text'])
    languages_tier.add(segment['start'], segment['end'], segment['language'])
  except ValueError:
    utterances_tier.add(segment['start'] + nudge, segment['end'], segment['text'])
    languages_tier.add(segment['start'] + nudge, segment['end'], segment['language'])

for word in result['word_segments']:
  try:
    words_tier.add(word['start'], word['end'], word['word'])
  except ValueError:
    words_tier.add(word['start'] + nudge, word['end'], word['word'])

# Write
final_tg.append(utterances_tier)
final_tg.append(languages_tier)
final_tg.append(words_tier)
final_tg.write(output_tg_path)
print("Done")

Done


# Future directions

Action items:
* Bulk transcriptions.
* Step 4's extra information.
* An interface for checking the timestamps & adjusting them (Praat). 
* A fine-tuned version of Whisper on code-switched data may work better as it will be able to handle code-switching within utterances (for example, it will have the vocabulary to do so).
* Multiple passes of the WhisperX -> MFA cycle might produce better results.