# 0.0.0 WhisperX/Pyannote Transcription+Diarization Pipeline 

This Jupyter notebook is designed to test and evaluate a new Transcription and Diarization Pipeline with the following objectives:
1. Achieving word-level transcription accuracy to ensure detailed and precise text representation of the audio input.
2. Assessing diarization confidence levels to accurately attribute spoken segments to different speakers and measure the reliability of speaker identification.
3. Enhancing the alignment of transcriptions to be closer to natural sentence segments, thereby improving the readability and usability of the transcribed data.

The notebook leverages advanced transcription and diarization capabilities provided by the Whisper, WhisperX, and pyannote libraries. By using GPU acceleration, it processes audio data efficiently, performing alignment and diarization to produce structured outputs that are saved in CSV format for further analysis. The resources and installation instructions are included to facilitate the setup and execution of the pipeline.

Resources:
https://towardsdatascience.com/unlock-the-power-of-audio-data-advanced-transcription-and-diarization-with-whisper-whisperx-and-ed9424307281 

# 0.1 Setup
WhisperX documentation found here: https://github.com/m-bain/whisperX
================================================
1. Install Git
2. Install FFMPEG and add to PATH
3. Install Anaconda 

================================================   
4. Create Conda environment
conda create -n whisperxtranscription-env python=3.10
conda activate whisperxtranscription-env

5. Install PyTorch https://pytorch.org/get-started/locally/ 
pip install numpy==1.26.3 torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121

6. Install WhisperX repository and additional packages
pip install whisperx speechbrain jupyter ipywidgets charset-normalizer pandas nltk plotly matplotlib webvtt-py pypi-json srt python-dotenv

7. Create .env file at the same level as this notebook file with the following line
HF_TOKEN="REPLACEWITHHUGGINGFACETOKENHERE"

=================================================
8. For GPU usage :
Install Visual Studio Community https://visualstudio.microsoft.com/downloads/
Install NVIDIA CUDA Toolkit 12.1 https://developer.nvidia.com/cuda-12-1-0-download-archive 

Check PyTorch and CUDA installation
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


=================================================
Fix Numpy
pip uninstall numpy -y
pip install numpy==1.26.3

Fix PyTorch
pip uninstall torch torchvision torchaudio -y


In [1]:
import numpy
print(numpy.__version__)


1.26.4


In [None]:
%pip install git+https://github.com/m-bain/whisperx.git

# 0.2 Check once to see if CUDA GPU is available and PyTorch is working properly

In [1]:
# Check if CUDA GPU is available to PyTorch
import torch                                                # PyTorch
torch.cuda.set_device(0)                                    # Set the main GPU as device to use if present
print(torch.__version__)
torch.cuda.is_available(),torch.cuda.get_device_name()      # Check if GPU is available and get the name of the GPU

2.3.0+cu121


(True, 'NVIDIA GeForce RTX 4060 Laptop GPU')

# 1.0 Setup - Start here by adjusting variables
1. choose batch size, compute type, whisper model, and file extension to transcribe

In [2]:
import os
from tkinter import Tk, filedialog
import pandas as pd
import warnings
import torch
import whisperx
import gc
import datetime
import json
import webvtt
import logging


# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger("speechbrain.utils.quirks").setLevel(logging.WARNING)

# Configuration
torch.cuda.set_device(0)  # Set GPU
device = "cuda" if torch.cuda.is_available() else "cpu" # Set device to GPU if available, otherwise CPU
language = "en"  # Set the language code en=English, es=Spanish, etc.
task = "transcribe"  # Set the task to "transcribe" or "translate" 
batch_size = 16 # Set the batch size for processing
compute_type = "float16" # Set the compute type to "float16" for faster processing
hf_token = os.getenv('HF_TOKEN') 
whisperx_model = "large-v3-turbo" # Set the WhisperX model to use
extensions = ['.WAV', '.m4a', '.mp3'] # Supported audio file extensions


# 2.0 Run - after adjusting variables first

Just push run here. You shouldn't need to change anything here unless you want to output less or more file types. These are mostly functions which are then called at the end of the cell.

1. You should get a popup asking to choose the folder where the files are found (It will also search subfolders).

2. You should then get a popup asking for where the transcription files should be placed (It will replicate the folder structure in which they were found)

3. You will also see a popup asking if you want to anonymize with a pseudonyms.csv file, and if so where it is located.

4. You should then see an output similar to the following (just ignore the warnings):

Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.3.0+cu121. Bad things might happen unless you revert torch to 1.x.

5. When complete you will see where each were written and the folders where they were written to.


In [None]:
from tkinter import Tk, filedialog, messagebox

# Functions
def find_audio_files(base_dir, extensions):
    audio_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if any(file.endswith(ext) for ext in extensions):
                audio_files.append(os.path.join(root, file))
    return audio_files

def anonymize_text(text, pseudonym_dict):
    for real_name, pseudonym in pseudonym_dict.items():
        text = text.replace(real_name, pseudonym)
    return text

def format_vtt_timestamp(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = int((seconds % 1) * 1000)
    return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{milliseconds:03}"

def save_transcripts(segments, output_dir, relative_path, pseudonym_dict=None):
    if pseudonym_dict:
        for segment in segments:
            segment['text'] = anonymize_text(segment['text'], pseudonym_dict)
    for i, segment in enumerate(segments):
        segment['sentence_number'] = i + 1
    df = pd.DataFrame(segments)
    df['text'] = df['text'].apply(lambda x: x.lstrip())
    cols = ['sentence_number'] + [col for col in df.columns if col != 'sentence_number']
    df = df[cols]

    os.makedirs(output_dir, exist_ok=True)
    base_filename = os.path.splitext(os.path.basename(relative_path))[0]
    csv_path = os.path.join(output_dir, f"{base_filename}_transcription.csv")
    df.to_csv(csv_path, index=False, encoding='utf-8-sig')

    with open(os.path.join(output_dir, f"{base_filename}_transcription.txt"), 'w', encoding='utf-8') as f:
        for segment in segments:
            f.write(f"{segment['text'].strip()}\n")

    json_path = os.path.join(output_dir, f"{base_filename}_transcription.json")
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(segments, f, ensure_ascii=False, indent=4)

    vtt = webvtt.WebVTT()
    for segment in segments:
        caption = webvtt.Caption()
        caption.start = format_vtt_timestamp(segment['start'])
        caption.end = format_vtt_timestamp(segment['end'])
        caption.lines = [f"{segment['sentence_number']}: {segment['text'].strip()}"]
        vtt.captions.append(caption)
    vtt.save(os.path.join(output_dir, f"{base_filename}_transcription.vtt"))

def process_audio_file(audio_file, base_output_dir, relative_path, pseudonym_dict=None):
    try:
        print(f"Processing {audio_file}...")
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(whisperx_model, device, compute_type=compute_type)
        result = model.transcribe(audio, batch_size=batch_size, language=language, task=task)
        del model
        gc.collect()
        torch.cuda.empty_cache()

        model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
        result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
        del model_a
        gc.collect()
        torch.cuda.empty_cache()

        diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
        diarize_segments = diarize_model(audio)
        result = whisperx.assign_word_speakers(diarize_segments, result)

        output_dir = os.path.join(base_output_dir, os.path.dirname(relative_path))
        save_transcripts(result["segments"], output_dir, relative_path, pseudonym_dict)
    except Exception as e:
        print(f"Error processing {audio_file}: {e}")

def main():
    # Initialize Tkinter
    root = Tk()
    root.withdraw()  # Hide the main window
    
    # Bring the root window to the front
    root.attributes('-topmost', True)

    # Popup for input folder
    input_folder = filedialog.askdirectory(title="Select Folder Containing Audio/Video Files")
    if not input_folder:
        print("No folder selected. Exiting.")
        return

    # Popup for output folder
    output_folder = filedialog.askdirectory(title="Select Folder to Save Transcriptions")
    if not output_folder:
        print("No output folder selected. Exiting.")
        return

    # Ask if a pseudonyms.csv file will be used
    use_pseudonyms = messagebox.askyesno("Pseudonyms", "Will you use a pseudonyms.csv file for to anonymize the transcripts?")
    pseudonym_dict = None

    if use_pseudonyms:
        pseudonyms_file = filedialog.askopenfilename(
            title="Select Pseudonyms CSV File",
            filetypes=[("CSV files", "*.csv")]
        )
        if not pseudonyms_file:
            print("No pseudonyms file selected. Continuing without pseudonymization.")
        else:
            # Load the pseudonyms file
            pseudonyms_df = pd.read_csv(pseudonyms_file)
            pseudonym_dict = dict(zip(pseudonyms_df['name'], pseudonyms_df['pseudonym']))
            print(f"Pseudonyms loaded from {pseudonyms_file}.")

    # Find and process audio files
    audio_files = find_audio_files(input_folder, extensions)
    print(f"Found {len(audio_files)} files to process.")

    for audio_file in audio_files:
        relative_path = os.path.relpath(audio_file, input_folder)
        process_audio_file(audio_file, output_folder, relative_path, pseudonym_dict)
        print(f"Processed {audio_file}")

    print("All files processed.")

if __name__ == "__main__":
    main()


# Combined Script

In [None]:
import os
from tkinter import Tk, filedialog
import pandas as pd
import warnings
import torch
import whisperx
import gc
import datetime
import json
import webvtt
from tkinter import Tk, filedialog, messagebox

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Configuration
torch.cuda.set_device(0)  # Set GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
language = "en"  # Set the language code en=English, es=Spanish, etc.
task = "transcribe"  # Set the task to "translate"
batch_size = 16
compute_type = "float16"
hf_token = os.getenv('HF_TOKEN')
whisperx_model = "large-v3-turbo"
extensions = ['.WAV', '.m4a', '.mp3']


# Functions
def find_audio_files(base_dir, extensions):
    audio_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if any(file.endswith(ext) for ext in extensions):
                audio_files.append(os.path.join(root, file))
    return audio_files

def anonymize_text(text, pseudonym_dict):
    for real_name, pseudonym in pseudonym_dict.items():
        text = text.replace(real_name, pseudonym)
    return text

def format_vtt_timestamp(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = int((seconds % 1) * 1000)
    return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{milliseconds:03}"

def save_transcripts(segments, output_dir, relative_path, pseudonym_dict=None):
    if pseudonym_dict:
        for segment in segments:
            segment['text'] = anonymize_text(segment['text'], pseudonym_dict)
    for i, segment in enumerate(segments):
        segment['sentence_number'] = i + 1
    df = pd.DataFrame(segments)
    df['text'] = df['text'].apply(lambda x: x.lstrip())
    cols = ['sentence_number'] + [col for col in df.columns if col != 'sentence_number']
    df = df[cols]

    os.makedirs(output_dir, exist_ok=True)
    base_filename = os.path.splitext(os.path.basename(relative_path))[0]
    csv_path = os.path.join(output_dir, f"{base_filename}_transcription.csv")
    df.to_csv(csv_path, index=False, encoding='utf-8-sig')

    with open(os.path.join(output_dir, f"{base_filename}_transcription.txt"), 'w', encoding='utf-8') as f:
        for segment in segments:
            f.write(f"{segment['text'].strip()}\n")

    json_path = os.path.join(output_dir, f"{base_filename}_transcription.json")
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(segments, f, ensure_ascii=False, indent=4)

    vtt = webvtt.WebVTT()
    for segment in segments:
        caption = webvtt.Caption()
        caption.start = format_vtt_timestamp(segment['start'])
        caption.end = format_vtt_timestamp(segment['end'])
        caption.lines = [f"{segment['sentence_number']}: {segment['text'].strip()}"]
        vtt.captions.append(caption)
    vtt.save(os.path.join(output_dir, f"{base_filename}_transcription.vtt"))

def process_audio_file(audio_file, base_output_dir, relative_path, pseudonym_dict=None):
    try:
        print(f"Processing {audio_file}...")
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(whisperx_model, device, compute_type=compute_type)
        result = model.transcribe(audio, batch_size=batch_size, language=language, task=task)
        del model
        gc.collect()
        torch.cuda.empty_cache()

        model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
        result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
        del model_a
        gc.collect()
        torch.cuda.empty_cache()

        diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
        diarize_segments = diarize_model(audio)
        result = whisperx.assign_word_speakers(diarize_segments, result)

        output_dir = os.path.join(base_output_dir, os.path.dirname(relative_path))
        save_transcripts(result["segments"], output_dir, relative_path, pseudonym_dict)
    except Exception as e:
        print(f"Error processing {audio_file}: {e}")

def main():
    # Initialize Tkinter
    root = Tk()
    root.withdraw()  # Hide the main window
    
    # Bring the root window to the front
    root.attributes('-topmost', True)

    # Popup for input folder
    input_folder = filedialog.askdirectory(title="Select Folder Containing Audio/Video Files")
    if not input_folder:
        print("No folder selected. Exiting.")
        return

    # Popup for output folder
    output_folder = filedialog.askdirectory(title="Select Folder to Save Transcriptions")
    if not output_folder:
        print("No output folder selected. Exiting.")
        return

    # Ask if a pseudonyms.csv file will be used
    use_pseudonyms = messagebox.askyesno("Pseudonyms", "Will you use a pseudonyms.csv file for to anonymize the transcripts?")
    pseudonym_dict = None

    if use_pseudonyms:
        pseudonyms_file = filedialog.askopenfilename(
            title="Select Pseudonyms CSV File",
            filetypes=[("CSV files", "*.csv")]
        )
        if not pseudonyms_file:
            print("No pseudonyms file selected. Continuing without pseudonymization.")
        else:
            # Load the pseudonyms file
            pseudonyms_df = pd.read_csv(pseudonyms_file)
            pseudonym_dict = dict(zip(pseudonyms_df['name'], pseudonyms_df['pseudonym']))
            print(f"Pseudonyms loaded from {pseudonyms_file}.")

    # Find and process audio files
    audio_files = find_audio_files(input_folder, extensions)
    print(f"Found {len(audio_files)} files to process.")

    for audio_file in audio_files:
        relative_path = os.path.relpath(audio_file, input_folder)
        process_audio_file(audio_file, output_folder, relative_path, pseudonym_dict)
        print(f"Processed {audio_file}")

    print("All files processed.")

if __name__ == "__main__":
    main()


# Attempted Gemini Revision Test

In [None]:
import os
from tkinter import Tk, filedialog, messagebox
import pandas as pd
import warnings
import torch
import whisperx
import gc
import json
import webvtt
import logging
import argparse
from tqdm import tqdm
import configparser  # Example for config file

# --- Configuration (using configparser as an example) ---
config = configparser.ConfigParser()
config.read('config.ini')

DEVICE = config['whisper']['device']
LANGUAGE = config['whisper']['language']
WHISPERX_MODEL = config['whisper']['model']
BATCH_SIZE = config['processing'].getint('batch_size')
COMPUTE_TYPE = config['processing']['compute_type']
EXTENSIONS_STR = config['files']['extensions']
EXTENSIONS = [ext.strip() for ext in EXTENSIONS_STR.split(',')]
HF_TOKEN = config['auth']['hf_token']
TASK = config['whisper']['task']
CHUNK_SIZE=30


# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Functions ---
def find_audio_files(base_dir, extensions):
    """Finds all audio files with specified extensions in a directory and its subdirectories."""
    audio_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if any(file.endswith(ext) for ext in extensions):
                audio_files.append(os.path.join(root, file))
    return audio_files

def anonymize_text(text, pseudonym_dict):
    """Replaces real names in text with pseudonyms from a dictionary."""
    if pseudonym_dict:
        for real_name, pseudonym in pseudonym_dict.items():
            text = text.replace(real_name, pseudonym)
    return text

def format_vtt_timestamp(seconds):
    """Formats a time in seconds to the WebVTT timestamp format (HH:MM:SS.mmm)."""
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = int((seconds % 1) * 1000)
    return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{milliseconds:03}"

def save_transcripts(segments, output_dir, relative_path, pseudonym_dict=None):
    """Saves the transcription segments to CSV, TXT, JSON, and VTT files."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    base_filename = os.path.splitext(os.path.basename(relative_path))[0]

    # Anonymize text if a pseudonym dictionary is provided
    processed_segments = []
    for i, segment in enumerate(segments):
        segment['text'] = segment['text'].lstrip()
        if pseudonym_dict:
            segment['text'] = anonymize_text(segment['text'], pseudonym_dict)
        segment['sentence_number'] = i + 1
        processed_segments.append(segment.copy())  # Avoid modifying original list

    # Save to CSV
    df = pd.DataFrame(processed_segments)
    cols = ['sentence_number'] + [col for col in df.columns if col != 'sentence_number']
    df = df[cols]
    csv_path = os.path.join(output_dir, f"{base_filename}_transcription.csv")
    df.to_csv(csv_path, index=False, encoding='utf-8-sig')
    logging.info(f"Saved CSV transcription to {csv_path}")

    # Save to TXT
    txt_path = os.path.join(output_dir, f"{base_filename}_transcription.txt")
    with open(txt_path, 'w', encoding='utf-8') as f:
        for segment in processed_segments:
            f.write(f"{segment['text'].strip()}\n")
    logging.info(f"Saved TXT transcription to {txt_path}")

    # Save to JSON
    json_path = os.path.join(output_dir, f"{base_filename}_transcription.json")
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(processed_segments, f, ensure_ascii=False, indent=4)
    logging.info(f"Saved JSON transcription to {json_path}")

    # Save to VTT
    vtt = webvtt.WebVTT()
    for segment in processed_segments:
        caption = webvtt.Caption()
        caption.start = format_vtt_timestamp(segment['start'])
        caption.end = format_vtt_timestamp(segment['end'])
        caption.lines = [f"{segment['sentence_number']}: {segment['text'].strip()}"]
        vtt.captions.append(caption)
    vtt_path = os.path.join(output_dir, f"{base_filename}_transcription.vtt")
    vtt.save(vtt_path)
    logging.info(f"Saved VTT transcription to {vtt_path}")

def process_audio_file(audio_file, base_output_dir, relative_path, pseudonym_dict=None):
    """Processes a single audio file: transcribes, aligns, diarizes, and saves the output."""
    try:
        logging.info(f"Processing {audio_file}...")
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(WHISPERX_MODEL, DEVICE, compute_type=COMPUTE_TYPE)
        result = model.transcribe(audio, batch_size=BATCH_SIZE, language=LANGUAGE, task=TASK, chunk_size=CHUNK_SIZE)
        del model
        gc.collect()
        torch.cuda.empty_cache()

        model_a, metadata = whisperx.load_align_model(language_code=LANGUAGE, device=DEVICE)
        result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
        del model_a
        gc.collect()
        torch.cuda.empty_cache()

        diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=DEVICE)
        diarize_segments = diarize_model(audio)
        result = whisperx.assign_word_speakers(diarize_segments, result)

        output_dir = os.path.join(base_output_dir, os.path.dirname(relative_path))
        save_transcripts(result["segments"], output_dir, relative_path, pseudonym_dict)
        logging.info(f"Finished processing {audio_file}")
    except FileNotFoundError:
        logging.error(f"Audio file not found: {audio_file}")
    except RuntimeError as e:
        logging.error(f"CUDA error during processing of {audio_file}: {e}")
    except Exception as e:
        logging.error(f"An unexpected error occurred while processing {audio_file}: {e}")

def main():
    """Main function to handle user input and process audio files."""
    root = Tk()
    root.withdraw()
    root.attributes('-topmost', True)

    input_folder = filedialog.askdirectory(title="Select Folder Containing Audio/Video Files")
    if not input_folder:
        logging.info("No input folder selected. Exiting.")
        return

    output_folder = filedialog.askdirectory(title="Select Folder to Save Transcriptions")
    if not output_folder:
        logging.info("No output folder selected. Exiting.")
        return

    use_pseudonyms = messagebox.askyesno("Pseudonyms", "Will you use a pseudonyms.csv file to anonymize the transcripts?")
    pseudonym_dict = None

    if use_pseudonyms:
        pseudonyms_file = filedialog.askopenfilename(
            title="Select Pseudonyms CSV File",
            filetypes=[("CSV files", "*.csv")]
        )
        if not pseudonyms_file:
            logging.info("No pseudonyms file selected. Continuing without pseudonymization.")
        else:
            try:
                pseudonyms_df = pd.read_csv(pseudonyms_file)
                pseudonym_dict = dict(zip(pseudonyms_df['name'], pseudonyms_df['pseudonym']))
                logging.info(f"Pseudonyms loaded from {pseudonyms_file}.")
            except FileNotFoundError:
                messagebox.showerror("Error", f"Pseudonyms file not found: {pseudonyms_file}")
                logging.error(f"Pseudonyms file not found: {pseudonyms_file}")
            except KeyError:
                messagebox.showerror("Error", "Pseudonyms CSV file must have 'name' and 'pseudonym' columns.")
                logging.error("Pseudonyms CSV file must have 'name' and 'pseudonym' columns.")
            except Exception as e:
                messagebox.showerror("Error", f"Error reading pseudonyms file: {e}")
                logging.error(f"Error reading pseudonyms file: {e}")

    audio_files = find_audio_files(input_folder, EXTENSIONS)
    logging.info(f"Found {len(audio_files)} files to process.")

    for audio_file in tqdm(audio_files, desc="Processing Audio Files"):
        relative_path = os.path.relpath(audio_file, input_folder)
        process_audio_file(audio_file, output_folder, relative_path, pseudonym_dict)

    logging.info("All files processed.")

if __name__ == "__main__":
    main()

# GUI setup

In [5]:
import os
from tkinter import Tk, filedialog, messagebox  # Still need these for file/folder dialogs
import pandas as pd
import warnings
import torch
import whisperx
import gc
import json
import webvtt
import logging
from IPython.display import display
import ipywidgets as widgets

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
warnings.filterwarnings("ignore", category=UserWarning)

# --- GUI Widgets ---
device_options = ['cuda', 'cpu']
language_default = 'es'
model_options = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]
model_default = "large-v2"
task_options = ["transcribe", "translate"]
task_default = "transcribe"
batch_size_default = 16
compute_type_options = ["float16", "float32"]
compute_type_default = "float16"
extensions_default = '.WAV, .m4a, .mp3'
hf_token_default = os.getenv('HF_TOKEN') or ''

device_widget = widgets.Dropdown(options=device_options, description='Device:', value='cuda')
language_widget = widgets.Text(description='Language Code:', value=language_default)
model_widget = widgets.Dropdown(options=model_options, description='Model:', value=model_default)
task_widget = widgets.Dropdown(options=task_options, description='Task:', value=task_default)
batch_size_widget = widgets.IntSlider(min=1, max=64, step=1, description='Batch Size:', value=batch_size_default)
compute_type_widget = widgets.Dropdown(options=compute_type_options, description='Compute Type:', value=compute_type_default)
extensions_widget = widgets.Text(description='Extensions (comma-separated):', value=extensions_default)
hf_token_widget = widgets.Text(description='HuggingFace Token:', value=hf_token_default, placeholder='Optional')
input_folder_button = widgets.Button(description="Select Input Folder")
output_folder_button = widgets.Button(description="Select Output Folder")
pseudonyms_file_button = widgets.Button(description="Select Pseudonyms CSV (Optional)")
process_button = widgets.Button(description="Start Processing", button_style='primary')

input_folder_path = widgets.Label(value="")
output_folder_path = widgets.Label(value="")
pseudonyms_file_path = widgets.Label(value="")
pseudonym_dict = None

# --- Functions ---
def find_audio_files(base_dir, extensions):
    audio_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if any(file.endswith(ext) for ext in extensions.split(',')):
                audio_files.append(os.path.join(root, file))
    return audio_files

def anonymize_text(text, pseudonym_dict):
    if pseudonym_dict:
        for real_name, pseudonym in pseudonym_dict.items():
            text = text.replace(real_name, pseudonym)
    return text

def format_vtt_timestamp(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = int((seconds % 1) * 1000)
    return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{milliseconds:03}"

def save_transcripts(segments, output_dir, relative_path, pseudonym_dict=None):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    base_filename = os.path.splitext(os.path.basename(relative_path))[0]

    processed_segments = []
    for i, segment in enumerate(segments):
        segment['text'] = segment['text'].lstrip()
        if pseudonym_dict:
            segment['text'] = anonymize_text(segment['text'], pseudonym_dict)
        segment['sentence_number'] = i + 1
        processed_segments.append(segment.copy())

    df = pd.DataFrame(processed_segments)
    cols = ['sentence_number'] + [col for col in df.columns if col != 'sentence_number']
    df = df[cols]
    csv_path = os.path.join(output_dir, f"{base_filename}_transcription.csv")
    df.to_csv(csv_path, index=False)
    logging.info(f"Saved CSV transcription to {csv_path}")

    txt_path = os.path.join(output_dir, f"{base_filename}_transcription.txt")
    with open(txt_path, 'w', encoding='utf-8') as f:
        for segment in processed_segments:
            f.write(f"{segment['text'].strip()}\n")
    logging.info(f"Saved TXT transcription to {txt_path}")

    json_path = os.path.join(output_dir, f"{base_filename}_transcription.json")
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(processed_segments, f, ensure_ascii=False, indent=4)
    logging.info(f"Saved JSON transcription to {json_path}")

    vtt = webvtt.WebVTT()
    for segment in processed_segments:
        caption = webvtt.Caption()
        caption.start = format_vtt_timestamp(segment['start'])
        caption.end = format_vtt_timestamp(segment['end'])
        caption.lines = [f"{segment['sentence_number']}: {segment['text'].strip()}"]
        vtt.captions.append(caption)
    vtt_path = os.path.join(output_dir, f"{base_filename}_transcription.vtt")
    vtt.save(vtt_path)
    logging.info(f"Saved VTT transcription to {vtt_path}")

def process_audio_file(audio_file, base_output_dir, relative_path, pseudonym_dict=None, params=None):
    try:
        logging.info(f"Processing {audio_file}...")
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(params['model'], params['device'], compute_type=params['compute_type'])
        result = model.transcribe(audio, batch_size=params['batch_size'], language=params['language'], task=params['task'])
        del model
        gc.collect()
        torch.cuda.empty_cache()

        model_a, metadata = whisperx.load_align_model(language_code=params['language'], device=params['device'])
        result = whisperx.align(result["segments"], model_a, metadata, audio, params['device'], return_char_alignments=False)
        del model_a
        gc.collect()
        torch.cuda.empty_cache()

        diarize_model = whisperx.DiarizationPipeline(use_auth_token=params['hf_token'], device=params['device'])
        diarize_segments = diarize_model(audio)
        result = whisperx.assign_word_speakers(diarize_segments, result)

        output_dir = os.path.join(base_output_dir, os.path.dirname(relative_path))
        save_transcripts(result["segments"], output_dir, relative_path, pseudonym_dict)
        logging.info(f"Finished processing {audio_file}")
    except FileNotFoundError:
        logging.error(f"Audio file not found: {audio_file}")
    except RuntimeError as e:
        logging.error(f"CUDA error during processing of {audio_file}: {e}")
    except Exception as e:
        logging.error(f"An unexpected error occurred while processing {audio_file}: {e}")

# --- GUI Logic ---
def select_input_folder(b):
    root = Tk()
    root.withdraw()
    folder_selected = filedialog.askdirectory(title="Select Input Folder")
    if folder_selected:
        input_folder_path.value = folder_selected
    root.destroy()

def select_output_folder(b):
    root = Tk()
    root.withdraw()
    folder_selected = filedialog.askdirectory(title="Select Output Folder")
    if folder_selected:
        output_folder_path.value = folder_selected
    root.destroy()

def select_pseudonyms_file(b):
    root = Tk()
    root.withdraw()
    file_selected = filedialog.askopenfilename(
        title="Select Pseudonyms CSV File (Optional)",
        filetypes=[("CSV files", "*.csv")]
    )
    if file_selected:
        pseudonyms_file_path.value = file_selected
        global pseudonym_dict
        try:
            pseudonyms_df = pd.read_csv(file_selected)
            pseudonym_dict = dict(zip(pseudonyms_df['name'], pseudonyms_df['pseudonym']))
            logging.info(f"Pseudonyms loaded from {file_selected}.")
        except FileNotFoundError:
            messagebox.showerror("Error", f"Pseudonyms file not found: {file_selected}")
            logging.error(f"Pseudonyms file not found: {file_selected}")
            pseudonym_dict = None
        except KeyError:
            messagebox.showerror("Error", "Pseudonyms CSV file must have 'name' and 'pseudonym' columns.")
            logging.error("Pseudonyms CSV file must have 'name' and 'pseudonym' columns.")
            pseudonym_dict = None
        except Exception as e:
            messagebox.showerror("Error", f"Error reading pseudonyms file: {e}")
            logging.error(f"Error reading pseudonyms file: {e}")
            pseudonym_dict = None
    else:
        pseudonyms_file_path.value = ""
        pseudonym_dict = None

def start_processing(b):
    input_folder = input_folder_path.value
    output_folder = output_folder_path.value
    extensions = extensions_widget.value
    global pseudonym_dict

    if not input_folder:
        messagebox.showerror("Error", "Please select an input folder.")
        return
    if not output_folder:
        messagebox.showerror("Error", "Please select an output folder.")
        return

    params = {
        'device': device_widget.value,
        'language': language_widget.value,
        'model': model_widget.value,
        'task': task_widget.value,
        'batch_size': batch_size_widget.value,
        'compute_type': compute_type_widget.value,
        'hf_token': hf_token_widget.value
    }

    audio_files = find_audio_files(input_folder, extensions)
    logging.info(f"Found {len(audio_files)} files to process.")

    for audio_file in audio_files:
        relative_path = os.path.relpath(audio_file, input_folder)
        process_audio_file(audio_file, output_folder, relative_path, pseudonym_dict, params)

    messagebox.showinfo("Processing Complete", "All files processed.")
    logging.info("All files processed.")

# --- Widget Event Handling ---
input_folder_button.on_click(select_input_folder)
output_folder_button.on_click(select_output_folder)
pseudonyms_file_button.on_click(select_pseudonyms_file)
process_button.on_click(start_processing)

# --- Display the GUI ---
display(widgets.VBox([
    widgets.HBox([input_folder_button, input_folder_path]),
    widgets.HBox([output_folder_button, output_folder_path]),
    widgets.HBox([pseudonyms_file_button, pseudonyms_file_path]),
    device_widget,
    language_widget,
    model_widget,
    task_widget,
    batch_size_widget,
    compute_type_widget,
    extensions_widget,
    hf_token_widget,
    process_button
]))

VBox(children=(HBox(children=(Button(description='Select Input Folder', style=ButtonStyle()), Label(value=''))…