In [None]:
#
# Summary:
#

#
# This Python script segments long audio files into smaller clips based on timing and transcription data stored in Azure Blob Storage.
# Here's what it does:
# 
# 1. Loads configuration settings (such as Azure container locations and working directories).
# 2. Downloads audio files and corresponding transcript JSON files from Azure.
# 3. Parses the transcript files to extract speech segments with timestamps and text.
# 4. Splits the audio files into individual .wav clips for each segment using ffmpeg.
# 5. Uploads the segmented audio files back to Azure, along with metadata like text and duration.
# 6. Cleans up temporary local files used during processing.
#
# The result is a set of labeled audio clips, each matched to its corresponding spoken phrase.
#

# 
# This Notebook requires the following environment variables (.env file):
# 

# SEGMENT_LOCAL_WORKING_DIR=
# SEGMENT_AUDIO_INPUT_CONTAINER_SAS_URI=
# SEGMENT_TEXT_INPUT_CONTAINER_SAS_URI=
# SEGMENT_TEXT_INPUT_PATH=
# SEGMENT_OUTPUT_CONTAINER_SAS_URI=


In [None]:
import ffmpeg
import json
import os
import shutil
import time
import uuid
from azure.storage.blob import BlobBlock, ContainerClient
from dotenv import load_dotenv

MAX_RETRIES = 3
BASE_DELAY = 1

def download_blobs(client: ContainerClient, dest: str, prefix: str=None):
    blobs = list(client.list_blobs(
        name_starts_with=prefix if prefix else None
    ))

    for blob in blobs:

        local_blob_path = os.path.join(
            dest, 
            os.path.basename(blob.name)
        )

        for attempt in range(1, MAX_RETRIES + 1):
            try: 
                with open(local_blob_path, "wb") as file:
                    stream = client.download_blob(blob.name)
                    for chunk in stream.chunks():
                        file.write(chunk)
                        
                break
            except Exception as e:
                if attempt == MAX_RETRIES:
                    raise
                else:
                    print(f"download_blobs() error: {str(e)}")
                    delay = BASE_DELAY * (2 ** (attempt -1))
                    time.sleep(delay)                    

def get_configuration() -> dict[str, str]:
    return {
        "local_working_dir": os.path.normpath(os.getenv('SEGMENT_LOCAL_WORKING_DIR')),
        "audio_input_container_sas": os.getenv('SEGMENT_AUDIO_INPUT_CONTAINER_SAS_URI'),
        "text_input_container_sas": os.getenv('SEGMENT_TEXT_INPUT_CONTAINER_SAS_URI'),
        "text_input_path": os.getenv('SEGMENT_TEXT_INPUT_PATH'),
        "output_container_sas": os.getenv('SEGMENT_OUTPUT_CONTAINER_SAS_URI')
    }

def get_container_client_from_sas(sas_uri: str) -> ContainerClient:
    return ContainerClient.from_container_url(sas_uri)

def get_temp_dir(path: str) -> str:
    dir = os.path.join(path, str(uuid.uuid4()))
    os.makedirs(dir, exist_ok=True)
    return dir

def parse_translation_file(input: str) -> dict[str, str]:
    with open(input, 'r') as content:
        data = json.load(content)
    
    segments = []
    if "recognizedPhrases" in data:
        for phrase in data["recognizedPhrases"]:
            start_time = float(phrase["offsetInTicks"]) / 10000000
            duration = float(phrase["durationInTicks"]) / 10000000
            end_time = start_time + duration
                
            text = ""
            if "nBest" in phrase and len(phrase["nBest"]) > 0:
                text = phrase["nBest"][0].get("display", "")
                
            segments.append({
                "start_time": start_time,
                "end_time": end_time,
                "text": text
            })

    return segments

def remove_temp_dir(path: str):
    shutil.rmtree(path)

def segment_audio(audio_source: str, text_source: str, dest: str) -> list[dict[str, str]]:
    results = []
    for text_file_name in os.listdir(text_source): # [FILENAME].wav.json
        base_text_file_name = os.path.splitext(text_file_name)[0] # [FILENAME].wav
        segments = parse_translation_file(os.path.join(text_source, text_file_name))
        if segments:
            for i, segment in enumerate(segments):
                start_time = segment["start_time"]
                end_time = segment["end_time"]
                duration = end_time - start_time

                segment_file_name = f"{os.path.splitext(base_text_file_name)[0]}_seg{i:03d}.wav"
                segment_file_path = os.path.join(dest, segment_file_name)

                (
                    ffmpeg
                    .input(os.path.join(audio_source, base_text_file_name), ss=start_time, t=duration)
                    .output(segment_file_path, acodec='pcm_s16le', ar='16000', ac=1, format='wav')
                    .run(quiet=True, overwrite_output=True)
                )

                results.append({
                    "name": segment_file_name,
                    "path": segment_file_path,
                    "start_time": segment["start_time"],
                    "end_time": segment["end_time"],
                    "duration": duration,
                    "text": segment["text"]
                })

    return results

def upload_blob_list_with_metadata(client: ContainerClient, blobs: list[dict[str, str]]):    
    for blob in blobs:
        if not os.path.isfile(blob["path"]):
            continue

        blob_client = client.get_blob_client(blob["name"])
        
        for attempt in range(1, MAX_RETRIES + 1):
            try:
                block_list = []
                chunk_size = 4 * 1024 * 1024

                with open(blob["path"], "rb") as file:
                    while True:
                        chunk = file.read(chunk_size)
                        if not chunk:
                            break
                        block_id = str(len(block_list)).zfill(6)
                        blob_client.stage_block(block_id=block_id, data=chunk)
                        block_list.append(BlobBlock(block_id=block_id))
                
                blob_client.commit_block_list(block_list)
                meta = {
                    "name": str(blob["name"]),
                    "duration": str(blob["duration"]),
                    "text": str(blob["text"])
                }
                blob_client.set_blob_metadata(meta)
                break
            except Exception as e:
                if attempt == MAX_RETRIES:
                    raise
                else:
                    print(f"upload_blob_list_with_metadata() error: {str(e)}")
                    delay = BASE_DELAY * (2 ** (attempt - 1))
                    time.sleep(delay)

def main():
    print("Notebook Cell Running...")

    print("Loading configuration...")
    load_dotenv()
    configuration = get_configuration()

    print("Creating local input and output directories...")
    audio_input_dir = get_temp_dir(configuration["local_working_dir"])
    text_input_dir = get_temp_dir(configuration["local_working_dir"])
    output_dir = get_temp_dir(configuration["local_working_dir"])

    print("Obtaining input and output container clients...")
    audio_input_client = get_container_client_from_sas(configuration['audio_input_container_sas'])
    text_input_client = get_container_client_from_sas(configuration['text_input_container_sas'])    
    output_client = get_container_client_from_sas(configuration['output_container_sas'])

    print("Downloading input files locally...")
    download_blobs(audio_input_client, audio_input_dir)
    download_blobs(text_input_client, text_input_dir, configuration['text_input_path'])

    print("Segmenting audio...")
    segments = segment_audio(audio_input_dir, text_input_dir, output_dir)

    print("Uploading output files to cloud...")
    upload_blob_list_with_metadata(output_client, segments)

    print("Cleaning local input and output directories...")
    remove_temp_dir(audio_input_dir)
    remove_temp_dir(text_input_dir)
    remove_temp_dir(output_dir)

    print("Done!")

main()