# Notebook for GUI of speech editing of audio files (with looping options for convenience)

In [2]:
import torch
from streamreader_audio_only import stream
import os

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import gc
import numpy as np


import os
from scipy.io import wavfile

CURR_DIR = os.getcwd()

from inference.tts.spec_denoiser import SpecDenoiserInfer
from utils.commons.hparams import set_hparams
import torchaudio
from dataclasses import dataclass
import matplotlib.pyplot as plt


import threading
import queue
import time
import pyaudio
import sys


import torch.multiprocessing as mp


import tkinter as tk
from PIL import Image, ImageTk


from IPython.display import clear_output

from g2p_en.expand import normalize_numbers

The .wav file should have a samplerate of 16000 and be mono. Use the following function if necessary to put your .wav file in the correct format and save it to your specified output path.

In [3]:
from pydub import AudioSegment
def correct_format_for_wav(input_wav_path, output_path):
    converted_audio = AudioSegment.from_wav(input_wav_path)
    converted_audio = converted_audio.set_frame_rate(16000).set_channels(1)
    converted_audio.export(output_path, format="wav")

In [6]:
# user needs to set up these
SAVE_AUDIO_DIR = CURR_DIR + "\\saved_audio"
if not os.path.exists(SAVE_AUDIO_DIR):
    os.mkdir(SAVE_AUDIO_DIR)
# Set this to true if you want to save all segments of audio and not just the flagged ones
SAVE_ALL_AUDIO = False
binary_data_directory = ".\\data\\processed\\binary\\libritts"
Espeak_dll_directory = "C:\Program Files\eSpeak NG\libespeak-ng.dll"
from phonemizer.backend.espeak.wrapper import EspeakWrapper

EspeakWrapper.set_library(Espeak_dll_directory)
whisperX_phoneme_model_directory = "..\\whisperX-main\\facebook"
whisperX_transcription_model_directory = None
data_queue = queue.Queue()
our_model_ckpt_path='checkpoints/spec_denoiser/model_ckpt_steps_568000.ckpt'


## Advanced settings

In [7]:
# some more hyperparameters the user may want to set. Be careful with most of these
silero_sensitivity = 0.2  # higher = more likely to detect silence
req_num_pauses = 2
min_segs_to_keep = 1
req_end_long = True
# not sure if these are correct (other than sample_rate)
sample_rate = 16000
hop_length = 160
segment_length = 30
segment_length = segment_length * hop_length * 4
SILERO_MIN_LENGTH_LONG_SILENCE = 200
# force the audio to go through even if not enough silences are formed if more than this many chunks are appended
MAX_ALLOWED_CHUNKS = 3
# the max amount of time you expect it to take to transcribe and run inference
processing_buffer = 1.5

# The data acquisition process will stop after this number of steps.
# This eliminates the need of process synchronization and makes this
# tutorial simple.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Plot mel with alignment tools

In [8]:
# https://pytorch.org/audio/main/generated/torchaudio.pipelines.Wav2Vec2ASRBundle.html
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H

model_trellis = bundle.get_model().to(DEVICE)
labels_trellis = bundle.get_labels()
dictionary = {c: i for i, c in enumerate(labels_trellis)}

In [9]:
def plot_mel_with_align(waveform, transcript, output_path):

    # Edit transcript into correct format
    transcript = normalize_numbers(transcript)
    transcript = transcript.upper()
    for punc in [".", "?", "!", "...", ",", ":", ";"]:
        transcript = transcript.replace(punc, "")
    transcript = transcript.replace(" ", "|")
    if not transcript.startswith("|"):
        transcript = "|" + transcript
    if not transcript.endswith("|"):
        transcript = transcript + "|"

    with torch.inference_mode():
        emissions, _ = model_trellis(waveform.to(DEVICE))
        emissions = torch.log_softmax(emissions, dim=-1)

    emission = emissions[0].cpu().detach()

    tokens = [dictionary.get(c, 0) for c in transcript]

    def get_trellis(emission, tokens, blank_id=0):
        num_frame = emission.size(0)
        num_tokens = len(tokens)

        trellis = torch.zeros((num_frame, num_tokens))
        trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
        trellis[0, 1:] = -float("inf")
        trellis[-num_tokens + 1 :, 0] = float("inf")

        for t in range(num_frame - 1):
            trellis[t + 1, 1:] = torch.maximum(
                # Score for staying at the same token
                trellis[t, 1:] + emission[t, blank_id],
                # Score for changing to the next token
                trellis[t, :-1] + emission[t, tokens[1:]],
            )
        return trellis

    trellis = get_trellis(emission, tokens)

    @dataclass
    class Point:
        token_index: int
        time_index: int
        score: float

    def backtrack(trellis, emission, tokens, blank_id=0):
        t, j = trellis.size(0) - 1, trellis.size(1) - 1

        path = [Point(j, t, emission[t, blank_id].exp().item())]
        while j > 0:
            # Should not happen but just in case
            assert t > 0

            # 1. Figure out if the current position was stay or change
            # Frame-wise score of stay vs change
            p_stay = emission[t - 1, blank_id]
            p_change = emission[t - 1, tokens[j]]

            # Context-aware score for stay vs change
            stayed = trellis[t - 1, j] + p_stay
            changed = trellis[t - 1, j - 1] + p_change

            # Update position
            t -= 1
            if changed > stayed:
                j -= 1

            # Store the path with frame-wise probability.
            prob = (p_change if changed > stayed else p_stay).exp().item()
            path.append(Point(j, t, prob))
        # Now j == 0, which means, it reached the SoS.
        # Fill up the rest for the sake of visualization
        while t > 0:
            prob = emission[t - 1, blank_id].exp().item()
            path.append(Point(j, t - 1, prob))
            t -= 1

        return path[::-1]

    path = backtrack(trellis, emission, tokens)

    # Merge the labels
    @dataclass
    class Segment:
        label: str
        start: int
        end: int
        score: float

        def __repr__(self):
            return (
                f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
            )

        @property
        def length(self):
            return self.end - self.start

    def merge_repeats(path):
        i1, i2 = 0, 0
        segments = []
        while i1 < len(path):
            while i2 < len(path) and path[i1].token_index == path[i2].token_index:
                i2 += 1
            score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
            segments.append(
                Segment(
                    transcript[path[i1].token_index],
                    path[i1].time_index,
                    path[i2 - 1].time_index + 1,
                    score,
                )
            )
            i1 = i2
        return segments

    segments = merge_repeats(path)

    # Merge words
    def merge_words(segments, separator="|"):
        words = []
        i1, i2 = 0, 0
        while i1 < len(segments):
            if i2 >= len(segments) or segments[i2].label == separator:
                if i1 != i2:
                    segs = segments[i1:i2]
                    word = "".join([seg.label for seg in segs])
                    score = sum(seg.score * seg.length for seg in segs) / sum(
                        seg.length for seg in segs
                    )
                    words.append(
                        Segment(word, segments[i1].start, segments[i2 - 1].end, score)
                    )
                i1 = i2 + 1
                i2 = i1
            else:
                i2 += 1
        return words

    word_segments = merge_words(segments)

    def plot_alignments(
        trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate
    ):

        fig2, ax2 = plt.subplots()

        # The original waveform
        ratio = waveform.size(0) / sample_rate / trellis.size(0)
        ax2.specgram(waveform, Fs=sample_rate)
        for word in word_segments:
            x0 = ratio * word.start
            x1 = ratio * word.end
            ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
            # ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)

        for seg in segments:
            if seg.label != "|":
                ax2.annotate(
                    seg.label,
                    (seg.start * ratio, sample_rate * 0.55),
                    annotation_clip=False,
                )
        ax2.set_xlabel("time [second]")
        ax2.set_yticks([])
        fig2.tight_layout()
        # plt.ioff()
        fig2.savefig(output_path)
        plt.close(fig2)

    plot_alignments(trellis, segments, word_segments, waveform[0])

# Tkinter GUI 

In [35]:
class GUIViewerApp:
    def __init__(self, root):
        # print("__init__ GUI")
        global key_phrases_dict_orig
        self.root = root
        self.root.title("Transcription and spectrograms live viewer")

        self.num_background_loop = 1
        
        self.start_time = 0.0
        self.stream_delay = 0.0
        self.update_dict_flag = False


        # Event to stop threads
        self.stop_event = threading.Event()

        # Add a Close button
        self.close_button = tk.Button(root, text="Close", command=self.close_window)
        self.close_button.pack(side="top")

        # Pause/Resume button
        self.is_paused = False
        self.pause_button = tk.Button(
            root,
            text="Pause after current (background) loop",
            command=self.toggle_pause_resume,
        )
        self.pause_button.pack()

        # Label for transcript
        self.transcript_label = tk.Label(root, text="...", font=("Arial", 14))
        self.transcript_label.pack(pady=10)

        # Canvas for spectrogram
        self.spectrogram_canvas = tk.Label(root)
        self.spectrogram_canvas.pack(side="left")

        self.spectrogram_canvas.pack()

        self.running = True
        self.rows = []

        # Create a frame to hold the labels and buttons (fixed part)
        self.fixed_frame = tk.Frame(root)
        self.fixed_frame.pack(side="top", fill="x", padx=10, pady=10)

        # Buttons to add/remove rows
        self.add_button = tk.Button(
            self.fixed_frame, text="Add Row", command=self.add_row
        )
        self.add_button.grid(row=1, column=0, padx=5, pady=10)

        self.remove_button = tk.Button(
            self.fixed_frame, text="Remove Row", command=self.remove_row
        )
        self.remove_button.grid(row=1, column=1, padx=5, pady=10)

        self.update_button = tk.Button(
            self.fixed_frame, text="Update Dictionary", command=self.update_dict
        )
        self.update_button.grid(row=1, column=2, padx=5, pady=10)

        # Add labels for the two columns
        self.label1 = tk.Label(
            self.fixed_frame, text="Original Phrase |", font=("Arial", 10, "bold")
        )
        self.label2 = tk.Label(
            self.fixed_frame, text="| Replacement Phrase", font=("Arial", 10, "bold")
        )

        self.label1.grid(row=0, column=0, padx=5, pady=5)
        self.label2.grid(row=0, column=1, padx=5, pady=5)

        # Create a canvas and a vertical scrollbar for scrolling the rows (scrollable part)
        self.row_canvas = tk.Canvas(root, height=200, width=200)
        self.scrollbar = tk.Scrollbar(
            root, orient="vertical", command=self.row_canvas.yview
        )
        self.row_frame = tk.Frame(self.row_canvas)

        self.row_frame.bind(
            "<Configure>",
            lambda e: self.row_canvas.configure(
                scrollregion=self.row_canvas.bbox("all")
            ),
        )

        self.row_canvas.create_window((0, 0), window=self.row_frame, anchor="nw")
        self.row_canvas.configure(yscrollcommand=self.scrollbar.set)
        self.row_canvas.pack(side="left", fill="both", expand=True)
        self.scrollbar.pack(side="right", fill="y")

        self.row_canvas.bind_all("<MouseWheel>", self._on_mouse_wheel)

        # Add rows based on the initial dictionary
        for key, value in key_phrases_dict_orig.items():
            self.add_row(key, value)

        # Start threads
        self.start_threads()



    def toggle_pause_resume(self):
        if self.is_paused:
            self.pause_button.config(text="Pause after current (background) loop")
            self.is_paused = False
        else:
            self.pause_button.config(text="Resume (background) looping")
            self.is_paused = True

    def start_threads(self):
        # print("GUI start_threads")
        # Start the main task in a separate thread
        self.main_thread = threading.Thread(
            target=main_function,
            args=(
                audio_src,
                inp_device,
                self,
            ),
            daemon=True,
        )
        self.main_thread.start()

        # Start worker thread
        self.worker_thread = threading.Thread(
            target=worker_task, args=(self,), daemon=True
        )
        self.worker_thread.start()


    def stop_threads(self):
        # print("GUI stop_threads")
        # Signal threads to stop
        self.stop_event.set()

        # Join threads to ensure they have stopped
        self.main_thread.join(timeout=2)
        # print("main_thread joined")
        self.worker_thread.join(timeout=2)
        # print("worker_thread joined")


    def add_row(self, key="", value=""):
        row = len(self.rows)
        entry1 = tk.Entry(self.row_frame, width=20)
        entry2 = tk.Entry(self.row_frame, width=20)

        entry1.grid(row=row, column=0, padx=5, pady=5)
        entry2.grid(row=row, column=1, padx=5, pady=5)

        entry1.insert(0, key)
        entry2.insert(0, value)

        self.rows.append((entry1, entry2))

        self.row_frame.update_idletasks()
        self.row_canvas.config(scrollregion=self.row_canvas.bbox("all"))

    def remove_row(self):
        if self.rows:
            entry1, entry2 = self.rows.pop()
            entry1.grid_forget()
            entry2.grid_forget()

            self.row_frame.update_idletasks()
            self.row_canvas.config(scrollregion=self.row_canvas.bbox("all"))

    def update_dict(self):
        self.update_dict_flag = True

    def _on_mouse_wheel(self, event):
        self.row_canvas.yview_scroll(int(-1 * (event.delta / 120)), "units")

    def update_gui(self, transcript, mel_path):
        
        
        self.transcript_label.config(
            text=f"{transcript}\n Current background loop number: {self.num_background_loop}"
        )

        image = Image.open(mel_path)
        self.image_reference = ImageTk.PhotoImage(image)
        self.spectrogram_canvas.config(image=self.image_reference)
        self.spectrogram_canvas.image = self.image_reference

        self.root.after(0, self.root.update)


    def clear_queue(self):
        # print("GUI clear_queue")
        try:
            while not data_queue.empty():
                data_queue.get_nowait()  # Remove and discard the item
                data_queue.task_done()
        except queue.Empty:
            pass

    def close_window(self):
        print("**** Closing the GUI ****")
        self.running = False
        self.stop_threads()
        self.clear_queue()
        self.root.quit()
        self.root.destroy()
        try:
            sys.exit(0) # Exit the program completely (this will throw an ExitError, but it's innocuous)
        except SystemExit:
            pass

    def no_op(self):
        pass

# Live speech editing

In [26]:
# Load the Silero VAD model and utilities
silero_model, silero_utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False)
(get_speech_ts, _, _, _, _) = silero_utils

Using cache found in C:\Users\fbale/.cache\torch\hub\snakers4_silero-vad_master


In [27]:
# Split wav audio on silence using silero. Can get labels for speech/silence, and timestamps
def split_on_silence_with_padding(audio_segment, silero_model, min_silence_duration_ms=500, sampling_rate=16000):

    wav = audio_segment

    # Detect non-silent (speech) segments
    speech_timestamps = get_speech_ts(
        wav,
        silero_model,
        sampling_rate=sampling_rate,
        min_silence_duration_ms=5,
        window_size_samples=512,
        threshold=silero_sensitivity,
    )
    """ 
    Other parameters that could be passed to Silero to improve things:
    threshold: float = 0.5, 
    sampling_rate: int = 16000, 
    min_speech_duration_ms: int = 250, 
    min_silence_duration_ms: int = 100, 
    window_size_samples: int = 1536, 
    speech_pad_ms: int = 30, 
    return_seconds: bool = False, 
    visualize_probs: bool = False): 
    see https://github.com/snakers4/silero-vad/discussions/201
    """

    number_of_nonsilence = len(speech_timestamps)

    if len(speech_timestamps) == 0:
        return [number_of_nonsilence, [wav, "silence", False, [0, len(wav)]]]

    # Prepare variables for splitting
    chunks = []
    prev_end = 0
    threshold_samples = int((min_silence_duration_ms / 1000) * sampling_rate)
    half_threshold_samples = threshold_samples // 2

    for index, segment in enumerate(speech_timestamps):
        curr_start = segment["start"]
        curr_end = segment["end"]
        flag_add_future_silence = False

        # First let's deal with the endpoint of the segment
        if (index < len(speech_timestamps) - 1) and (len(speech_timestamps) >= 2):
            future_start = speech_timestamps[index + 1]["start"]
        else:
            future_start = len(wav)
        future_silence_duration_samples = future_start - curr_end
        if future_silence_duration_samples > threshold_samples:
            extra_future_silence_samples = (
                future_silence_duration_samples - threshold_samples
            )
            temp_prev_end = curr_end + half_threshold_samples
            temp_next_prev_end = temp_prev_end + extra_future_silence_samples
            flag_add_future_silence = True
        else:
            temp_prev_end = (
                curr_end + future_start
            ) // 2  # Take midpoint between current end and future start as padding

        # Now let's deal with the startpoint of the segment
        silence_duration_samples = curr_start - prev_end

        # Apart from the very first segment, for all the other segments the distance from the previous segment is always less than threshold / 2
        # because of the way we deal with the future endpoints
        if silence_duration_samples > threshold_samples:
            extra_silence_duration_samples = (silence_duration_samples - threshold_samples)

            if extra_silence_duration_samples > 0:
                # This is going to be a silent chunk
                chunks.append(
                    [
                        wav[prev_end : curr_start - half_threshold_samples],
                        "silence",
                        False,
                        [prev_end, curr_start - half_threshold_samples],
                    ]
                )

            # Now append current segment
            chunks.append(
                [
                    wav[curr_start - half_threshold_samples : temp_prev_end],
                    "speech",
                    flag_add_future_silence,
                    [curr_start - half_threshold_samples, temp_prev_end],
                ]
            )
            # set the new prev_end
            prev_end = temp_prev_end

            if flag_add_future_silence:
                # Add future silence
                chunks.append(
                    [
                        wav[temp_prev_end:temp_next_prev_end],
                        "silence",
                        False,
                        [temp_prev_end, temp_next_prev_end],
                    ]
                )
                # set the new prev_end
                prev_end = temp_next_prev_end

        else:
            chunks.append(
                [
                    wav[prev_end:temp_prev_end],
                    "speech",
                    flag_add_future_silence,
                    [prev_end, temp_prev_end],
                ]
            )
            # set the new prev_end
            prev_end = temp_prev_end

            if flag_add_future_silence:
                # Add future silence
                chunks.append(
                    [
                        wav[temp_prev_end:temp_next_prev_end],
                        "silence",
                        False,
                        [temp_prev_end, temp_next_prev_end],
                    ]
                )
                # set the new prev_end
                prev_end = temp_next_prev_end

    # Deal with the very last segment (which should be a silence)
    if prev_end < len(wav):
        chunks.append(
            [wav[prev_end:], "last_segment_short_silence", False, [prev_end, len(wav)]]
        )

    return [number_of_nonsilence, chunks]

#### Load all the various models separately from inference

In [13]:
print("Building whisperX transcription model...")
torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32
# https://huggingface.co/models?pipeline_tag=text-to-speech&p=1&sort=trending
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(DEVICE)
processor = AutoProcessor.from_pretrained(model_id)
whisper = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=DEVICE,
)


hparams = set_hparams(exp_name="spec_denoiser")

print("Building our model...")
our_model = SpecDenoiserInfer(
    hparams,
    binary_data_directory,
    our_model_ckpt_path,
    whisperX_model_directory=whisperX_phoneme_model_directory,
    device=DEVICE,
)

print("Warming up models...")
our_model.example_run(
    [
        {
            "item_name": "",
            "text": "this is a libri vox recording",
            "edited_text": "this is a funny joke shows.",
            "wav_fn_orig": "inference/audio_backup/1_space.wav",
            "edited_region": "[4,6]",
            "region": "[4,6]",
            "mfa_textgrid": "",
        }
    ],
    use_MFA=False,
    use_librosa=False,
    save_wav_bool=False,
    disp_wav=False,
    mask_loc_buffer=10,
)
sample_audio, rate = torchaudio.load("inference/audio_backup/1_space.wav")
sample_audio = torchaudio.functional.resample(
    sample_audio, orig_freq=rate, new_freq=sample_rate
)[0].squeeze()
get_speech_ts(
    sample_audio,
    silero_model,
    sampling_rate=sample_rate,
    min_silence_duration_ms=5,
    window_size_samples=512,
    threshold=silero_sensitivity,
)
whisper(
    sample_audio.to("cpu").numpy(),
    chunk_length_s=30,
    stride_length_s=5,
    batch_size=1,
    return_timestamps="word",
)

Building whisperX transcription model...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Building our model...
LOAD DIFFUSION MODEL TIME: 0.7125575542449951
| load 'model_gen' from 'pretrained/hifigan_hifitts\model_ckpt_steps_2168000.ckpt'.
Build Vocoder Time 1.2619240283966064
Vocoder Device cuda
Loaded the voice encoder model on cuda in 0.09 seconds.
WHISPERX LOAD TIME = 5.142096519470215
Warming up models...
Mask loc buffer set to 10 frames, but there are only 14 frames of silence before the first edited word. Using silence midpoint instead.


{'text': ' This is a Librivox recording.',
 'chunks': [{'text': ' This', 'timestamp': (0.0, 0.26)},
  {'text': ' is', 'timestamp': (0.26, 1.52)},
  {'text': ' a', 'timestamp': (1.52, 1.74)},
  {'text': ' Librivox', 'timestamp': (1.74, 2.34)},
  {'text': ' recording.', 'timestamp': (2.34, 2.76)}]}

In [36]:
def main(inp_device, src, app):
    global key_phrases_dict_orig
    global key_phrases_dict
    global key_phrases

    long_pause_punc_list = [". ", "? ", "! ", "... "]
    pause_punc_list = [". ", "? ", "! ", "... ", ", ", ": ", "; "]
    extended_pause_punc_list = pause_punc_list + [" "]

    # Dictionary of the phrases to be replaced (the key is the phrase to be replaced, while the value is the replacement)
    # TODO: Initial uppercases/lowercases?

    print(f"Segment Length in seconds: {segment_length/sample_rate}")
    print(f"Will run for {NUM_ITER*segment_length/sample_rate} seconds")

    @torch.no_grad
    def infer(app):
        global key_phrases_dict_orig
        global key_phrases_dict
        global key_phrases
        import time

        # Counter for keeping track of saved wav files for playback (need to save separate wav files, otherwise permission errors)
        COUNTER_FOR_SD = 0

        # how many key phrases have been found in total
        num_key_phrase = 0

        # will hold the current transcript
        transcript = ""
        cur_transcript = ""

        count_num_audios_saved = 0

        transcription_inference_times = [0]
        our_model_inference_times = [0]
        total_iter_times = []  # This is not used here, so we can remove it.
        num_its_before_transcription_locked = []
        num_its_before_transcription_counter = 0

        total_audio = np.empty([1])

        CHUNKS_SO_FAR = torch.empty(0)
        flag_start_timing_its_before_transcription = True
        flag_at_least_one_transcription = False

        print("**** Start of streaming ****")
        print("Key phrases dict:", key_phrases_dict_orig)

        for i in range(NUM_ITER):
            # Check periodically to exit the loop if stop_event is set
            if app.stop_event.wait(timeout=0.0001):
                break

            chunk = q.get()

            # for some reason on the first get we wait for several audio chunks to come through. If its the first run, just burn through to the end
            # if i==0:
            #    while not q.qsize()==0:
            #        chunk,chunk_time=q.get()
            #    print(time.time()-chunk_time)
            initial_audio_recieved_time = time.time()

            CHUNKS_SO_FAR = torch.cat((CHUNKS_SO_FAR, chunk[:, 0]), 0)

            # Use silero to compute the number of speech bits in the CHUNKS_SO_FAR
            number_of_nonsilence, split_by_silence_chunks_list = (
                split_on_silence_with_padding(
                    CHUNKS_SO_FAR, silero_model, SILERO_MIN_LENGTH_LONG_SILENCE, 16000
                )
            )

            # If there are at least two speech bits or if there is only a speech bit followed by a long silence, then lock the first speech bit for transcription with Whisper
            flag_send_to_whisper = False
            if number_of_nonsilence > 0:
                # Find the first speech bit
                flag_found_first_speech_segment = False
                index_split_regions = 0
                preceding_silence = torch.empty(0)
                while not flag_found_first_speech_segment:
                    if split_by_silence_chunks_list[index_split_regions][1] != "speech":
                        preceding_silence = torch.cat(
                            (
                                preceding_silence,
                                split_by_silence_chunks_list[index_split_regions][0],
                            )
                        )
                    if split_by_silence_chunks_list[index_split_regions][1] == "speech":
                        flag_found_first_speech_segment = True
                        segment_to_send_to_whisper = split_by_silence_chunks_list[
                            index_split_regions
                        ][0]
                    index_split_regions += 1
                # Check if the first speech bit can be sent to Whisper
                # i.e. either there are at least two speech bits or
                # there is a speech bit followed by a long silence
                if (number_of_nonsilence > req_num_pauses) or (
                    split_by_silence_chunks_list[index_split_regions - 1][2] == True
                ):
                    flag_send_to_whisper = True
                    # Reset CHUNKS_SO_FAR to the remaining regions after the first speech region
                    CHUNKS_SO_FAR = torch.empty(0)
                    while index_split_regions in range(
                        len(split_by_silence_chunks_list)
                    ):
                        CHUNKS_SO_FAR = torch.cat(
                            (
                                CHUNKS_SO_FAR,
                                split_by_silence_chunks_list[index_split_regions][0],
                            ),
                            0,
                        )
                        index_split_regions += 1
                elif num_its_before_transcription_counter >= MAX_ALLOWED_CHUNKS:
                    flag_send_to_whisper = True
                    # Reset CHUNKS_SO_FAR to the remaining regions after the first speech region
                    CHUNKS_SO_FAR = torch.empty(0)
                    while index_split_regions in range(
                        len(split_by_silence_chunks_list)
                    ):
                        CHUNKS_SO_FAR = torch.cat(
                            (
                                CHUNKS_SO_FAR,
                                split_by_silence_chunks_list[index_split_regions][0],
                            ),
                            0,
                        )
                        index_split_regions += 1
                    print("Warning - Not enough pauses detected for optimal inference")

            # this is the case of all silence
            elif num_its_before_transcription_counter >= MAX_ALLOWED_CHUNKS:
                segment_to_send_to_whisper = CHUNKS_SO_FAR
                total_audio = np.concatenate(
                    (total_audio, CHUNKS_SO_FAR.numpy()), axis=0
                )
                #######################################################################################################
                wavfile.write(
                    f"{SAVE_AUDIO_DIR}/temp_{COUNTER_FOR_SD}.wav",
                    16000,
                    CHUNKS_SO_FAR.numpy(),
                )
                processed_segment_path = f"{SAVE_AUDIO_DIR}/temp_{COUNTER_FOR_SD}.wav"
                audio_segment, _ = torchaudio.load(processed_segment_path)
                mel_path = f"{SAVE_AUDIO_DIR}/mel_flagged_{COUNTER_FOR_SD}.png"
                plot_mel_with_align(audio_segment, cur_transcript, mel_path)
                data_queue.put(["", mel_path, CHUNKS_SO_FAR.numpy()])
                COUNTER_FOR_SD += 1
                count_num_audios_saved += 1
                num_its_before_transcription_counter = 0
                CHUNKS_SO_FAR = torch.empty(0)

            num_its_before_transcription_counter += 1

            if flag_start_timing_its_before_transcription:
                time_its_before_transcription = time.time()
                flag_start_timing_its_before_transcription = False

            # If the current segment can be passed to whisper for transcription:
            if flag_send_to_whisper:

                # Get the new dictionary with all the punctuated variations of keys and values
                key_phrases_dict = generate_punctuated_key_phrases_dict(
                    key_phrases_dict_orig, extended_pause_punc_list
                )
                key_phrases = list(key_phrases_dict.keys())

                def update_key_phrases_dict():
                    global key_phrases_dict_orig
                    global key_phrases_dict
                    global key_phrases

                    key_phrases_dict_orig.clear()  # Clear the dictionary before updating

                    for entry1, entry2 in app.rows:
                        key = entry1.get().strip()
                        value = entry2.get().strip()
                        if key:  # Only add non-empty keys
                            key_phrases_dict_orig[key] = value

                    key_phrases_dict = generate_punctuated_key_phrases_dict(
                        key_phrases_dict_orig, extended_pause_punc_list
                    )
                    key_phrases = list(key_phrases_dict.keys())

                    print(
                        "Updated dictionary:", key_phrases_dict_orig
                    )  # Print to console for demonstration

                if app.update_dict_flag:
                    update_key_phrases_dict()
                    app.update_dict_flag = False

                flag_at_least_one_transcription = True

                num_its_before_transcription_locked.append(
                    num_its_before_transcription_counter
                )
                num_its_before_transcription_counter = 0
                total_iter_times.append(time.time() - time_its_before_transcription)
                flag_start_timing_its_before_transcription = True

                # Pass the locked audiosegment to Whisper for transcription
                transcription_inference_time_start = time.time()
                transcript = whisper(
                    segment_to_send_to_whisper.numpy(),
                    chunk_length_s=30,
                    stride_length_s=5,
                    batch_size=1,
                    return_timestamps="word",
                )

                transcription_inference_times.append(
                    time.time() - transcription_inference_time_start
                )
                cur_transcript = (
                    " " + " ".join(transcript["text"].lower().split()) + " "
                )

                # Look for flagged words in the transcript
                if any([phrase in cur_transcript for phrase in key_phrases]):
                    wavfile.write(
                        f"{SAVE_AUDIO_DIR}/flagged_{num_key_phrase}.wav",
                        sample_rate,
                        segment_to_send_to_whisper.numpy(),
                    )

                    # Run our speech editing model
                    our_model_inference_time_start = time.time()
                    dataset_info = prep_inp_for_replacement_handle_multiple_key_phrases(
                        cur_transcript, f"{SAVE_AUDIO_DIR}/flagged_{num_key_phrase}.wav"
                    )
                    result_wavs = our_model.example_run(
                        dataset_info, False, False, False, False, 5
                    )
                    our_model_inference_times.append(
                        time.time() - our_model_inference_time_start
                    )

                    wavfile.write(
                        f"{SAVE_AUDIO_DIR}/flagged_edited_{num_key_phrase}.wav",
                        sample_rate,
                        result_wavs[0][1].astype(np.float32),
                    )

                    num_flagged_phrases = sum(
                        x != 0
                        for x in [
                            phrase in cur_transcript.lower() for phrase in key_phrases
                        ]
                    )
                    num_key_phrase += num_flagged_phrases

                    if SAVE_ALL_AUDIO:
                        # Note: here the original sampling rate is 22050 because it's the output of our whisperX model (which works at that sampling rate)
                        total_audio = np.concatenate(
                            (
                                total_audio,
                                preceding_silence.numpy(),
                                torchaudio.functional.resample(
                                    torch.tensor(result_wavs[0][1]),
                                    orig_freq=22050,
                                    new_freq=22050,
                                )
                                .to("cpu")
                                .numpy(),
                            ),
                            axis=0,
                        )

                    #######################################################################################################
                    wavfile.write(
                        f"{SAVE_AUDIO_DIR}/temp_{COUNTER_FOR_SD}_22050.wav",
                        22050,
                        torchaudio.functional.resample(
                            torch.tensor(result_wavs[0][1]),
                            orig_freq=22050,
                            new_freq=22050,
                        )
                        .to("cpu")
                        .numpy()
                        .astype(np.float32),
                    )
                    processed_segment_path = (
                        f"{SAVE_AUDIO_DIR}/temp_{COUNTER_FOR_SD}_22050.wav"
                    )
                    audio_segment, _ = torchaudio.load(processed_segment_path)
                    mel_path = f"{SAVE_AUDIO_DIR}/mel_flagged_{COUNTER_FOR_SD}.png"
                    plot_mel_with_align(
                        audio_segment, dataset_info[0]["edited_text"], mel_path
                    )
                    resampled_inferred_audio = (
                        torchaudio.functional.resample(
                            torch.tensor(result_wavs[0][1]),
                            orig_freq=22050,
                            new_freq=sample_rate,
                        )
                        .to("cpu")
                        .numpy()
                    )


                    data_queue.put(
                        [
                            dataset_info[0]["edited_text_with_marked_words"],
                            mel_path,
                            np.concatenate((preceding_silence.numpy(), resampled_inferred_audio), 0)
                        ]
                    )
                    COUNTER_FOR_SD += 1
                    #######################################################################################################

                    count_num_audios_saved += 1

                else:  # No flagged words
                    total_audio = np.concatenate(
                        (
                            total_audio,
                            preceding_silence.numpy(),
                            segment_to_send_to_whisper.numpy(),
                        ),
                        axis=0,
                    )

                    #######################################################################################################
                    wavfile.write(
                        f"{SAVE_AUDIO_DIR}/temp_{COUNTER_FOR_SD}.wav",
                        16000,
                        segment_to_send_to_whisper.numpy(),
                    )
                    processed_segment_path = (
                        f"{SAVE_AUDIO_DIR}/temp_{COUNTER_FOR_SD}.wav"
                    )
                    audio_segment, _ = torchaudio.load(processed_segment_path)
                    mel_path = f"{SAVE_AUDIO_DIR}/mel_flagged_{COUNTER_FOR_SD}.png"
                    plot_mel_with_align(audio_segment, cur_transcript, mel_path)
                    data_queue.put(
                        [
                            cur_transcript,
                            mel_path,
                            np.concatenate(
                                (
                                    preceding_silence.numpy(),
                                    segment_to_send_to_whisper.numpy(),
                                ),
                                0,
                            )
                        ]
                    )

                    COUNTER_FOR_SD += 1
                    #######################################################################################################

                    count_num_audios_saved += 1

        if SAVE_ALL_AUDIO:
            wavfile.write(
                f"{SAVE_AUDIO_DIR}/complete.wav", 22050, total_audio.astype(np.float32)
            )

        if not flag_at_least_one_transcription:
            total_iter_times.append(time.time() - time_its_before_transcription)

        print(
            f"Average Transcription Inference Time: {np.mean(transcription_inference_times)}. Maximum: {np.max(transcription_inference_times)} Transcription performed on {len(transcription_inference_times)-1} of {NUM_ITER} iterations"
        )
        print(
            f"Average Our Model Inference Time: {np.mean(our_model_inference_times)}. Maximum: {np.max(our_model_inference_times)} .Replacement performed on {len(our_model_inference_times)-1} of {NUM_ITER} iterations"
        )
        print(
            f"Average Total Iteration Time: {np.mean(total_iter_times)}. Maximum: {np.max(total_iter_times)}."
        )
        try:
            print(
                f"Average Number of Iterations before transcription: {np.mean(num_its_before_transcription_locked)}. Maximum: {np.max(num_its_before_transcription_locked)}. Number of locked segments: {len(num_its_before_transcription_locked)}"
            )
        except:
            pass

        print(f"Total Audio Shape: {np.shape(total_audio)}")

        print("**** Iterations finished! ****")

    def generate_punctuated_key_phrases_dict(
        key_phrases_dict, extended_pause_punc_list
    ):
        punct_key_phrases_dict = {}
        key_phrases = list(key_phrases_dict.keys())
        for phrase in key_phrases:
            flag_starts_with_space = False
            if not phrase.startswith(" "):
                flag_starts_with_space = True

            for p in extended_pause_punc_list:
                if not phrase.endswith(p):
                    if not flag_starts_with_space:
                        punct_key_phrases_dict[phrase + p] = (
                            key_phrases_dict[phrase] + p
                        )
                    else:
                        punct_key_phrases_dict[" " + phrase + p] = (
                            " " + key_phrases_dict[phrase] + p
                        )
        return punct_key_phrases_dict

    def prep_inp_for_replacement_handle_multiple_key_phrases(
        transcipt, file_name
    ):  # ,silero_timestamps):
        global key_phrases_dict
        global key_phrases
        # right now if there are multiple key phrases in the transcipt, this tells our model to infer those phrases and everything in between
        # a smarter way to do this would be to make it so that our model supports editing multiple regions

        dataset_info = [{}]
        dataset_info[0][
            "item_name"
        ] = ""  # this should just be used for naming the output file
        dataset_info[0]["text"] = " ".join(
            transcipt.lower().split()
        )  # a transcription of the original text
        for punc in pause_punc_list:
            dataset_info[0]["text"] = dataset_info[0]["text"].replace(punc.strip(), "")
        dataset_info[0][
            "wav_fn_orig"
        ] = file_name  # location of the .wav file to perform inference on

        transcript_words = transcipt.lower().split()
        key_phrase_words = [phrase.split() for phrase in key_phrases]

        # right now if multiple key phrases exist in the transcript, we consider the last one appearing in key_phrases_dict as the one to be replaced, and replace the first occurance of it

        num_ves = 0

        phrase_info_dicts = []

        change_in_region_length = 0
        for phrase in key_phrase_words:
            try:
                word_reg_start = transcript_words.index(phrase[0])
                word_reg_end = transcript_words.index(phrase[-1])
                key_phrase = " " + " ".join(phrase) + " "
                phrase_info_dicts.append(
                    {"phrase": key_phrase, "start": word_reg_start, "end": word_reg_end}
                )
                key = " " + " ".join(phrase) + " "
                change_in_region_length += len(key_phrases_dict[key].split()) - len(
                    phrase
                )
                print(change_in_region_length)
            except ValueError:
                num_ves += 1
        if num_ves >= len(key_phrase_words):
            print(
                "Error: Attempting to replace a word that does not exist in the transcript"
            )
            return 1

        dataset_info[0]["edited_text"] = " " + " ".join(transcipt.lower().split()) + " "
        dataset_info[0]["edited_text_with_marked_words"] = (
            " " + " ".join(transcipt.lower().split()) + " "
        )

        for phrase in key_phrases:
            replacement_phrase = key_phrases_dict[phrase]
            dataset_info[0]["edited_text_with_marked_words"] = dataset_info[0][
                "edited_text_with_marked_words"
            ].replace(phrase, f" |{replacement_phrase.strip()}| ")
            dataset_info[0]["edited_text"] = dataset_info[0]["edited_text"].replace(
                phrase, replacement_phrase
            )

        for punc in pause_punc_list:
            dataset_info[0]["edited_text"] = dataset_info[0]["edited_text"].replace(
                punc.strip(), ""
            )
            dataset_info[0]["edited_text_with_marked_words"] = dataset_info[0][
                "edited_text_with_marked_words"
            ].replace(punc.strip(), "")
        dataset_info[0]["edited_text"] = dataset_info[0]["edited_text"].strip()
        dataset_info[0]["edited_text_with_marked_words"] = dataset_info[0][
            "edited_text_with_marked_words"
        ].strip()

        phrase_info_dicts = sorted(
            phrase_info_dicts, key=lambda d: d["start"]
        )  # this could potentially fail, it assumes that the transcript and phrases are all so that the first occurance of the first word in the phrase is contained in the full phrase
        word_reg_start = phrase_info_dicts[0]["start"]
        word_reg_end = phrase_info_dicts[-1]["end"]

        dataset_info[0][
            "region"
        ] = f"[{word_reg_start+1},{word_reg_end+1}]"  # the region to edit (counting the words that will be changed starting from 1)

        dataset_info[0][
            "edited_region"
        ] = f"[{word_reg_start+1},{word_reg_end+1+change_in_region_length}]"  # word counts in the full edited text of the region which is to be inferred starting from one

        dataset_info[0][
            "mfa_textgrid"
        ] = ""  # we still need to set this to some value even if we are not using MFA

        if not os.path.exists("./transcripts"):
            os.makedirs("./transcripts")
        import json

        with open(f"./transcripts/dataset_info.txt", "a") as f:
            f.write(json.dumps(dataset_info[0]))
            f.write("\n")
        return dataset_info

    ctx = mp.get_context("spawn")
    manager = (
        ctx.Manager()
    )  # for some reason this fixes an issue I was having with multiprocessing https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847
    q = manager.Queue()
    p = ctx.Process(
        target=stream,
        args=(
            q,
            inp_device,
            src,
            segment_length,
            sample_rate,
            NUM_ITER,
        ),
    )

    # print("p started")
    p.start()
    p_start_time = time.time()
    infer(app)
    # print("infer finished")
    p.join(timeout=1)
    # print("p joined")

In [37]:
def main_function(audio_src, inp_device, app):

    while not app.stop_event.is_set():
        # Check if the GUI has been paused
        if not app.is_paused:
            print("****************")
            print("(Background) loop number", app.num_background_loop)
            print("****************")
            gc.collect()

            data_queue = queue.Queue()
            main(src=audio_src, inp_device=inp_device, app=app)

            app.num_background_loop += 1
        else:
            time.sleep(0.0001)

        # Check periodically to exit the loop if stop_event is set
        if app.stop_event.wait(timeout=0.1):
            break

In [38]:
def worker_task(app):

    FORMAT = pyaudio.paFloat32
    CHANNELS = 1
    RATE = 16000

    pya = pyaudio.PyAudio()

    py_stream = pya.open(format=FORMAT, channels=CHANNELS, rate=RATE, output=True)

    while not app.stop_event.is_set():
        while not data_queue.empty():
            (transcript, mel_path, audio_chunk) = data_queue.get()
            

            if app.running:
                # Check if the root window is still active
                if app.root.winfo_exists():
                    # Schedule the GUI update on the main thread
                    try:
                        app.root.after(0, app.update_gui, transcript, mel_path)
                    except RuntimeError:
                        # The root window may have been destroyed in the meantime
                        break

                
                py_stream.write(audio_chunk.tobytes())
                # py_stream.write(audio_chunk)
                data_queue.task_done()
            else:
                break
        if not app.running:
            break

        while data_queue.empty():
            time.sleep(0.0001)
            if not app.running:
                break
        # Check periodically to exit the loop if stop_event is set
        if app.stop_event.wait(timeout=0.1):
            break

    py_stream.stop_stream()
    py_stream.close()
    pya.terminate()


    print("If it didn't close automatically, please manually close the GUI")
    try:
        while not data_queue.empty():
            data_queue.get_nowait()  # Remove and discard the item
            data_queue.task_done()  # Mark the task as done
    except data_queue.empty():
        pass  # If the queue is already empty, just pass
    # app.root.after(1, app.root.destroy)

### User input

In [41]:
# Don't change this!
inp_device = None  # None for a file

# The path to your .wav file (make sure the sample rate is 16000 and the audio is mono)
audio_src = "test_sample.wav"

# whether to save all audio and the initial dictionary (can be updated in gui but takes a long time before will take effect)
SAVE_ALL_AUDIO = False

# The number of iterations for each background loop; the length of the playback of a loop is NUM_ITER * segment_length / sample_rate
NUM_ITER = 20

# The words/phrases to be edited (the key is the word/phrase to be edited, the value is what it gets edited to )
key_phrases_dict_orig = {
    'blueprints': 'yellowprints',
    'hand':'lend',
    'things':'kings'
}

In [42]:
if __name__ == "__main__":
    clear_output()
    gc.collect()
    tk_root = tk.Tk()
    tk_app = GUIViewerApp(tk_root)
    tk_root.mainloop()
    print("**** End of streaming ****")
    gc.collect()
    torch.cuda.empty_cache()

****************
(Background) loop number 1
****************
Segment Length in seconds: 1.2
Will run for 24.0 seconds
**** Start of streaming ****
Key phrases dict: {'blueprints': 'yellowprints', 'hand': 'lend', 'things': 'kings'}
**** Closing the GUI ****
If it didn't close automatically, please manually close the GUI
**** End of streaming ****


Average Transcription Inference Time: 0.0. Maximum: 0 Transcription performed on 0 of 20 iterations
Average Our Model Inference Time: 0.0. Maximum: 0 .Replacement performed on 0 of 20 iterations
Average Total Iteration Time: 0.0. Maximum: 0.0.
Total Audio Shape: (1,)
**** Iterations finished! ****
