In [37]:
import io
import os
import subprocess
import time

from docx import Document
import IPython.display as ipd
from etils import epath as ep
from google.api_core.client_options import ClientOptions
from google.cloud.speech_v2 import SpeechClient
from google.cloud.speech_v2.types import cloud_speech
import jiwer
import pandas as pd
import plotly.graph_objs as go
from pydub import AudioSegment

In [19]:
# Use the environment variable if the user doesn't provide Project ID.
import os

PROJECT_ID = "verdant-branch-457906-a2"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}
if not PROJECT_ID or PROJECT_ID == "[verdant-branch-457906-a2]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "europe-west4")
print(f"Using project {PROJECT_ID} in location {LOCATION}")

BUCKET_NAME = "bdav42"  # @param {type:"string", isTemplate: true}
BUCKET_URI = f"gs://{BUCKET_NAME}"  # @param {type:"string"}
print(f"Using bucket {BUCKET_URI}")

Using project verdant-branch-457906-a2 in location europe-west4
Using bucket gs://bdav42


In [20]:
API_ENDPOINT = f"{LOCATION}-speech.googleapis.com"

client = SpeechClient(
    client_options=ClientOptions(
        api_endpoint=API_ENDPOINT,
    )
)

# INPUT_AUDIO_SAMPLE_FILE_URI = (
#     "gs://github-repo/audio_ai/speech_recognition/attention_is_all_you_need_podcast.wav"
# )
INPUT_LONG_AUDIO_SAMPLE_FILE_URI = (
    f"{BUCKET_URI}/Arvin_converted.wav"
)

RECOGNIZER = client.recognizer_path(PROJECT_ID, LOCATION, "_")


In [28]:
def read_audio_file(audio_file_path: str) -> bytes:
    """
    Read audio file as bytes.
    """
    if audio_file_path.startswith("gs://"):
        with ep.Path(audio_file_path).open("rb") as f:
            audio_bytes = f.read()
    else:
        with open(audio_file_path, "rb") as f:
            audio_bytes = f.read()
    return audio_bytes


def save_audio_sample(audio_bytes: bytes, output_file_uri: str) -> None:
    """
    Save audio sample as a file in Google Cloud Storage.
    """

    output_file_path = ep.Path(output_file_uri)
    if not output_file_path.parent.exists():
        output_file_path.parent.mkdir(parents=True, exist_ok=True)

    with output_file_path.open("wb") as f:
        f.write(audio_bytes)


def extract_audio_sample(audio_bytes: bytes, duration: int) -> bytes:
    """
    Extracts a random audio sample of a given duration from an audio file.
    """
    audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
    start_time = 0
    audio_sample = audio[start_time : start_time + duration * 1000]

    audio_bytes = io.BytesIO()
    audio_sample.export(audio_bytes, format="wav")
    audio_bytes.seek(0)

    return audio_bytes.read()


def play_audio_sample(audio_bytes: bytes) -> None:
    """
    Plays the audio sample in a notebook.
    """
    ipd.display(ipd.Audio(io.BytesIO(audio_bytes).read(), rate=44100))


def audio_sample_chunk_n(audio_bytes: bytes, num_chunks: int) -> list[bytes]:
    """
    Chunks an audio sample into a specified number of chunks and returns a list of bytes for each chunk.
    """
    audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
    total_duration = len(audio)
    chunk_duration = total_duration // num_chunks

    chunks = []
    start_time = 0

    for _ in range(num_chunks):
        end_time = min(start_time + chunk_duration, total_duration)
        chunk = audio[start_time:end_time]

        audio_bytes_chunk = io.BytesIO()
        chunk.export(audio_bytes_chunk, format="wav")
        audio_bytes_chunk.seek(0)
        chunks.append(audio_bytes_chunk.read())

        start_time = end_time

    return chunks


def audio_sample_merge(audio_chunks: list[bytes]) -> bytes:
    """
    Merges a list of audio chunks into a single audio sample.
    """
    audio = AudioSegment.empty()
    for chunk in audio_chunks:
        audio += AudioSegment.from_file(io.BytesIO(chunk))

    audio_bytes = io.BytesIO()
    audio.export(audio_bytes, format="wav")
    audio_bytes.seek(0)

    return audio_bytes.read()


def compress_for_streaming(audio_bytes: bytes) -> bytes:
    """
    Compresses audio bytes for streaming using ffmpeg, ensuring the output size is under MAX_CHUNK_SIZE bytes.
    """
    # Temporary file to store original audio
    with open("temp_original.wav", "wb") as f:
        f.write(audio_bytes)

    # Initial compression attempt with moderate bitrate
    bitrate = "32k"
    subprocess.run(
        [
            "ffmpeg",
            "-i",
            "temp_original.wav",
            "-b:a",
            bitrate,
            "-y",
            "temp_compressed.mp3",
        ]
    )

    # Check if compressed size is within limit
    compressed_size = os.path.getsize("temp_compressed.mp3")
    if compressed_size <= MAX_CHUNK_SIZE:
        with open("temp_compressed.mp3", "rb") as f:
            compressed_audio_bytes = f.read()
    else:
        # If too large, reduce bitrate and retry
        while compressed_size > MAX_CHUNK_SIZE:
            bitrate = str(int(bitrate[:-1]) - 8) + "k"  # Reduce bitrate by 8kbps
            subprocess.run(
                [
                    "ffmpeg",
                    "-i",
                    "temp_original.wav",
                    "-b:a",
                    bitrate,
                    "-y",
                    "temp_compressed.mp3",
                ]
            )
            compressed_size = os.path.getsize("temp_compressed.mp3")

        with open("temp_compressed.mp3", "rb") as f:
            compressed_audio_bytes = f.read()

    # Clean up temporary files
    os.remove("temp_original.wav")
    os.remove("temp_compressed.mp3")

    return compressed_audio_bytes


def parse_streaming_recognize_response(response) -> list[tuple[str, int]]:
    """Parse streaming responses from the Speech-to-Text API"""
    streaming_recognize_results = []
    for r in response:
        for result in r.results:
            streaming_recognize_results.append(
                (result.alternatives[0].transcript, result.result_end_offset)
            )
    return streaming_recognize_results


def parse_real_time_recognize_response(response) -> list[tuple[str, int]]:
    """Parse real-time responses from the Speech-to-Text API"""
    real_time_recognize_results = []
    for result in response.results:
        real_time_recognize_results.append(
            (result.alternatives[0].transcript, result.result_end_offset)
        )
    return real_time_recognize_results


def parse_batch_recognize_response(
    response, audio_sample_file_uri: str = INPUT_LONG_AUDIO_SAMPLE_FILE_URI
) -> list[tuple[str, int]]:
    """Parse batch responses from the Speech-to-Text API"""
    batch_recognize_results = []
    for result in response.results[
        audio_sample_file_uri
    ].inline_result.transcript.results:
        batch_recognize_results.append(
            (result.alternatives[0].transcript, result.result_end_offset)
        )
    return batch_recognize_results


def get_recognize_output(
    audio_bytes: bytes, recognize_results: list[tuple[str, int]]
) -> list[tuple[bytes, str]]:
    """
    Get the output of recognize results, handling 0 timedelta and ensuring no overlaps or gaps.
    """
    audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
    recognize_output = []
    start_time = 0

    initial_end_time = recognize_results[0][1].total_seconds() * 1000

    # This loop handles the streaming case where result timestamps might be zero.
    if initial_end_time == 0:
        for i, (transcript, timedelta) in enumerate(recognize_results):
            if i < len(recognize_results) - 1:
                # Use the next timedelta if available
                next_end_time = recognize_results[i + 1][1].total_seconds() * 1000
                end_time = next_end_time
            else:
                next_end_time = len(audio)
                end_time = next_end_time

            # Ensure no gaps between chunks
            chunk = audio[start_time:end_time]
            chunk_bytes = io.BytesIO()
            chunk.export(chunk_bytes, format="wav")
            chunk_bytes.seek(0)
            recognize_output.append((chunk_bytes.read(), transcript))

            # Set start_time for the next iteration
            start_time = end_time
    else:
        for i, (transcript, timedelta) in enumerate(recognize_results):
            # Calculate end_time in milliseconds
            end_time = timedelta.total_seconds() * 1000

            # Ensure no gaps between chunks
            chunk = audio[start_time:end_time]
            chunk_bytes = io.BytesIO()
            chunk.export(chunk_bytes, format="wav")
            chunk_bytes.seek(0)
            recognize_output.append((chunk_bytes.read(), transcript))

            # Set start_time for the next iteration
            start_time = end_time

    return recognize_output


def print_transcription(audio_sample_bytes: bytes, transcription: str) -> None:
    """Prettify the play of the audio and the associated print of the transcription text in a notebook"""

    # Play the audio sample
    display(ipd.HTML("<b>Audio:</b>"))
    play_audio_sample(audio_sample_bytes)
    display(ipd.HTML("<br>"))

    # Display the transcription text
    display(ipd.HTML("<b>Transcription:</b>"))
    formatted_text = f"<pre style='font-family: monospace; white-space: pre-wrap;'>{transcription}</pre>"
    display(ipd.HTML(formatted_text))


# def evaluate_stt(
#     actual_transcriptions: list[str],
#     reference_transcriptions: list[str],
#     audio_sample_file_uri: str = INPUT_LONG_AUDIO_SAMPLE_FILE_URI,
# ) -> pd.DataFrame:
#     """
#     Evaluate speech-to-text (STT) transcriptions against reference transcriptions.
#     """
#     audio_uris = [audio_sample_file_uri] * len(actual_transcriptions)
#     evaluations = []
#     for audio_uri, actual_transcription, reference_transcription in zip(
#         audio_uris, actual_transcriptions, reference_transcriptions
#     ):
#         evaluation = {
#             "audio_uri": audio_uri,
#             "actual_transcription": actual_transcription,
#             "reference_transcription": reference_transcription,
#             "wer": jiwer.wer(reference_transcription, actual_transcription),
#             "cer": jiwer.cer(reference_transcription, actual_transcription),
#         }
#         evaluations.append(evaluation)

#     evaluations_df = pd.DataFrame(evaluations)
#     evaluations_df.reset_index(inplace=True, drop=True)
#     return evaluations_df


def plot_evaluation_results(
    evaluations_df: pd.DataFrame,
) -> go.Figure:
    """
    Plot the mean Word Error Rate (WER) and Character Error Rate (CER) from the evaluation results.
    """
    mean_wer = evaluations_df["wer"].mean()
    mean_cer = evaluations_df["cer"].mean()

    trace_means = go.Bar(
        x=["WER", "CER"], y=[mean_wer, mean_cer], name="Mean Error Rate"
    )

    trace_baseline = go.Scatter(
        x=["WER", "CER"], y=[0.5, 0.5], mode="lines", name="Baseline (0.5)"
    )

    layout = go.Layout(
        title="Speech-to-Text Evaluation Results",
        xaxis=dict(title="Metric"),
        yaxis=dict(title="Error Rate", range=[0, 1]),
        barmode="group",
    )

    fig = go.Figure(data=[trace_means, trace_baseline], layout=layout)
    return fig

In [32]:
# Read the audio file
input_audio_bytes = read_audio_file(INPUT_LONG_AUDIO_SAMPLE_FILE_URI)

# Extract a random audio sample 
short_audio_sample_bytes = extract_audio_sample(input_audio_bytes, 30)

play_audio_sample(short_audio_sample_bytes)

ImportError: To use epath.Path with gs://, fsspec should be installed.'

In [23]:
### Perform batch recognition

# def batch_recognize
batch_recognition_config = cloud_speech.RecognitionConfig(
    language_codes=["fa-IR"],
    model="chirp_2",
    features=cloud_speech.RecognitionFeatures(
        enable_automatic_punctuation=True,
    ),
    auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
)

# Set the audio file URI
audio_metadata = cloud_speech.BatchRecognizeFileMetadata(
    uri=INPUT_LONG_AUDIO_SAMPLE_FILE_URI
)

# Create the request
batch_recognition_request = cloud_speech.BatchRecognizeRequest(
    config=batch_recognition_config,
    files=[audio_metadata],
    recognition_output_config=cloud_speech.RecognitionOutputConfig(
        inline_response_config=cloud_speech.InlineOutputConfig(),
    ),
    recognizer=RECOGNIZER,
)


In [25]:
# Run the batch recognition operation
operation = client.batch_recognize(request=batch_recognition_request)
# Wait for the operation to complete

while True:
    if not operation.done():
        print("Waiting for operation to complete...")
        time.sleep(60)
    else:
        print("Operation completed.")
        break

response = operation.result()

Waiting for operation to complete...
Operation completed.


In [39]:
# Visualize the results

batch_recognize_results = parse_batch_recognize_response(
    response, audio_sample_file_uri=INPUT_LONG_AUDIO_SAMPLE_FILE_URI
)
# batch_recognize_output = get_recognize_output(
#     long_audio_sample_bytes, batch_recognize_results
# )
# for audio_sample_bytes, transcription in batch_recognize_output:
#     print_transcription(audio_sample_bytes, transcription)
print(batch_recognize_results)
type(batch_recognize_results)

[('قربانم بد نیستم شکر خدا شما چه خبرا ما مشغول دیگه امروز دیگه منتظریم ببینیم نتیجه مذاکرات چی میشه و بعد بعضی کانادا هم که کماکان خرابه و آره واقعا حالا انتخابات هم داره نزدیکه من که به کانزروتیو ها رای میدم نه که رای هم کانزروتیو باشه ولی فکر میکنم در موقعیت فعلی مناسب ترن برای کانادا حداقل', datetime.timedelta(seconds=30)), ('الان من فکر میکنم کلا کانسروتیو با رای خیلی بالایی میان من فکر نمیکنم بیان لیبرال ها الان تو اکثریت اند فکر نمیکنم چون خود لیبرال ها خیلی من خودم اگه بودم به کانسروتیو ها رای میدادم البته منطقه ما فرقی نمیکنه چون اینجا ان دی پی حالا مثلا ما حتی به هر کی هم رای بدیم رای ما کم نمیشه اگه ما اتفاقا به ان دی پی رای بدیم شاید بهتر باشه بخاطر اینکه سیپ لیبرال ها رو میگیره آره میدونید چی میگم آره به نظر من شاید اینجوری بشه', datetime.timedelta(seconds=60)), ('اونقدر مهم اینه که فعلاً لیبرال\u200cها حتی امکان نیاد سر کار نه لیبرال دیگه به هر حال به اندازه کافی تر زدن دیگه حالا باید برن آره آره خیلی داغون کردن خب من کانکت کردم با درل برات ایمیل می\u200cکنه اون رنتو مگه 

list

In [41]:
    # document = Document()
    # for result in response.results:
    #     for alternative in result.alternatives:
    #         document.add_paragraph(alternative.transcript)
    #         for word_info in alternative.words:
    #             document.add_paragraph(f"Word: {word_info.word}, Speaker: {word_info.speaker_tag}")
    #         document.add_paragraph("-" * 20)

    # document.save("transcription.docx")
    # print("Transcription saved to transcription.docx")

document = Document()
document.add_heading('Persian Translation Results', level=0)

for transcription, duration in batch_recognize_results:
    document.add_paragraph(transcription)
    # document.add_paragraph(f"Duration: {duration}")   # Optional: Add duration if needed# 
    document.add_paragraph() # Add an empty paragraph for spacing   

document.save('persian_translation.docx')
print("Persian translation saved to persian_translation.docx")

Persian translation saved to persian_translation.docx


In [None]:
# Evaluate the results: if you have a reference transcription, you can compare it with the actual transcription.

actual_transcriptions = [t for _, t in batch_recognize_output]
reference_transcriptions = [
    """sentence 1""",
    """next""",
    ""next""",
    """next""",
]

evaluation_df = evaluate_stt(actual_transcriptions, reference_transcriptions)
plot_evaluation_results(evaluation_df)