# Multi-model inference with Ray on Databricks

This notebook demonstrates how to perform multi-model inference using Ray on Databricks. We will:

1. Set up the environment and install necessary packages.
3. Download and prepare the LJSpeech dataset.
5. Download and set up models from Hugging Face.
6. Define classes to handle each inference step, including audio conversion, transcription, named entity redaction, and text classification.
7. Run inference using Ray Data native commands and save the results to a Delta table.

In [None]:
# using >=0.7.0 as it supports whisper and manually updating numba due to conflicts
%pip install vllm==0.7.0 pydub numba==0.61.0 databricks-sdk 
%pip install ray --upgrade
%restart_python

## Set catalog and schema

We set the catalog and schema to organise our data and ensure it is stored in the correct location. Change these to suit your workspace.

In [None]:
CATALOG = "marcell"
SCHEMA = "call_centre_processing"

Create catalog, schema and volume if they don't exist, and create directories for compressed, raw audio files and models.

In [None]:
spark.sql(f"CREATE CATALOG IF NOT EXISTS {CATALOG}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}")
spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG}.{SCHEMA}.data")
dbutils.fs.mkdirs(f"/Volumes/{CATALOG}/{SCHEMA}/data/compressed/LJSpeech")
dbutils.fs.mkdirs(f"/Volumes/{CATALOG}/{SCHEMA}/data/raw_audio/LJSpeech")
dbutils.fs.mkdirs(f"/Volumes/{CATALOG}/{SCHEMA}/data/models")

## Download raw audio files

We download the [LJSpeech dataset](https://paperswithcode.com/dataset/ljspeech) from the URL and unzip it to the raw audio directory. This is a collection of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. The files are stored in a tar.bz2 archive, so we will first download it and then unzip it.

In [None]:
# Download the LJSpeech dataset

import urllib.request

url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
target_file_path = f"/Volumes/{CATALOG}/{SCHEMA}/data/compressed/LJSpeech/LJSpeech-1.1.tar.bz2"
urllib.request.urlretrieve(url, target_file_path)

The unzipping can take quite some time (>1hr).

In [0]:
# Unzip the LJSpeech dataset

import tarfile

extract_to_path = f"/Volumes/{CATALOG}/{SCHEMA}/data/raw_audio/LJSpeech"
with tarfile.open(target_file_path, 'r:bz2') as tar_ref:
    tar_ref.extractall(extract_to_path)

## Create reference dataframe

We create a reference dataframe that contains the file paths of the raw audio files. We will use this dataframe to parallelize the inference process.

In [0]:
import pyspark.sql.functions as F

df_file_reference = spark.createDataFrame(dbutils.fs.ls(f"/Volumes/{CATALOG}/{SCHEMA}/data/raw_audio/LJSpeech/LJSpeech-1.1/wavs/"))\
  .withColumn("file_path", F.expr("substring(path, 6, length(path))")) # remove the leading dbfs:/ from the path

df_file_reference.display()

Write the dataframe to a Delta table.

In [0]:
df_file_reference.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable("{CATALOG}.{SCHEMA}.recording_file_reference")

## Download models from Hugging Face

We download two models from Hugging Face. We do this because it's more efficient to download these larger models once and retrieve them from storage for every batch of inference:
- [Whisper-medium](https://huggingface.co/openai/whisper-medium)
- [Phi-4](https://huggingface.co/microsoft/phi-4)




In [None]:
from transformers import pipeline
import torch

Whisper-medium is a state-of-the-art automatic speech recognition (ASR) model developed by OpenAI. It is designed to transcribe spoken language into written text with high accuracy.

In [None]:
WHISPER_MODEL_SAVE_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/data/models/whisper-medium/"

dbutils.fs.mkdirs(WHISPER_MODEL_SAVE_PATH)

whisper_pipeline = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-medium",
    torch_dtype=torch.float16,
    device="cuda:0"
)

whisper_pipeline.save_pretrained(WHISPER_MODEL_SAVE_PATH)

Phi-4 is a state-of-the-art language model developed by Microsoft. It is designed for text generation and can be used for various natural language processing tasks. We will use it for simple classification.

In [None]:
PHI_MODEL_SAVE_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/data/models/phi-4/"

dbutils.fs.mkdirs(PHI_MODEL_SAVE_PATH)

phi_pipeline = pipeline(
    "text-generation",
    model="microsoft/phi-4",
    model_kwargs={"torch_dtype": "auto"},
    device_map="auto",
)

phi_pipeline.save_pretrained(PHI_MODEL_SAVE_PATH)

## Run inference

We run inference on each audio recording and save the results to a Delta table. Our multi-step inference process will:
- Convert audio files to a format suitable for model inference using `ConverttoPrompt`.
- Transcribe audio recordings into text using the Whisper model with `WhisperTranscription`.
- Redact named entities from the transcriptions using `NERRedaction`.
- Classify the redacted text into predefined categories using `TextClassification`.

In [None]:
import ray
import os
from ray.util.spark import setup_ray_cluster

import ssl
import time

import pyspark.sql.types as T
import pandas as pd

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
import librosa
import pydub
import numpy as np

### Set up Ray on Databricks

To set up Ray on Databricks, we need to configure the Ray cluster and allocate resources such as CPU and GPU cores. This setup allows us to leverage Ray's distributed computing capabilities for efficient and scalable inference.

For more details, refer to the [Ray on Databricks documentation](https://docs.ray.io/en/latest/ray-on-databricks.html) and [What is Ray on Databricks?](https://docs.databricks.com/aws/en/machine-learning/ray) page.

In [None]:
num_cpu_cores_per_worker = 20 # number of cores to allocate to Ray per worker
num_cpus_head_node = 10 # number of cores to allocate to Ray on the head node
num_gpu_per_worker = 1 # number of GPUs to allocate to Ray per worker
num_gpus_head_node = 1 # number of GPUs to allocate to Ray on the head node
min_worker_nodes = 2 # autoscaling minimum number of workers
max_worker_nodes = 2 # autoscaling maximum number of workers

ray_conf = setup_ray_cluster(
  min_worker_nodes=min_worker_nodes,
  max_worker_nodes=max_worker_nodes,
  num_cpus_head_node= num_cpus_head_node,
  num_gpus_head_node= num_gpus_head_node,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker
  )


### Define classes to handle each inference step

Convert .wav file to VLLM prompt-compatible normalized numpy arrays.

In [None]:
class ConverttoPrompt:
    """
    This class handles the conversion of audio files to a format suitable for model inference.
    It reads audio files, converts them to numpy arrays, and normalizes the audio data.
    """

    def __init__(self):
        pass

    def transform(self, audio_filename):
        audio = pydub.AudioSegment.from_wav(audio_filename)
        samples = np.array(audio.get_array_of_samples())
        if audio.channels == 2:
            samples = samples.reshape((-1, 2))

        array = np.float32(samples) / 2**15
        frame_rate = audio.frame_rate
        return array, frame_rate

    def __call__(self, row) -> dict:
        array, frame_rate = self.transform(row["file_path"])
        row["array"] = list(array)
        row["frame_rate"] = frame_rate
        return row


Run Whisper inference on the numpy arrays with VLLM.

In [None]:
class WhisperTranscription:
    """
    This class handles the transcription of audio files using the Whisper model.
    It reads audio data, converts it to the required format, and performs transcription.

    Args:
        model_path (str): Path to the Whisper model.
        max_model_len (int, optional): Maximum length of the model. Defaults to 448.
        max_num_seqs (int, optional): Maximum number of sequences. Defaults to 400.
        kv_cache_dtype (str, optional): Data type for key-value cache. Defaults to "fp8".
        gpu_memory_utilization (float, optional): GPU memory utilization. Defaults to 0.5.
    """

    def __init__(
        self,
        model_path: str,
        max_model_len: int = 448,
        max_num_seqs: int = 400,
        kv_cache_dtype: str = "fp8",
        gpu_memory_utilization: float = 0.5,
    ):
        self.unverified_context = ssl._create_unverified_context()
        self.transcription_pipeline = LLM(
            model=str(model_path),
            max_model_len=max_model_len,
            max_num_seqs=max_num_seqs,
            kv_cache_dtype=kv_cache_dtype,
            gpu_memory_utilization=gpu_memory_utilization,
        )

    def transform(self, row):
        prompts = []
        for array, frame_rate in zip(list(row["array"]), list(row["frame_rate"])):
            prompts.append(
                {
                    "prompt": "<|startoftranscript|>",
                    "multi_modal_data": {"audio": (array, frame_rate)},
                }
            )
        return prompts

    def __call__(self, row) -> str:
        sampling_params = SamplingParams(
            temperature=0,
            top_p=1.0,
            max_tokens=500,
        )
        prompts = self.transform(row)
        outputs = self.transcription_pipeline.generate(prompts, sampling_params)

        del row["array"]
        del row["frame_rate"]

        row["transcription"] = [output.outputs[0].text for output in outputs]

        return row


Named Entity redaction replaces any mention of organization, person or location name with asterisks.

_Note: this could be done in a more sophisticated manner, here we are just masking mentions of named entities for simplicity._

In [None]:
class NERRedaction:
    """
    This class handles the redaction of named entities (NER) from transcriptions.
    It uses a pre-trained NER model to identify and redact sensitive information such as
    organization names, personal names, and locations.

    Methods:
        redact(text, pipeline): Redacts named entities from the given text using the specified pipeline.
        __call__(row): Applies the redaction process to the transcription in the given row.

    Attributes:
        ner_pipeline: Pre-trained NER model pipeline for entity recognition.
    """
    def __init__(self):
        self.unverified_context = ssl._create_unverified_context()
        self.ner_pipeline = pipeline("ner", device="cuda:0")

    def redact(self, text, pipeline):
        try:
            entities = pipeline(text)
            for entity in entities:
                if entity['entity'] in ['I-ORG', 'I-PER', 'I-LOC']:
                    start, end = entity['start'], entity['end']
                    text = text[:start] + '*' * (end - start) + text[end:]

            redacted_text = {'redacted_text': text}
        except Exception as e:
            redacted_text = {'redacted_text': None}
        finally:
            return redacted_text

    def __call__(self, row: dict) -> dict:
        text = row["transcription"]
        redacted_text = self.redact(text, self.ner_pipeline)

        row["redacted_text"] = redacted_text["redacted_text"]

        return row


Finally, we define a class for zero-shot text classification of the transcribed and redacted audio using a general LLM.

In [None]:
class TextClassification:
    """
    This class handles the classification of text into predefined categories.
    It uses a pre-trained language model to classify the text based on the content.

    Args:
        model_path (str): Path to the classification model.
        enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to True.
        gpu_memory_utilization (float, optional): GPU memory utilization. Defaults to 0.5.
        temperature (float, optional): Sampling temperature for the model. Defaults to 0.5.
        max_tokens (int, optional): Maximum number of tokens for the model output. Defaults to 128.
    """
    def __init__(self, model_path: str, enforce_eager: bool = True, gpu_memory_utilization: float = 0.5, temperature: float = 0.5, max_tokens: int = 128):
        self.unverified_context = ssl._create_unverified_context()
        self.cls_pipeline = LLM(
            model=model_path,
            enforce_eager=enforce_eager,
            gpu_memory_utilization=gpu_memory_utilization
        )
        self.sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)

    def create_prompt(self, redacted_text) -> dict:
        prompt = [
            [
                {"role": "system", "content": "You are an expert at determining the underlying category of a short text passage. Your input is a short text passage and your output is a category. Do not output anything else but one of the following categories that best fits the text passage: 'Politics', 'Sports', 'Entertainment', 'Technology', 'Personal', 'Other'."},
                {"role": "user", "content": text},
            ] for text in redacted_text
        ]
        return prompt

    def __call__(self, row: dict) -> dict:
        conversation = self.create_prompt(row["redacted_text"])
        outputs = self.cls_pipeline.chat(
            conversation,
            sampling_params=self.sampling_params,
            use_tqdm=False
        )
        row['classification'] = [output.outputs[0].text for output in outputs]
        return row

### Run inference via Ray Data native commands

To run the inference, we will use Ray Data's native commands to parallelize and distribute the workload across multiple nodes. This approach ensures efficient processing of large datasets by leveraging Ray's distributed computing capabilities.

In [None]:
# Define and create a temporary directory for Ray to use when writing Delta tables
temp_dir = f"/Volumes/{CATALOG}/{SCHEMA}/data/tmp"
dbutils.fs.mkdirs(temp_dir)
os.environ["RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = temp_dir

Read file reference dataframe from Unity Catalog.

In [None]:
df_file_reference = spark.table(f"{CATALOG}.{SCHEMA}.recording_file_reference")

### Ray-native Inference Pipeline

We run the inference pipeline using Ray Data's native commands to parallelize and distribute the workload across multiple (GPU) nodes.


1. **map and map_batches**:
   - `map`: This method applies a function to each element of the dataset individually. In this case, `ConverttoPrompt` and `NERRedaction` are applied to each element.
   - `map_batches`: This method applies a function to batches of elements in the dataset. Here, `whisper_transcription` and `text_classification` are applied to batches of elements, which can improve performance by reducing the overhead of function calls.

2. **num_cpus and num_gpus**:
   - `num_cpus`: This argument specifies the number of CPU cores to allocate for each task. For example, `num_cpus=1` allocates one CPU core for each `ConverttoPrompt` task.
   - `num_gpus`: This argument specifies the number of GPUs to allocate for each task. For example, `num_gpus=float(40 / 80)` allocates half a GPU for each `whisper_transcription` task.

3. **min_size and max_size in ray.data.ActorPoolStrategy**:
   - `min_size`: This argument specifies the minimum number of actors (workers) to keep in the pool. For example, `min_size=10` ensures that at least 10 actors are available for `ConverttoPrompt`.
   - `max_size`: This argument specifies the maximum number of actors (workers) to keep in the pool. For example, `max_size=100` allows up to 100 actors for `ConverttoPrompt`.

In [None]:
whisper_transcription = WhisperTranscription(model_path=WHISPER_MODEL_SAVE_PATH)
text_classification = TextClassification(model_path=PHI_MODEL_SAVE_PATH)

ds = ray.data.from_spark(df_file_reference)

ds = ds.repartition(200) \
    .map(
        ConverttoPrompt,
        compute=ray.data.ActorPoolStrategy(min_size=10, max_size=100),
        num_cpus=1,
    ) \
    .map_batches(
        whisper_transcription,
        compute=ray.data.ActorPoolStrategy(min_size=3, max_size=6),
        num_gpus=float(40 / 80),
        batch_size=256
    ) \
    .map(
        NERRedaction,
        compute=ray.data.ActorPoolStrategy(min_size=1, max_size=50),
        num_gpus=float(1 / 15)
    ) \
    .map_batches(
        text_classification,
        compute=ray.data.ActorPoolStrategy(min_size=1, max_size=6),
        num_gpus=float(40 / 80),
        batch_size=256
    )

We save the processed audio data to a Delta table in Databricks. This allows us to store the results of our inference pipeline in a structured format that can be easily queried and analyzed. For more details on how to use the `write_databricks_table` method, refer to the [Databricks documentation](https://docs.databricks.com/aws/en/machine-learning/ray/connect-spark-ray#write-ray-data-to-spark).

In [None]:
ds.write_databricks_table(f"{CATALOG}.{SCHEMA}.processed_audio", mode='overwrite', mergeSchema=True)