In [0]:
# using >=0.7.0 as it supports whisper and manually updating numba due to conflicts
%pip install vllm==0.7.0 pydub numba==0.61.0 databricks-sdk 
%pip install ray --upgrade

In [0]:
dbutils.library.restartPython()

In [0]:
import ray
import os
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster
from transformers import pipeline
import torch

import ssl
import time
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
import librosa
import pydub
import numpy as np
from mlflow.utils.databricks_utils import get_databricks_env_vars

In [0]:

num_cpu_cores_per_worker = 20 # total cpu's present in each node
num_cpus_head_node = 	10
num_gpu_per_worker = 1
num_gpus_head_node = 1


# Set databricks credentials as env vars
mlflow_dbrx_creds = get_databricks_env_vars("databricks")
os.environ["DATABRICKS_HOST"] = mlflow_dbrx_creds['DATABRICKS_HOST']
os.environ["DATABRICKS_TOKEN"] = mlflow_dbrx_creds['DATABRICKS_TOKEN']

ray_conf = setup_ray_cluster(
  min_worker_nodes=2,
  max_worker_nodes=2,
  num_cpus_head_node= num_cpus_head_node,
  num_gpus_head_node= num_gpus_head_node,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker
  )


In [0]:
from util import stage_registered_model , flatten_folder

In [0]:
# Why using vllm  and what was the alternative and why we used it ? 

In [0]:
class ConverttoPrompt:
    def __init__(self):
        pass

    def transform(self, audio_filenames):
        a = pydub.AudioSegment.from_wav(audio_filenames)
        y = np.array(a.get_array_of_samples())
        if a.channels == 2:
            y = y.reshape((-1, 2))

        array = np.float32(y) / 2**15
        frame_rate =  a.frame_rate
        return array,frame_rate




    def __call__(self, row) -> str:
        array ,frame_rate = self.transform(row["file_path"])
        
        row['array'] = list(array)
        row['frame_rate'] = frame_rate
        return row


class WhisperTranscription:
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        print("Trying to load model...")
        # WHISPER_MODEL_PATH = stage_registered_model(
        #                     catalog = 'mlops_pj', 
        #                     schema = "gsk_gsc_cfu_count", 
        #                     model_name = "whisper_large_v3",
        #                     version= 2,
        #                     local_base_path = "/local_disk0/models/",
        #                     overwrite = False)
        # flatten_folder(WHISPER_MODEL_PATH)
        WHISPER_MODEL_PATH = "/Volumes/marcell/call_centre_processing/data/model_artifacts/whisper-medium/"
        self.transcription_pipeline = LLM(
                            model=str(WHISPER_MODEL_PATH),
                            max_model_len=448,
                            max_num_seqs=400,
                            kv_cache_dtype="fp8",
                            gpu_memory_utilization = 40/80)
        print("Model loaded...")

    def transform(self, row):
        prompts = []
        for array,frame_rate in zip(list(row['array']),list(row['frame_rate'])):
            prompts.append({"prompt": "<|startoftranscript|>",
                                "multi_modal_data":{"audio": (array,frame_rate)}})
            
        return prompts




    def __call__(self, row) -> str:

        # Create a sampling params object.
        sampling_params = SamplingParams(
            temperature=0,
            top_p=1.0,
            max_tokens=500,
        )
        prompts = self.transform(row)

        outputs = self.transcription_pipeline.generate(prompts, sampling_params)

        del row['array']
        del row['frame_rate']

        row['transcription'] = [ output.outputs[0].text for output in outputs]

        return row




class NERRedaction:
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        print("Trying to load NER model...")
        self.ner_pipeline = pipeline("ner", device="cuda:0")
        print("NER model loaded.")


    def redact(self, text, pipeline):
        try:
            entities = pipeline(text)
            for entity in entities:
                if entity['entity'] in ['I-ORG', 'I-PER', 'I-LOC']:
                    start, end = entity['start'], entity['end']
                    text = text[:start] + '*'*(end - start) + text[end:]

            redacted_text = {'redacted_text': text}
        except Exception as e:
            print(e)
            redacted_text = {'redacted_text': None}
        finally:
            return redacted_text



    def __call__(self, row:dict) -> dict:
        text = row["transcription"]
        redacted_text = self.redact(text, self.ner_pipeline)

        row["redacted_text"] = redacted_text["redacted_text"]

        return row

In [0]:
class TextClassification:
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        print("Trying to load NER model...")
        MODEL_SAVE_PATH = "/Volumes/marcell/call_centre_processing/data/model_artifacts/phi-4/"
        self.cls_pipeline = LLM(model=MODEL_SAVE_PATH,
                                enforce_eager=True,
                               gpu_memory_utilization = 40/80)
        self.sampling_params = SamplingParams(temperature=0.5,max_tokens=128)
        print("Classification model loaded.")

    def create_prompt(self,redacted_text) -> dict:

        prompt = [[
            {"role": "system", "content": "You are an expert at determining the underlying category of a short text passage. Your input is a short text passage and your output is a category. Do not output anything else but one of the following categories that best fits the text passage: 'Politics', 'Sports', 'Entertainment', 'Technology', 'Personal', 'Other'."},
            {"role": "user", "content": text},
        ] for text in redacted_text]

        return prompt
  
    def __call__(self, row:dict) -> dict:
        conversation = self.create_prompt(row["redacted_text"])
        outputs = self.cls_pipeline.chat(conversation,
                   sampling_params=self.sampling_params,
                   use_tqdm=False)
        row['classification'] = [ output.outputs[0].text for output in outputs]
        return row

In [0]:
df_file_reference = spark.table("marcell.call_centre_processing.recording_file_reference")
# df_file_reference = spark.table("marcell.call_centre_processing.sample_data")
df_file_reference.count()

## Option1 : Code to run the function inside a udf

In [0]:
@F.pandas_udf(T.StringType())
def transcribe_udf(filepaths: pd.Series) -> pd.Series:
    import ray
    import ray.data

    ray.init(ray_conf[1], ignore_reinit_error=True)

    @ray.remote
    def ray_data_task(ds = None):
        ds = ray.data.from_pandas(pd.DataFrame(filepaths.to_list(),columns = ['file_path']))
        print("Length of filepaths ", len(filepaths))
        preds = (
        ds.repartition(filepaths.shape[0])
        .map(
            ConverttoPrompt,
            compute=ray.data.ActorPoolStrategy(min_size=1,max_size=40),
            num_cpus=1,
        )
        .map_batches(
              WhisperTranscription,
              compute=ray.data.ActorPoolStrategy(min_size=1,max_size=40),
              num_gpus=float(40/80),
              batch_size = 256
          )
        )

        final_df = preds.to_pandas()

        return final_df["transcription"]
    
    return ray.get(ray_data_task.remote(filepaths))

In [0]:
# df_transcriptions = df_file_reference.repartition(1).withColumn("transcription_medium", transcribe_udf(F.col("file_path")))
# df_transcriptions.write.mode("overwrite").saveAsTable("marcell.call_centre_processing.whisper_medium_transcriptions")

## Option 2 : Run it via Ray Data native commands

In [0]:
import os
os.environ["RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/marcell/call_centre_processing/data/tempDoc"

In [0]:
df_file_reference.display()

In [0]:
ds = ray.data.from_spark(df_file_reference)
ds =ds.repartition(200)\
        .map(
            ConverttoPrompt,
            compute=ray.data.ActorPoolStrategy(min_size=10,max_size=100),
            num_cpus=1,
        )\
        .map_batches(
              WhisperTranscription,
              compute=ray.data.ActorPoolStrategy(min_size=3,max_size=6),
              num_gpus=float(40/80),
              batch_size = 256
          )\
          .map(
              NERRedaction,
              compute=ray.data.ActorPoolStrategy(min_size=1,max_size=50),
              num_gpus=float(1/15)
          )\
         .map_batches(
              TextClassification,
              compute=ray.data.ActorPoolStrategy(min_size=1,max_size=6),
              num_gpus=float(40/80),
              batch_size = 256
          )

In [0]:
# binary search is the way to think about this.
# parallelism and guessing the right number of workers is the way to do it.

In [0]:
# give the logic about how you would go about running the same workload  in lower tier hardware and then 
# and give chunky verbose content explanation on this.

In [0]:
# test the logic DE and get them 

In [0]:

ds.write_databricks_table("marcell.call_centre_processing.whisper_medium_transcriptions_vllm_2", mode = 'overwrite', mergeSchema = True)

In [0]:
spark.table("marcell.call_centre_processing.whisper_medium_transcriptions_vllm_2").display()