In [0]:
# using >=0.7.0 as it supports whisper and manually updating numba due to conflicts
%pip install vllm==0.7.0 pydub numba==0.61.0 databricks-sdk 
%pip install ray --upgrade
%restart_python

# Job parameters

In [0]:
params = dbutils.widgets.getAll()
print(params)

## Set catalog, schema and model paths

We set the catalog and schema to organise our data and ensure it is stored in the correct location. Change these to suit your workspace.

In [0]:
catalog = params["catalog"]
schema = params["schema"]
transcription_model_id = params["transcription_model_id"]
transcription_model_save_path = f'/Volumes/{catalog}/{schema}/data/models/{transcription_model_id.replace("-", "_").replace("/", "_")}'
llm_model_id = params["llm_model_id"]
llm_model_save_path = f'/Volumes/{catalog}/{schema}/data/models/{llm_model_id.replace("-", "_").replace("/", "_")}'

## Run inference

We run inference on each audio recording and save the results to a Delta table. Our multi-step inference process will:
- Convert audio files to a format suitable for model inference using `ConverttoPrompt`.
- Transcribe audio recordings into text using the Whisper model with `WhisperTranscription`.
- Redact named entities from the transcriptions using `NERRedaction`.
- Classify the redacted text into predefined categories using `TextClassification`.

In [0]:
import ray
import os
from ray.util.spark import setup_ray_cluster

import ssl
import time

import pyspark.sql.types as T
import pandas as pd

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from transformers import pipeline
import librosa
import pydub
import numpy as np

### Set up Ray on Databricks

To set up Ray on Databricks, we need to configure the Ray cluster and allocate resources such as CPU and GPU cores. This setup allows us to leverage Ray's distributed computing capabilities for efficient and scalable inference.

For more details, refer to the [Ray on Databricks documentation](https://docs.ray.io/en/latest/ray-on-databricks.html) and [What is Ray on Databricks?](https://docs.databricks.com/aws/en/machine-learning/ray) page.

In [0]:
num_cpu_cores_per_worker = 20 # number of cores to allocate to Ray per worker
num_cpus_head_node = 10 # number of cores to allocate to Ray on the head node
num_gpu_per_worker = 1 # number of GPUs to allocate to Ray per worker
num_gpus_head_node = 1 # number of GPUs to allocate to Ray on the head node
min_worker_nodes = 2 # autoscaling minimum number of workers
max_worker_nodes = 2 # autoscaling maximum number of workers

ray_conf = setup_ray_cluster(
  min_worker_nodes=min_worker_nodes,
  max_worker_nodes=max_worker_nodes,
  num_cpus_head_node= num_cpus_head_node,
  num_gpus_head_node= num_gpus_head_node,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker
  )


### Define classes to handle each inference step

Ray Data's `map` and `map_batches` expect a callable class, so we define those for each step here.

Convert .wav file to VLLM prompt-compatible normalized numpy arrays.

In [0]:
class ConverttoPrompt:
    """
    This class handles the conversion of audio files to a format suitable for model inference.
    It reads audio files, converts them to numpy arrays, and normalizes the audio data.
    """

    def __init__(self):
        pass

    def transform(self, audio_filename):
        audio = pydub.AudioSegment.from_wav(audio_filename)
        samples = np.array(audio.get_array_of_samples())
        if audio.channels == 2:
            samples = samples.reshape((-1, 2))

        array = np.float32(samples) / 2**15
        frame_rate = audio.frame_rate
        return array, frame_rate

    def __call__(self, row) -> dict:
        array, frame_rate = self.transform(row["file_path"])
        row["array"] = list(array)
        row["frame_rate"] = frame_rate
        return row


Run Whisper inference on the numpy arrays with VLLM.

In [0]:
class TranscriptionStep:
    """
    This class handles the transcription of audio files using the Whisper model.
    It reads audio data, converts it to the required format, and performs transcription.
    """

    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        self.transcription_pipeline = LLM(
            model=transcription_model_save_path,
            max_model_len=448,
            max_num_seqs=400,
            kv_cache_dtype="fp8",
            gpu_memory_utilization=float(40/80),
        )

    def transform(self, row):
        prompts = []
        for array, frame_rate in zip(list(row["array"]), list(row["frame_rate"])):
            prompts.append(
                {
                    "prompt": "<|startoftranscript|>",
                    "multi_modal_data": {"audio": (array, frame_rate)},
                }
            )
        return prompts

    def __call__(self, row) -> str:
        sampling_params = SamplingParams(
            temperature=0,
            top_p=1.0,
            max_tokens=500,
        )
        prompts = self.transform(row)
        outputs = self.transcription_pipeline.generate(prompts, sampling_params)

        del row["array"]
        del row["frame_rate"]

        row["transcription"] = [output.outputs[0].text for output in outputs]

        return row


Named Entity redaction replaces any mention of organization, person or location name with asterisks.

_Note: this could be done in a more sophisticated manner, here we are just masking mentions of named entities for simplicity._

In [0]:
class RedactionStep:
    """
    This class handles the redaction of named entities (NER) from transcriptions.
    It uses a pre-trained NER model to identify and redact sensitive information such as
    organization names, personal names, and locations.
    """
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        self.ner_pipeline = pipeline("ner", device="cuda:0")

    def redact(self, text, pipeline):
        try:
            entities = pipeline(text)
            for entity in entities:
                if entity['entity'] in ['I-ORG', 'I-PER', 'I-LOC']:
                    start, end = entity['start'], entity['end']
                    text = text[:start] + '*' * (end - start) + text[end:]

            redacted_text = {'redacted_text': text}
        except Exception as e:
            redacted_text = {'redacted_text': None}
        finally:
            return redacted_text

    def __call__(self, row: dict) -> dict:
        text = row["transcription"]
        redacted_text = self.redact(text, self.ner_pipeline)

        row["redacted_text"] = redacted_text["redacted_text"]

        return row


Finally, we define a class for zero-shot text classification of the transcribed and redacted audio using a general LLM.

In [0]:
class ClassificationStep:
    """
    This class handles the classification of text into predefined categories.
    It uses a pre-trained language model to classify the text based on the content.

    """
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        self.cls_pipeline = LLM(
            model=llm_model_save_path,
            enforce_eager=True,
            gpu_memory_utilization=float(40/80)
        )
        self.sampling_params = SamplingParams(temperature=0.1, max_tokens=128)

    def create_prompt(self, redacted_text) -> dict:
        prompt = [
            [
                {"role": "system", "content": "You are an expert at determining the underlying category of a short text passage. Your input is a short text passage and your output is a category. Do not output anything else but one of the following categories that best fits the text passage: 'Politics', 'Sports', 'Entertainment', 'Technology', 'Personal', 'Other'."},
                {"role": "user", "content": text},
            ] for text in redacted_text
        ]
        return prompt

    def __call__(self, row: dict) -> dict:
        conversation = self.create_prompt(row["redacted_text"])
        outputs = self.cls_pipeline.chat(
            conversation,
            sampling_params=self.sampling_params,
            use_tqdm=False
        )
        row['classification'] = [output.outputs[0].text for output in outputs]
        return row

### Run inference via Ray Data native commands

To run the inference, we will use Ray Data's native commands to parallelize and distribute the workload across multiple nodes. This approach ensures efficient processing of large datasets by leveraging Ray's distributed computing capabilities.

In [0]:
# Define and create a temporary directory for Ray to use when writing Delta tables
temp_dir = f"/Volumes/{catalog}/{schema}/data/tmp"
dbutils.fs.mkdirs(temp_dir)
os.environ["RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = temp_dir

Read file reference dataframe from Unity Catalog.

In [0]:
df_file_reference = spark.table(f"{catalog}.{schema}.recording_file_reference")

We run the inference pipeline using Ray Data's native commands to parallelize and distribute the workload across multiple (GPU) nodes.

This is our opportunity to make the pipeline run efficiently.

- **map vs. map_batches**:
   - Use `map` for operations that are lightweight and need to be applied to each element individually.
   - Use `map_batches` for operations that are more computationally intensive and can benefit from batch processing to reduce overhead.

- **num_cpus**:
   - Allocate more CPU cores for tasks that are CPU-bound and require significant processing power.
   - Ensure that the total number of CPU cores allocated does not exceed the available cores in your cluster.

- **num_gpus**:
   - Allocate GPUs for tasks that can leverage GPU acceleration, such as deep learning inference.
   - Ensure that the total number of GPUs allocated does not exceed the available GPUs in your cluster.

- **min_size and max_size in ray.data.ActorPoolStrategy**:
   - Set `min_size` to ensure a minimum number of actors are always available, which can help maintain a steady throughput.
   - Set `max_size` based on the maximum parallelism you want to achieve, considering the available resources and the nature of the task.

In [0]:
ds = ray.data.from_spark(df_file_reference)

ds = ds.repartition(200) \
    .map(
        ConverttoPrompt,
        compute=ray.data.ActorPoolStrategy(min_size=10, max_size=100),
        num_cpus=1,
    ) \
    .map_batches(
        TranscriptionStep,
        compute=ray.data.ActorPoolStrategy(min_size=3, max_size=6),
        num_gpus=float(40 / 80),
        batch_size=256
    ) \
    .map(
        RedactionStep,
        compute=ray.data.ActorPoolStrategy(min_size=1, max_size=45),
        num_gpus=float(1 / 15)
    ) \
    .map_batches(
        ClassificationStep,
        compute=ray.data.ActorPoolStrategy(min_size=1, max_size=6),
        num_gpus=float(40 / 80),
        batch_size=256
    )

We save the processed audio data to a Delta table in Databricks. This allows us to store the results of our inference pipeline in a structured format that can be easily queried and analyzed. For more details on how to use the `write_databricks_table` method, refer to the [Databricks documentation](https://docs.databricks.com/aws/en/machine-learning/ray/connect-spark-ray#write-ray-data-to-spark).

In [0]:
ds.write_databricks_table(f"{catalog}.{schema}.processed_audio", mode='overwrite', mergeSchema=True)