In [1]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForCausalLM
import torch
import os
from dotenv import load_dotenv
import sounddevice as sd
import numpy as np
import tkinter as tk
from threading import Thread
# Load environment variables and Hugging Face token
load_dotenv()
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')


text_gen_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=huggingface_token)
text_gen_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=huggingface_token)


device_text = "mps" if torch.backends.mps.is_available() else "cpu"
text_gen_model = text_gen_model.to(device_text)

# Initialize the processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

Loading checkpoint shards: 100%|##########| 2/2 [00:05<00:00,  2.60s/it]
Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.

In [None]:
# Global variable to control the recording state
is_recording = False
audio_data = np.array([])


def record_audio(fs=16000):
    """Continuously record audio until stopped."""
    global is_recording, audio_data
    with sd.InputStream(samplerate=fs, channels=1, callback=callback):
        while is_recording:
            sd.sleep(100)  # Small sleep to avoid locking up the CPU


def callback(indata, frames, time, status):
    """This is called for each audio block from the microphone."""
    global audio_data
    audio_data = np.append(audio_data, indata.copy())


def toggle_recording():
    """Toggle the recording state between start and stop."""
    global is_recording, audio_data
    if not is_recording:
        # Start recording
        is_recording = True
        audio_data = np.array([])
        Thread(target=record_audio).start()  # Start recording in a background thread
    else:
        # Stop recording
        is_recording = False
        process_audio(audio_data)  # Process the recorded audio


def process_audio(audio_data, fs=16000):
    """Process the recorded audio and update the transcription label."""
    # Preprocess the audio to match the model's expected format
    input_values = processor(
        audio_data, return_tensors="pt", sampling_rate=fs
    ).input_values

    # Move to the same device as the model
    input_values = input_values.to(model.device)

    # Perform inference
    with torch.no_grad():
        logits = model(input_values).logits

    # Decode the predicted ids
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]

    # Update the transcription label text in the main thread
    def update_transcription_label():
        transcription_label.config(text=f"Transcribed Text: {transcription}")
        generate_text(transcription)

    root.after(0, update_transcription_label)


def generate_text(transcription):
    print("Starting text generation...")

    def background_generate():
        print("Starting text generation inside...")
        input_ids = text_gen_tokenizer.encode(transcription, return_tensors="pt").to(device_text)
        print(input_ids)
        try:
            # Generate text using the model and tokenizer
            print("iam inside")
            output_sequences = text_gen_model.generate(input_ids, max_length=2000, num_return_sequences=1)
            print("output_sequence", output_sequences)
            generated_text = text_gen_tokenizer.decode(output_sequences[0], skip_special_tokens=True)
            print("Generated text:", generated_text)  # For debugging
        except Exception as e:
            print("An error occurred:", e)
            generated_text = "Error in generating text."

        # Function to update the GUI with the generated text
        def update_generated_text_label():
            generated_text_label.config(text=f"Model Response: {generated_text}")

        # Schedule the GUI update to run in the main thread
        root.after(0, update_generated_text_label)

    # Run the text generation in a background thread to avoid blocking the GUI
    Thread(target=background_generate).start()


# Set up the GUI
root = tk.Tk()
root.title("Voice Recorder")

# Add a record button
record_button = tk.Button(root, text="Start Recording", command=toggle_recording)
record_button.pack(pady=20)

# Add a label widget for displaying the transcription
transcription_label = tk.Label(
    root, text="Transcription will appear here...", wraplength=400
)
transcription_label.pack(pady=10)

generated_text_label = tk.Label(
    root, text="Generated text will appear here...", wraplength=2000
)
generated_text_label.pack(pady=10)


def update_button_text():
    """Update the button text based on the recording state."""
    if is_recording:
        record_button.config(text="Stop Recording")
    else:
        record_button.config(text="Start Recording")
    root.after(100, update_button_text)


root.after(100, update_button_text)  # Check every 100ms to update the button text
root.mainloop()

Starting text generation...
Starting text generation inside...
tensor([[     2,  33574,  14975,    586,  50471,  30825,   9661,  41533, 128362,
           2148,   5167,  18744,  26224, 111965,  92364,  54094,    595,  11047,
           1905]], device='mps:0')
iam inside
output_sequence tensor([[     2,  33574,  14975,    586,  50471,  30825,   9661,  41533, 128362,
           2148,   5167,  18744,  26224, 111965,  92364,  54094,    595,  11047,
           1905,   3228, 179743, 235265,    109, 235285,   1144,    476,   2910,
           5255,   2091, 235269,  17363,    731,   6238, 235265,    590,   1144,
          13920,    576,  31928,   3515, 235290,  14683,   2793,    575,   3590,
            577,    476,   5396,   3001,    576,  73815,    578,   3920, 235265,
            109,    688,   4858,    708,   1009,   8944,    576,    970,  22466,
          66058,    109, 235287,    590,    798,   5598,   8965,    578,  29632,
         235265,    108, 235287,    590,    798,  19631,  17044, 