In [0]:
import pyspark.sql.functions as F

import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import json



In [0]:
audio_table = spark.table("marcell.call_centre_processing.recording_file_reference")\
  .withColumn("modification_timestamp", F.from_unixtime(F.col("modificationTime")/1000).cast("timestamp"))\
  .withColumn("recording_timestamp", F.expr("timestamp_seconds(cast(rand() * (unix_timestamp('2025-01-27 23:59:59') - unix_timestamp('2025-01-01 00:00:00')) as int) + unix_timestamp('2025-01-01 00:00:00'))"))

num_partitions = 4 # enforce to number of GPU workers?
audio_table = audio_table.repartition(num_partitions).cache()
audio_table.display()

In [0]:
import pandas as pd
from typing import Iterator

In [0]:
@F.pandas_udf("string")
def transcribe(path_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    import torch
    from transformers import pipeline
    from transformers.utils import is_flash_attn_2_available
    import json
    pipe = pipeline(
        "automatic-speech-recognition",
        model="openai/whisper-medium", # select checkpoint from https://huggingface.co/openai/whisper-large-v3#model-details
        # model="distil-whisper/large-v2",
        torch_dtype=torch.float16,
        device="cuda:0", # or mps for Mac devices
        # Comment this out if you plan on using T4 gpus, flash attention is only supported on ampere (A10/A100) architecture
        model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
    )

    def add_suffix(path: str):
        outputs = pipe(
            path,
            chunk_length_s=30,
            batch_size=24,
            return_timestamps=True,
        )
        return json.dumps(outputs["chunks"])

    for path_chunk in path_iterator:
        yield path_chunk.apply(add_suffix)

In [0]:
audio_table.withColumn("transcription", transcribe("file_path")).write.mode("overwrite").option("overwriteSchema", "true").saveAsTable("marcell.call_centre_processing.transcriptions_udf")