# Automatic-Speech-Recognition batch inference reference solution

Tested on:
```
* MLR 15.4LTS GPU Runtime
* Collection of `.wav` files from the LJSpeech dataset
* GPU Cluster with 1 driver node and 1 worker node of `g5.12xlarge[A10G]`
```

**IMPORTANT:**

Set these `spark configs` on the cluster before starting it:

* `spark.databricks.pyspark.dataFrameChunk.enabled true`
* `spark.task.resource.gpu.amount 0`

In [0]:
# using vllm>=0.7.0 as it supports whisper and manually updating numba due to conflicts
%pip install -qU databricks-sdk numba==0.61.0 pydub ray vllm==0.7.0  


%restart_python

In [0]:
import warnings


warnings.filterwarnings("ignore")

In [0]:
import librosa
import numpy as np
import os
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd
import pydub
import ray
import ssl
import time
import torch

from mlflow.utils.databricks_utils import get_databricks_env_vars
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster
from transformers import pipeline
from util import stage_registered_model, flatten_folder
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

## Setup and start Ray cluster
Some best practices for scaling up Ray clusters [here](https://docs.databricks.com/en/machine-learning/ray/scale-ray.html#scale-ray-clusters-on-databricks) :
* `num_cpus_*` always leave 1 CPU core for spark so value should be <= max cores per worker - 1

In [0]:
num_cpu_cores_per_worker = 48-1 # total cpu's present in each worker node (g5.12xlarge)
num_cpus_head_node = 	48-1 # total cpu's present in the driver node (g5.12xlarge)
num_gpu_per_worker = 4
num_gpus_head_node = 4

# Set databricks credentials as env vars
mlflow_dbrx_creds = get_databricks_env_vars("databricks")
os.environ["DATABRICKS_HOST"] = mlflow_dbrx_creds['DATABRICKS_HOST']
os.environ["DATABRICKS_TOKEN"] = mlflow_dbrx_creds['DATABRICKS_TOKEN']

ray_conf = setup_ray_cluster(
  min_worker_nodes=1,
  max_worker_nodes=1,
  num_cpus_head_node= num_cpus_head_node,
  num_gpus_head_node= num_gpus_head_node,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker
  )

In [0]:
CATALOG = "amine_elhelou" # system
SCHEMA = "ray_gtm_examples" # ai
ASR_MODEL_NAME = "whisper-large-v3-turbo" # whisper-large-v3
PII_MODEL_NAME = "piiranha-v1"
MODEL_ALIAS = "Production"
# MODEL_VERSION = 1

## Pre-requisite: download the models from MLflow registry into every node once to avoid multiple download conflicts

In [0]:
from util import run_on_every_node


ray.init()

In [0]:
@ray.remote(num_cpus=1)
def download_model(catalog,
                  schema ,
                  model_name, 
                  alias = "Production",
                  local_base_path = "/local_disk0/models/",
                  overwrite = False):
    model_weights_path = stage_registered_model(
                  catalog = CATALOG,
                  schema =  SCHEMA,
                  model_name = model_name,
                  alias = alias,
                  local_base_path = local_base_path,
                  overwrite = overwrite)
    flatten_folder(model_weights_path)

In [0]:
import mlflow


# Point to UC registry (in case not default)
mlflow.set_registry_uri("databricks-uc")

# Execute
_ = run_on_every_node(download_model , **{
                  "catalog": CATALOG,
                  "schema": SCHEMA,
                  "model_name": ASR_MODEL_NAME,
                  "alias": MODEL_ALIAS
                  })
_ = run_on_every_node(download_model , **{
                  "catalog": CATALOG,
                  "schema": SCHEMA,
                  "model_name": PII_MODEL_NAME,
                  "alias": MODEL_ALIAS
                  })

## Write ray-friendly `__call__`-able classes for batch processing

For VLLM, one parameter to configure would be:

* `gpu_memory_utilization`: will define how many model instances will be created in a single GPU. This would depend on model's size and GPU VRAM in theory. For example: `whisper-v3-large` is ~1.55B at FP32 would require 10GB of memory and an A10G's VRAM is 24GB implies that this parameter could be set to 2.

### Ingest and transcribe pipelines

In [0]:
class ConverttoPrompt:
    """
    Class which whill read audio files and convert them to numpy arrays
    """
    
    def __init__(self):
        pass

    def transform(self, audio_filenames):
        # CHANGE THIS BASED ON YOUR AUDIO FILE SOURCES
        a = pydub.AudioSegment.from_wav(audio_filenames) # .from_mp3()audio_filenames
        y = np.array(a.get_array_of_samples())
        if a.channels == 2:
            y = y.reshape((-1, 2))

        array = np.float32(y) / 2**15
        frame_rate =  a.frame_rate
        return array,frame_rate


    def __call__(self, row) -> str:
        array ,frame_rate = self.transform(row["file_path"])
        
        row['array'] = list(array)
        row['frame_rate'] = frame_rate
        return row


class WhisperTranscription:
    """
    Class which will handle transcription of audio files (in batch fashion using VLLM)
    """
    def __init__(self, catalog:str, schema:str, model_name:str, model_alias:str = "Production"):
        self.unverified_context = ssl._create_unverified_context()
        print("Loading model from UC registry...")
        model_weights_path = stage_registered_model(
                            catalog = catalog, # "system"
                            schema = schema, #"ai"
                            model_name = model_name, # whisper_large_v3",
                            alias = model_alias,
                            # version = model_version,
                            local_base_path = "/local_disk0/models/",
                            overwrite = False)
        flatten_folder(model_weights_path)
        model_weights_path = str(model_weights_path)  #convert from Posit to string required by TF
        self.WHISPER_MODEL_PATH = model_weights_path

        # Create VLLM pipeline object
        self.transcription_pipeline = LLM(
                            model=model_weights_path,
                            max_model_len=448, # Max chunk size to be sliced into for long audio transcripts (READ VLLM config for whisper-v3-large-turbo model)
                            max_num_seqs=400,
                            kv_cache_dtype="fp8",
                            enforce_eager=True,
                            gpu_memory_utilization = 1) # How many models to load per GPU, depending on model size and GPU RAM (default to 1 to avoid OOM errors)
        print("Model loaded...")

    def transform(self, row):
        """
        Format the input audio stream/array to be passed to the VLLM pipeline according to how the model is expecting it (e.g. {
            "prompt" : "<|startoftranscript|>",
            "multi_modal_data": 
                { "audio" : (<array, frame_rate>) }
            }
        """

        # Prepare batch of prompts to be passed to the VLLM pipeline
        prompts = []
        for array,frame_rate in zip(list(row['array']),list(row['frame_rate'])):
            prompts.append({"prompt": "<|startoftranscript|>",
                                "multi_modal_data":{"audio": (array,frame_rate)}})
            
        return prompts


    def __call__(self, row) -> str:
        """
        Call method applying all pipeline steps (in batch)
        """

        # Create a sampling params inference object
        sampling_params = SamplingParams(
            temperature=0,
            top_p=1.0,
            max_tokens=500,
        )
        prompts = self.transform(row)

        outputs = self.transcription_pipeline.generate(prompts, sampling_params)

        del row['array']
        del row['frame_rate']

        row['transcription'] = [ output.outputs[0].text for output in outputs ]

        return row

### PII Redaction pipeline

In [0]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification


class PIIRedaction:
    """
    Class which will handle redaction of audio transcripts (per record)
    """
    
    def __init__(self, catalog:str, schema:str, model_name:str, model_alias:str = "Production"):        
        print("Loading PII-redaction model from UC registry...")
        model_weights_path = stage_registered_model(
                            catalog = catalog,
                            schema = schema,
                            model_name = model_name,
                            alias = model_alias,
                            # version = model_version,
                            local_base_path = "/local_disk0/models/",
                            overwrite = False)
        flatten_folder(model_weights_path)
        model_weights_path = str(model_weights_path)  #convert from Posit to string required by TF
        self.PII_MODEL_PATH = model_weights_path
        self.tokenizer = AutoTokenizer.from_pretrained(model_weights_path) # model_id
        self.model = AutoModelForTokenClassification.from_pretrained(model_weights_path) # model_id

    def _mask_pii(self, text, aggregate_redaction=False):
        """
        Apply redaction to the text based on source code provided by the PII/NER Model
        """

        def apply_redaction(masked_text, start, end, pii_type, aggregate_redaction):
            for j in range(start, end):
                masked_text[j] = ''
            if aggregate_redaction:
                masked_text[start] = '[redacted]'
            else:
                masked_text[start] = f'[{pii_type}]'

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

        # Tokenize input text
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get the model predictions
        with torch.no_grad():
            outputs = self.model(**inputs)

        # Get the predicted labels
        predictions = torch.argmax(outputs.logits, dim=-1)

        # Convert token predictions to word predictions
        encoded_inputs = self.tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
        offset_mapping = encoded_inputs['offset_mapping']

        masked_text = list(text)
        is_redacting = False
        redaction_start = 0
        current_pii_type = ''

        for i, (start, end) in enumerate(offset_mapping):
            if start == end:  # Special token
                continue

            label = predictions[0][i].item()
            if label != self.model.config.label2id['O']:  # Non-O label
                pii_type = self.model.config.id2label[label]
                if not is_redacting:
                    is_redacting = True
                    redaction_start = start
                    current_pii_type = pii_type
                elif not aggregate_redaction and pii_type != current_pii_type:
                    # End current redaction and start a new one
                    apply_redaction(masked_text, redaction_start, start, current_pii_type, aggregate_redaction)
                    redaction_start = start
                    current_pii_type = pii_type
            else:
                if is_redacting:
                    apply_redaction(masked_text, redaction_start, end, current_pii_type, aggregate_redaction)
                    is_redacting = False

        # Handle case where PII is at the end of the text
        if is_redacting:
            apply_redaction(masked_text, redaction_start, len(masked_text), current_pii_type, aggregate_redaction)

        return ''.join(masked_text)
    
    def __call__(self, row:dict) -> dict:
        row["redacted_text"] = self._mask_pii(row["transcription"], aggregate_redaction=False)
        return row

## Prepare batch job

1. Point to input Delta table containing file paths
2. Select UC model names and `@alias` _(or version)_
3. Write ray inference code
4. Apply batch job and write/materialize outputs to Delta table

### 1. Read input Delta table containing audio file's path

In [0]:
TABLENAME = f"{CATALOG}.{SCHEMA}.recordings_file_reference" #CATALOG.SCHEMA.table
audio_files_reference_df = spark.table(TABLENAME)
# audio_files_reference_df.count()

## 2. Write ray batch pipeline

using [`map`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map.html#ray.data.Dataset.map) and [`map_batches`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html)

Some relevant parameters to define are:
* `num_cpus`: for CPU intensive workloads (i.e. read the audio files) - defines how many ray-cores to use for individual tasks (default is `1 Ray-Core/CPU == 1 CPU Physical Core`). It can be defined as a fraction to oversubscribe a single physical core with multiple tasks

* `num_gpus`: for GPU intensive workloads - defines how many (fractionnal) GPU(s) a single task/batch will use

* `concurrency`: how many parallel tasks to run in parallel `Tuple(min,max)`

**IF USING `whisper-large-v3` directly from UC's `system.ai` then set `CATALOG = system` and `SCHEMA = ai`** _(and `model_version = 1`)_

In [0]:
ds = ray.data.from_spark(audio_files_reference_df)
ds = ds.repartition(200)\
        .map(
            ConverttoPrompt,
            concurrency=(40,94), # Can go up to total sum of cores
            num_cpus=1,
        )\
        .map_batches(
              WhisperTranscription,
              fn_constructor_kwargs={
                  "catalog": CATALOG,
                  "schema": SCHEMA,
                  "model_name": ASR_MODEL_NAME,
                  "model_alias": MODEL_ALIAS
                  },
              concurrency=(6,12), # Up to max number of GPUs
              num_gpus=.6, # Individual batches will utilize  up to 60% of GPU's memory <==> 2 batches in parallel per GPU
              batch_size = 128
          )\
          .map(
              PIIRedaction,
              fn_constructor_kwargs={
                  "catalog": CATALOG,
                  "schema": SCHEMA,
                  "model_name": PII_MODEL_NAME,
                  "model_alias": MODEL_ALIAS
                  },
              concurrency=(10,24),
              num_gpus=float(.2) # One task/record will utilize up to 20% of GPU's memory <==> 5 tasks in parallel per GPU
          )

In [0]:
# Temporary directory for ray-uc-volumes-fuse (to write to Delta natively)
VOLUME = "temp"
spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG}.{SCHEMA}.{VOLUME}")

tmp_dir_fs = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME}/tempDoc"
dbutils.fs.mkdirs(tmp_dir_fs)
os.environ["RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = tmp_dir_fs

In [0]:
ds.write_databricks_table(
  f"{CATALOG}.{SCHEMA}.whisper_transcriptions_redacted_silver_piiranha_v2",
  mode = "overwrite", #append/merge
  mergeSchema = True
)

In [0]:
spark.table(f"{CATALOG}.{SCHEMA}.whisper_transcriptions_redacted_silver_piiranha_v2").display()