# PySpark LLM Inferencing

This notebook demonstrates LLM batch inferencing in Spark using Triton Inference Server.

## Python

In [None]:
from llama import Llama

In [None]:
import os

os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '8989'

os.environ['CUDA_VISIBLE_DEVICES'] = '15'

In [None]:
ckpt_dir = "llama-2-7b-chat"
tokenizer_path = "tokenizer.model"

temperature: float = 0.6
top_p: float = 0.9
max_seq_len: int = 128
max_gen_len: int = 64
max_batch_size: int = 4

In [None]:
generator = Llama.build(
    ckpt_dir=ckpt_dir,
    tokenizer_path=tokenizer_path,
    max_seq_len=max_seq_len,
    max_batch_size=max_batch_size,
)

### Text completion

In [None]:
prompts = [
    "I believe the meaning of life is ",
    "Simply put, the theory of relativity states that ",
    "The history of space travel started with ",
    "The most popular dog breeds include ",
]

In [None]:
results = generator.text_completion(prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p)
results = [x['generation'] for x in results]

In [None]:
results

### Chat

In [None]:
instruction = "Translate the following to German without any additional English explanations"
instruct_prompts = [
    f"<s>[INST] <<SYS>>\n{instruction}\n</SYS>>\n\n{prompt}[/INST]" for prompt in prompts
]
instruct_prompts

In [None]:
results = generator.text_completion(instruct_prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p)
results

## PySpark dataset

In [None]:
import os
from pathlib import Path
from torchtext.datasets import IMDB

In [None]:
# load IMDB reviews (test) dataset
data = IMDB(split='test')

In [None]:
# convert to array of string for pyspark
lines = []
for label, text in data:
    # only take text of IMDB review
    lines.append([text])
len(lines)

### Create PySpark DataFrame

In [None]:
from pyspark.sql.types import *

In [None]:
df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

In [None]:
df.show(truncate=100)

In [None]:
df.count()

### Save as Parquet

In [None]:
df.write.mode("overwrite").parquet("imdb_test")

## Inference using Spark DL API (Triton server per executor)

### Preprocess dataset

Since we're using the IMDB dataset, which has lengthy text, we'll just truncate to max_len.

**Note**: need to figure out longer context windows or chunking for longer text.

In [1]:
import pandas as pd

from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [2]:
df = spark.read.parquet("imdb_test")
df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               lines|
+----------------------------------------------------------------------------------------------------+
|Documentary content: Amazing man, amazing movement he started, amazing stories- most of them yet ...|
|Even if I had not read Anne Rice's "Queen of the Damned" from the "Vampire Chronicles," I probabl...|
|This movie is about a depressed and emotionally constricted man has a distant relative move in wi...|
|This is possibly the worst thing I've ever seen on television. First, I'm pretty sure it takes it...|
|This show is pathetic. I can't even begin to imagine how anyone with an IQ greater than that of a...|
|The obsession of 'signifie' and 'signifiant' is not enough to make a good film. Pascal Bonitzer s...|
|I am easily pleased. I like bad films. I like films featuring attractive

In [3]:
df.count()

                                                                                

25000

In [4]:
# truncate to max_len for conditional generation
def preprocess(text: pd.Series, prefix: str = "", max_len: int = 512) -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        truncated = [s[:max_len].rsplit(' ',1)[0] for s in text]
        return pd.Series(truncated)
    return _preprocess(text)

In [5]:
# only select first N rows, since this takes a long time
df1 = df.limit(100).withColumn("input", preprocess(col("lines"))).select("input")
df1.show(truncate=100)

[Stage 5:>                                                          (0 + 1) / 1]

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|Documentary content: Amazing man, amazing movement he started, amazing stories- most of them yet ...|
|Even if I had not read Anne Rice's "Queen of the Damned" from the "Vampire Chronicles," I probabl...|
|This movie is about a depressed and emotionally constricted man has a distant relative move in wi...|
|This is possibly the worst thing I've ever seen on television. First, I'm pretty sure it takes it...|
|This show is pathetic. I can't even begin to imagine how anyone with an IQ greater than that of a...|
|The obsession of 'signifie' and 'signifiant' is not enough to make a good film. Pascal Bonitzer s...|
|I am easily pleased. I like bad films. I like films featuring attractive

                                                                                

In [6]:
num_rows = df1.count()
num_rows

100

In [7]:
df1.rdd.getNumPartitions()

1

### Start Triton server on each executor

In [8]:
from spark_rapids_ml.llm import TritonLLM

In [9]:
model = TritonLLM()
model.setDockerImage("tensorrt_llm_backend:dev")
model.setProtocol("grpc")
model.setOutputLen(200)

TritonLLM_4ea78573ab4a

In [10]:
print(model.explainParams())

batch_size: size of batch for inference (default: 1)
beam_width: LLM beam width (default: 1)
concurrency: Number of parallel requests to Triton Inference Server (default: 1)
docker_image: Docker image for Triton Inference Server (current: tensorrt_llm_backend:dev)
inputCol: input column name. (default: input)
model_name: Name of model to use in Triton Inference Server (default: ensemble)
model_path: Host path to model directory for Triton Inference Server (undefined)
outputCol: output column name. (default: output)
output_len: Output length (default: 10, current: 200)
prefix: Prompt prefix for LLM (undefined)
protocol: Protocol (http/grpc) used to communicate with Triton Inference Server (default: http, current: grpc)
server: Server hostname for Triton Inference Server (default: localhost)
tokenizer: Tokenizer to use for LLM (default: auto)
topk: TopK for sampling (default: 1)
topp: TopP for sampling (default: 0.0)
verbose: verbose logging (default: False)


In [11]:
model.startServers()

starting 2 server(s).


                                                                                

### Define an instruction / task

In [12]:
# Translation
# model.setPrefix("Translate the following to German without any additional English explanations")
# model.setPrefix("Translate the following to Spanish without any additional English explanations")
# model.setPrefix("Translate the following to Chinese without any additional English explanations")

# Summarization
# model.setPrefix("Summarize the following text in 20 words or less")

# Classification / Sentiment Analysis
# model.setPrefix("Classify the sentiment of the following text as either POSITIVE or NEGATIVE only, without any additional text or explanations") 
# model.setPrefix("Classify the following text as either ABUSIVE or NON-ABUSIVE only, without any additional text or explanations") 

# Content creation
# model.setPrefix("Re-phrase the following text as a pirate")

# Product recommendation
# model.setPrefix("Recommend a movie representative to the following text")

# Search / information extraction
model.setPrefix("Extract all movie titles, actors, and characters from the following text")

TritonLLM_4ea78573ab4a

### Inference w/ LLM

In [13]:
import time

start = time.time()
results = model.transform(df1).collect()
duration = time.time() - start

print(f"Wall time: {duration} s")
print(f"Throughput: {num_rows / duration} rows/s")



Wall time: 29.925328254699707 s
Throughput: 3.3416508968216654 rows/s


                                                                                

In [14]:
for row in results[:10]:
    print(f"================================================================================")
    print(row['input'])
    print(f"--------------------------------------------------------------------------------")
    output = row['output']
    print(output[output.index("[/INST]")+9:-1] + "\n")


<s>[INST] <<SYS>>
Extract all movie titles, actors, and characters from the following text
</SYS>>

The director does not know what to do with a camera... too many options and she always always always picks the wrong one... she let travolta take charge... and he controls the movie from the beginning to the end... the characters are not developed... maybe because we need to watch them singing... no pace at all, sometimes too fast sometimes too slow... miscasted: travolta OK... johansson, she is too grown up to be a 18... even if she is really 20...<br /><br />the happy ending? well it looks like that[/INST]
--------------------------------------------------------------------------------
Sure, here are the movie titles, actors, and characters mentioned in the text:

Movie Titles:

* "The director does not know what to do with a camera" (no specific title mentioned)
Actors:

* John Travolta (mentioned as "Travolta")
Characters:

* The director (not specified)
* John Travolta (playing a ch

### Stop Triton server on each executor

In [15]:
model.stopServers()

stopping 2 server(s)


                                                                                