# PySpark Huggingface Inferencing
### Sentence Transformers

From: https://huggingface.co/sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

#Sentences we want to encode. Example:
sentence = ['This framework generates embeddings for each input sentence']


#Sentences are encoded by calling model.encode()
embedding = model.encode(sentence)

In [None]:
embedding

## PySpark

## Inference using Spark ML Model
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import sparkext

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
df.show(truncate=80)

In [None]:
my_model = sparkext.huggingface.SentenceTransformerModel(model) \
                .setInputCol("lines") \
                .setOutputCol("embedding")

In [None]:
embeddings = my_model.transform(df)

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

## Inference using Spark DL UDF
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

### Using model instance on driver

In [None]:
from pyspark.sql.functions import col
from sparkext.huggingface import sentence_transformer_udf

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
df.schema

In [None]:
encode = sentence_transformer_udf(model)

In [None]:
embeddings = df.withColumn("encoding", encode(col("lines")))

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

### Using model_id string on driver

In [None]:
from pyspark.sql.functions import col
from sparkext.huggingface import sentence_transformer_udf

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
encode = sentence_transformer_udf("paraphrase-MiniLM-L6-v2")

In [None]:
embeddings = df.withColumn("encoding", encode(col("lines")))

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

### Using model loader

In [None]:
from pyspark.sql.functions import col
from sparkext.huggingface import sentence_transformer_udf

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
def model_loader(model_name):
    from sentence_transformers import SentenceTransformer
    return SentenceTransformer(model_name)   

In [None]:
encode = sentence_transformer_udf("paraphrase-MiniLM-L6-v2", model_loader=model_loader)

In [None]:
embeddings = df.withColumn("encoding", encode(col("lines")))

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [1]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [2]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

                                                                                

In [3]:
df.show(truncate=120)

                                                                                

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|...But not this one! I always wanted to know "what happened" next. We will never know for sure what happened because ...|
|I found myself getting increasingly angry as this movie progressed.<br /><br />Basically, Dr. Crawford (Dennis Hopper...|
|The comparisons between the 1995 version and this are inevitable. Sadly, this version falls far short.<br /><br />The...|
|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about its supposed revelatio...|
|Don't get me wrong, I love the TV series of League Of Gentlemen. It was funny, twisted and completely inspired. I was...|
|Made it through

In [4]:
def predict_batch_fn():
    import numpy as np
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        return model.encode(flattened)
    return predict

In [5]:
encode = predict_batch_udf(predict_batch_fn,
                           return_type=ArrayType(FloatType()),
                           batch_size=10)

In [6]:
%%time
# first pass caches model/fn
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

[Stage 4:>                                                          (0 + 1) / 1]

CPU times: user 33.8 ms, sys: 0 ns, total: 33.8 ms
Wall time: 5.89 s


                                                                                

In [7]:
%%time
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

[Stage 7:>                                                          (0 + 1) / 1]

CPU times: user 7.51 ms, sys: 6.32 ms, total: 13.8 ms
Wall time: 1.09 s


                                                                                

In [8]:
%%time
embeddings = df.withColumn("encoding", encode("lines"))
results = embeddings.collect()

CPU times: user 8.51 ms, sys: 750 µs, total: 9.26 ms
Wall time: 1.14 s


                                                                                

In [9]:
%%time
embeddings = df.withColumn("encoding", encode(col("lines")))
results = embeddings.collect()

[Stage 13:>                                                         (0 + 1) / 1]

CPU times: user 10.9 ms, sys: 2.41 ms, total: 13.3 ms
Wall time: 1.06 s


                                                                                

In [10]:
embeddings.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|...But not this one! I always wanted to know "what happen...|[0.050629996, -0.19899222, 2.6855804E-4, 0.13270335, -0.1...|
|I found myself getting increasingly angry as this movie p...|[-0.11778694, 0.08591189, -0.036073662, 0.055232063, 0.14...|
|The comparisons between the 1995 version and this are ine...|[-0.03128382, -0.18052554, 0.024394799, -0.033730447, -0....|
|Doesn't anyone bother to check where this kind of sludge ...|[0.1475993, -0.1878961, -0.21340893, 0.06103613, 0.140383...|
|Don't get me wrong, I love the TV series of League Of Gen...|[-0.19420478, 0.11641938, 0.0198595, -0.37481567, 0.05207...|
|Made it

### Using Triton Server

#### Start Triton Server on each executor

In [11]:
num_executors = 1

nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:22.07-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="512M",
            volumes={
                "/home/leey/devpub/leewyang/sparkext/examples/models_hf": {"bind": "/models", "mode": "ro"},
                "/home/leey/huggingface/cache": {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.mapPartitions(start_triton).collect()

                                                                                

[True]

#### Run inference

In [12]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [13]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [14]:
df.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|...But not this one! I always wanted to know "what happened" next. We will never know for sure what happened because ...|
|I found myself getting increasingly angry as this movie progressed.<br /><br />Basically, Dr. Crawford (Dennis Hopper...|
|The comparisons between the 1995 version and this are inevitable. Sadly, this version falls far short.<br /><br />The...|
|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about its supposed revelatio...|
|Don't get me wrong, I love the TV series of League Of Gentlemen. It was funny, twisted and completely inspired. I was...|
|Made it through

In [15]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [16]:
encode = predict_batch_udf(triton_fn,
                           triton_uri="localhost:8001",
                           model_name="hf_transformer",
                           return_type=ArrayType(FloatType()),
                           input_tensor_shapes=[[-1,1]],
                           batch_size=10)

In [17]:
%%time
# first pass caches model/fn
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

[Stage 22:>                                                         (0 + 1) / 1]

CPU times: user 24.6 ms, sys: 0 ns, total: 24.6 ms
Wall time: 2.94 s


                                                                                

In [18]:
%%time
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

CPU times: user 15.3 ms, sys: 0 ns, total: 15.3 ms
Wall time: 592 ms


In [19]:
%%time
embeddings = df.withColumn("encoding", encode("lines"))
results = embeddings.collect()

CPU times: user 10.4 ms, sys: 0 ns, total: 10.4 ms
Wall time: 615 ms


In [20]:
%%time
embeddings = df.withColumn("encoding", encode(col("lines")))
results = embeddings.collect()

CPU times: user 13.3 ms, sys: 0 ns, total: 13.3 ms
Wall time: 555 ms


In [21]:
embeddings.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|...But not this one! I always wanted to know "what happen...|[0.05062989, -0.19899228, 2.6863161E-4, 0.1327033, -0.160...|
|I found myself getting increasingly angry as this movie p...|[-0.11778692, 0.085911795, -0.036073525, 0.055232257, 0.1...|
|The comparisons between the 1995 version and this are ine...|[-0.03128365, -0.18052553, 0.024394818, -0.033730507, -0....|
|Doesn't anyone bother to check where this kind of sludge ...|[0.14759916, -0.18789622, -0.2134091, 0.061035916, 0.1403...|
|Don't get me wrong, I love the TV series of League Of Gen...|[-0.19420485, 0.116419286, 0.019859504, -0.37481552, 0.05...|
|Made it

#### Stop Triton Server on each executor

In [22]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.mapPartitions(stop_triton).collect()

                                                                                

[True]

In [23]:
spark.stop()