# Transformer Based Batch Inference

This notebook aims to show how you can run batch inference using Spark's distributed capabilites, with a multi-machine multi-gpu setup.

Run on Databricks ML 14.0, with 4 A100 GPUs (1 Driver + 3 Worker) using _N24ads_A100_V4_ machines on Azure.


### Install & Upgrade Libraries

In [0]:
!pip install -q --upgrade transformers
!pip install -q --upgrade accelerate
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m


### GPU Stats

In [0]:
!nvidia-smi

Sat Sep 16 17:41:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   40C    P0    44W / 300W |      0MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

### Parameters

The model name & tokenizer name below should be adjustable to most of the other models existing in the hugging face world.

It would also make sense to change the prompt if there is a change to the model. This one is specifically designed for the LLAMA V2 model.

In [0]:
# Model Params
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
MAX_NEW_TOKENS = 300
MIN_LENGTH = 0
REPETITION_PENALTY = 1.2
TEMPERATURE = 0.1
TOP_P = 0.9
TOP_K = 50
DO_SAMPLE = True
USE_CACHE = True

# Tokenizer Params
TOKENIZER_NAME = MODEL_NAME
MAX_TOKENS = 2048

# Run Params (How many articles to use)
MAX_EXAMPLES = 10000

# Storage Params
STORAGE_PATH = "/dbfs/llm-examples"

# Instruction
INSTRUCTION = """Please provide a concise summary for the following article: {text}"""

# Prompt
PROMPT_TEMPLATE = f"""<s>[INST]<<SYS>>
You are a direct and honest assistant. Please provide concise and factual answers and just the answers.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

{INSTRUCTION}
[/INST]
"""

### Storage Operations

Makes sure that the directory is cleaned.

In [0]:
import shutil
import os

# Remove existing files
shutil.rmtree("/dbfs/llm-examples")

# Build the new directory
os.makedirs("/dbfs/llm-examples", exist_ok=True)

### Huggingface Login

This step can be skipped if your model doesn't require a login. LLAMA V2 does.

In [0]:
# Login to hugging face
from huggingface_hub import notebook_login

# Login the huggingface
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### Retrieve Articles Data

CNN & Daily Mail articles data set from Hugging Face Datasets is retrieved and combined to make a pyspark dataframe.

In [0]:
# Imports
from datasets.utils import logging as dataset_logging, disable_progress_bar
from pyspark.sql import functions as SF
from datasets import load_dataset

# Disable verbose loggers
dataset_logging.set_verbosity_error()
disable_progress_bar()

# Download dataset
dataset = load_dataset(
    path="cnn_dailymail", name="3.0.0", cache_dir=f"{STORAGE_PATH}/hf"
)

# Create spark dataframes
train_df = spark.createDataFrame(data=dataset["train"].to_pandas())
val_df = spark.createDataFrame(data=dataset["validation"].to_pandas())
test_df = spark.createDataFrame(data=dataset["test"].to_pandas())

# Union for all data
articles_df = train_df.union(val_df).union(test_df)



### Sample Data

Deterministic sampling through ID hexing for consistent results & benchmarking.

In [0]:
from pyspark.sql import functions as SF
import hashlib

# Build function for creating a random column
@SF.udf("string")
def generate_hex_from_string(input_string: str) -> str:
    sha256 = hashlib.sha256()
    sha256.update(input_string.encode("utf-8"))
    return sha256.hexdigest()


# Generate random string
articles_df = articles_df.withColumn(
    "random_string", generate_hex_from_string(SF.col("id"))
)

# Order by randomness and limit dataframe size
articles_df = articles_df.orderBy(SF.col("random_string")).limit(MAX_EXAMPLES)

### Generate Prompts

Prompts are applied with the article text to generate prepared instructions for the model.

In [0]:
from pyspark.sql import functions as SF

# Build function for generating instructions
@SF.udf("string")
def generate_instructions(article):
    return PROMPT_TEMPLATE.format(text=article)


# Generate instructions
articles_df = articles_df.withColumn(
    "instruction", generate_instructions(SF.col("article"))
)

### Execute Data Operations


In [0]:
# Trigger with action
articles_df = spark.createDataFrame(articles_df.toPandas()).repartition(10)

# Cache for performance
articles_df.cache()
print(f"Number of Examples: {articles_df.count()}")

Number of Examples: 10000


### Download Model & Tokenizer

Downloading the model and the tokizer helps when it comes to loading the model faster during the multi machine inference step.

If the model and the tokenizer are in the same repository, only one download will occur. In some cases, for example for Falcon-7B, they can be different. In that case, the code downloads both to the location specified in params.

In [0]:
# External Imports
from huggingface_hub.utils import (
    disable_progress_bars as hfhub_disable_progress_bar,
    logging as hf_logging,   
)
from huggingface_hub import snapshot_download
import os

# Turn Off Info Logging for Transfomers
hf_logging.set_verbosity_error()
hfhub_disable_progress_bar()

# Download the model 
local_model_path = f"{STORAGE_PATH}/model/"
os.makedirs(local_model_path, exist_ok=True)
model_download = snapshot_download(
    repo_id=MODEL_NAME,
    local_dir=local_model_path,
    local_dir_use_symlinks=False,
    ignore_patterns="*.safetensors*", # This argument is specific to LLAMA. Other models might not need it.
    max_workers=48
)

# Download the tokenizer
if MODEL_NAME == TOKENIZER_NAME:
    local_tokenizer_path = local_model_path
else:
    local_tokenizer_path = f"{STORAGE_PATH}/tokenizer/"
    os.makedirs(local_tokenizer_path, exist_ok=True)
    tokenizer_download = snapshot_download(
        repo_id=TOKENIZER_NAME,
        local_dir=local_tokenizer_path,
        local_dir_use_symlinks=False,
        max_workers=48
    )

### Load Model & Tokenizer

Model and Tokenizer are loaded for the downloaded directory for testing.

In [0]:
# Imports
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Params
random_seed = 42

# Random seed set
torch.cuda.manual_seed(random_seed)
torch.manual_seed(random_seed)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path, padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    return_dict=True,
    device_map="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    pad_token_id=tokenizer.eos_token_id,
)

# Put model in eval mode
model.eval()

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=2)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

### Run For One

Batch Generate functions takes a list of prompts, and returns a list of ouputs (generated_text). Generation parameters such a temperature and top_p are set within the function.

Even though this example shows how to do a few examples, this function will be used during distributed inference.

In [0]:
# Imports
import torch

# Get sample data
sample_instructions = [x[0] for x in articles_df.select("instruction").limit(2).collect()]

# Define Inference Flow
@torch.inference_mode()
def batch_generate(batch_prompts, tokenizer=tokenizer, model=model):
    batch = tokenizer.batch_encode_plus(
        batch_prompts,
        padding=True,
        truncation=True,
        return_tensors="pt",
        return_token_type_ids=False,
        max_length=MAX_TOKENS
    )
    batch = {k: v.to("cuda") for k, v in batch.items()}
    with torch.no_grad():
        outputs = model.generate(
            **batch,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE,
            top_p=TOP_P,
            temperature=TEMPERATURE,
            min_length=MIN_LENGTH,
            use_cache=USE_CACHE,
            top_k=TOP_K,
            repetition_penalty=REPETITION_PENALTY,
        )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Check out one example
print(batch_generate(sample_instructions[:1])[0])

[INST]<<SYS>>
You are a direct and honest assistant. Please provide concise and factual answers and just the answers.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Please provide a concise summary for the following article: PUBLISHED: . 12:42 EST, 30 August 2012 . | . UPDATED: . 02:28 EST, 31 August 2012 . Guilty: Gordon Dyche has been sentenced to four years in prison for causing death by careless driving . A 'chancer' who killed four members of the same family when he shunted their car off the road and into a reservoir while rushing to work was jailed for four years yesterday. Mechanic Gordon Dyche, 24, had recently completed a driving ban when he tried to overtake two cars in a row on a winding country road because he was worried he would lose pay if he was late. He passed a Volkswagen Passat which was travelling a

### Batch Test

Optimal batch size changes depending on the GPU used. The GPU used in during this test has 80 GB of GPU memory, so going higher makes sense, however smaller machine like the A10s usually do better with smaller batch sizes.

The code compares multiple batch sizes, and interation stops when out of memory error is raised. 

In [0]:
# Imports
import time

# Get sample data
sample_instructions = [x[0] for x in articles_df.select("instruction").limit(50).collect()]

def batch_size_optimiser():
    batch_sizes = [1, 2, 3, 4, 5, 7, 10, 12, 15, 17, 20, 25, 30]
    success = True
    for size in batch_sizes:
        start = time.perf_counter()
        try:
            batch_generate(batch_prompts=sample_instructions[:size])
        except torch.cuda.OutOfMemoryError:
            success = False
            break
        finally:
            elapsed = round(time.perf_counter() - start, 2)
            unit_time = round(elapsed/size, 2)
            yield {"batch_size": size, "elapsed_time": elapsed, "unit_time": unit_time, "success": success}

for result in batch_size_optimiser():
    print("- - " * 10)
    print(result)

- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 1, 'elapsed_time': 6.81, 'unit_time': 6.81, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 2, 'elapsed_time': 8.47, 'unit_time': 4.24, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 3, 'elapsed_time': 8.77, 'unit_time': 2.92, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 4, 'elapsed_time': 8.98, 'unit_time': 2.25, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 5, 'elapsed_time': 8.35, 'unit_time': 1.67, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 7, 'elapsed_time': 11.76, 'unit_time': 1.68, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 10, 'elapsed_time': 18.24, 'unit_time': 1.82, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 12, 'elapsed_time': 21.38, 'unit_time': 1.78, 'success': True}
- - - - - - - - - - - - - - - - - - - - 
{'batch_size': 15,

### Select Batch Size

Doesn't necessarily has to be the largest batch size that succeeded without running into an OOM error. It is probably better to choose the 2nd or 3rd largest successful batch size so that OOM errors can be reduced during inference.

In [0]:
# Get the bactch size with minimum unit time
OPTIMAL_BATCH_SIZE = 15

### Distributed Inference Logic

All of the generation logic is carried into a Pandas UDF. This helps with the set up on the workers. An iterator to interator architecture is followed to handle batching processes. 

In the case that the model runs into an OOM Error, the function handles the exception by returnin a OOM string as a result.

In [0]:
# External Imports
from pyspark.sql import functions as SF
import pandas as pd
from typing import Iterator

# Build Inference Function
@SF.pandas_udf("string", SF.PandasUDFType.SCALAR_ITER)
def run_distributed_inference(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:

    # External Imports
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import pandas as pd
    import torch
    import os

    # Params
    random_seed = 42

    # Random seed set
    torch.cuda.manual_seed(random_seed)
    torch.manual_seed(random_seed)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path, padding_side="left")
    tokenizer.pad_token_id = tokenizer.eos_token_id

    # Load Model
    model = AutoModelForCausalLM.from_pretrained(
        local_model_path,
        return_dict=True,
        device_map="auto",
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Put model in eval mode
    model.eval()

    for prompts in iterator:
        prompts = prompts.to_list()
        try:
            output = batch_generate(
                batch_prompts=prompts, 
                tokenizer=tokenizer, 
                model=model
            )
        except torch.cuda.OutOfMemoryError:
            # If out of memory, return a series of OOM strings that has the lenght of the input
            output = ["OOM"] * len(prompts)

        yield pd.Series(output)



### Inference Configurations

Automatically undertands how many workers are available in the cluster, and adjusts partitions accordingly. This means that the setup portion of the Pandas UDF which loads the model and tokenizer gets run only once during inference, and the data processing is handled with the iterator.

Max Records Per Batch configuration controls how big the batch sizes are going to be. 

In [0]:
# Imports
from pyspark import SparkContext

# Auto get number of workers
sc = SparkContext.getOrCreate()

# Subtract 1 to exclude the driver
num_workers = len(sc._jsc.sc().statusTracker().getExecutorInfos()) - 1  

# Set the batch size for the Pandas UDF
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", OPTIMAL_BATCH_SIZE * 2)

# Repartition
articles_df = articles_df.repartition(num_workers)

# Cache DF
articles_df.cache()
articles_df.count()

10000

### Run Distributed Inference

In [0]:
import time

# Apply Inference UDF
articles_df = (
    articles_df
    .withColumn("llm_summary", run_distributed_inference(SF.col("instruction")))

)

# Materilize and Execute
inference_start_time = time.perf_counter()
articles_pdf = articles_df.toPandas()
inference_elapsed_time = round(time.perf_counter() - inference_start_time, 4)

### Build Spark DF

In [0]:
# Go back to Spark
articles_df = spark.createDataFrame(articles_pdf)

# Cache DF
articles_df.cache()
articles_df.count()

10000

### Clean Summaries

In [0]:
# Imports
from pyspark.sql import functions as SF

# UDF Build
clean_llm_summary = SF.udf(lambda x: x.split("[/INST]")[-1].strip(), "string")

# Apply UDF
articles_df = (
    articles_df.withColumn(
        "cleaned_llm_summary", clean_llm_summary(SF.col("llm_summary"))
    )
)

### Calculate Token Counts

In [0]:
@SF.udf("int")
def calculate_n_tokens(target_text):
    return len(
        tokenizer.encode_plus(
            target_text,
            padding=True,
            truncation=True,
            return_token_type_ids=False,
            add_special_tokens=False,
            return_attention_mask=True,
            max_length=MAX_TOKENS,
        )["input_ids"]
    )


# Calculate article tokens
articles_df = articles_df.withColumn(
    "article_token_count", calculate_n_tokens(SF.col("article"))
)

# Calculate instruction tokens
articles_df = articles_df.withColumn(
    "instruction_token_count", calculate_n_tokens(SF.col("instruction"))
)

# Calculate generated tokens
articles_df = articles_df.withColumn(
    "generated_token_count", calculate_n_tokens(SF.col("cleaned_llm_summary"))
)

# Cache DF
articles_df.cache()
articles_df.count()

10000

### Display Stats

In [0]:
# Imports
from pyspark.sql import functions as SF
import datetime

text_stats = (
    articles_df.groupBy()
    .agg(
        SF.count(SF.col("id")).alias("articles_count"),
        SF.sum(SF.col("article_token_count")).alias("total_article_tokens"),
        SF.sum(SF.col("instruction_token_count")).alias("total_instruction_tokens"),
        SF.sum(SF.col("generated_token_count")).alias("total_generated_tokens"),
    )
    .first()
)

human_elapsed_time = str(datetime.timedelta(seconds=inference_elapsed_time))

print("-" * 3 + " Input " + "-" * 3)
print(f"Total Article Count: {text_stats['articles_count']}")
print(f"Articles Token Count: {text_stats['total_article_tokens']}")
print(f"With Instructions Token Count: {text_stats['total_instruction_tokens']}")

print("\n" + "-" * 3 + " Output " + "-" * 3)
print(f"Generated Tokens Count: {text_stats['total_generated_tokens']}")
print(f"Inference Elapsed Seconds: {inference_elapsed_time}")
print(f"Inference Elapsed Time: {human_elapsed_time}")

--- Input ---
Total Article Count: 10000
Articles Token Count: 9972345
With Instructions Token Count: 10956621

--- Output ---
Generated Tokens Count: 2979721
Inference Elapsed Seconds: 6484.8818
Inference Elapsed Time: 1:48:04.881800


### Save Results


In [0]:
# Build DBFS path for the table
save_location = f"{STORAGE_PATH}/results".split("/dbfs")[-1]

# Save Table
articles_df.write.mode("overwrite").save(save_location)

# Register Table
_ = spark.sql(f"DROP TABLE IF EXISTS llm_batch_inference_results")
_ = spark.sql(
    f"CREATE TABLE llm_batch_inference_results USING DELTA LOCATION '{save_location}'"
)