In [0]:
%pip install openai-whisper

In [0]:
import ray
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

In [0]:

num_cpu_cores_per_worker = 4 # total cpu's present in each node
num_cpus_head_node = 	4
num_gpu_per_worker = 1
num_gpus_head_node = 1

ray_conf = setup_ray_cluster(
  min_worker_nodes=4,
  max_worker_nodes=4,
  num_cpus_head_node= num_cpus_head_node,
  num_gpus_head_node= num_gpus_head_node,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker
  )


In [0]:
import ssl
import torch
import whisper
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd

In [0]:
class WhisperTranscription:
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        whisper_model_type ='medium'
        self.transcribe_whisper_model = whisper.load_model(whisper_model_type, device=DEVICE)

    def transcribe(self, audio_filename, whisper_model):
        try:
            transcription = whisper_model.transcribe(audio_filename)
        except Exception as e:
            print(e)
            transcription = {"language": None, "text": None}
        finally:
            return transcription


    def __call__(self, row) -> dict:
        print(row)
        filepath = row["file_path"]
        transcription = self.transcribe(filepath, self.transcribe_whisper_model)
        # row['audio_transcription'] = transcription

        return transcription

In [0]:
@F.pandas_udf(T.StringType())
def transcribe_udf(filepaths: pd.Series) -> pd.Series:
    import ray
    import ray.data

    ray.init(ray_conf[1])

    @ray.remote
    def ray_data_task(ds = None):
        ds = ray.data.from_pandas(pd.DataFrame(filepaths.to_list(),columns = ['file_path']))

        preds = (
          ds.repartition(filepaths.shape[0])
          .map(
              WhisperTranscription,
              compute=ray.data.ActorPoolStrategy(min_size=1,max_size=100),
              num_gpus=.5,
          )
        )

        final_df = preds.to_pandas()

        return final_df.iloc[:, 0]
    
    return ray.get(ray_data_task.remote(filepaths))

In [0]:
df_file_reference = spark.table("marcell.call_centre_processing.recording_file_reference")

In [0]:
df_transcriptions = df_file_reference.withColumn("transcription", transcribe_udf(F.col("file_path")))

In [0]:
df_transcriptions.write.mode("overwrite").saveAsTable("marcell.call_centre_processing.transcriptions")