# PySpark Huggingface Inferencing
### Sentence Transformers

From: https://huggingface.co/sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

#Sentences we want to encode. Example:
sentence = ['This framework generates embeddings for each input sentence']


#Sentences are encoded by calling model.encode()
embedding = model.encode(sentence)

In [None]:
embedding

## PySpark

## Inference using Spark ML Model
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import sparkext

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
df.show(truncate=80)

In [None]:
my_model = sparkext.huggingface.SentenceTransformerModel(model) \
                .setInputCol("lines") \
                .setOutputCol("embedding")

In [None]:
embeddings = my_model.transform(df)

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

## Inference using Spark DL UDF
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

### Using model instance on driver

In [None]:
from pyspark.sql.functions import col
from sparkext.huggingface import sentence_transformer_udf

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
df.schema

In [None]:
encode = sentence_transformer_udf(model)

In [None]:
embeddings = df.withColumn("encoding", encode(col("lines")))

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

### Using model_id string on driver

In [None]:
from pyspark.sql.functions import col
from sparkext.huggingface import sentence_transformer_udf

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
encode = sentence_transformer_udf("paraphrase-MiniLM-L6-v2")

In [None]:
embeddings = df.withColumn("encoding", encode(col("lines")))

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

### Using model loader

In [None]:
from pyspark.sql.functions import col
from sparkext.huggingface import sentence_transformer_udf

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
def model_loader(model_name):
    from sentence_transformers import SentenceTransformer
    return SentenceTransformer(model_name)   

In [None]:
encode = sentence_transformer_udf("paraphrase-MiniLM-L6-v2", model_loader=model_loader)

In [None]:
embeddings = df.withColumn("encoding", encode(col("lines")))

In [None]:
%%time
results = embeddings.collect()

In [None]:
embeddings.show(truncate=60)

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [1]:
from pyspark.ml.udf import model_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [2]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

                                                                                

In [3]:
df.show(truncate=120)

[Stage 1:>                                                          (0 + 1) / 1]

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|...But not this one! I always wanted to know "what happened" next. We will never know for sure what happened because ...|
|I found myself getting increasingly angry as this movie progressed.<br /><br />Basically, Dr. Crawford (Dennis Hopper...|
|The comparisons between the 1995 version and this are inevitable. Sadly, this version falls far short.<br /><br />The...|
|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about its supposed revelatio...|
|Don't get me wrong, I love the TV series of League Of Gentlemen. It was funny, twisted and completely inspired. I was...|
|Made it through

                                                                                

In [4]:
def model_fn():
    import numpy as np
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        return model.encode(flattened)
    return predict

In [5]:
encode = model_udf(model_fn, 
                   input_tensor_shapes=[[-1,1]], 
                   return_type=ArrayType(FloatType()), 
                   batch_size=10)

In [6]:
embeddings = df.withColumn("encoding", encode(struct("lines")))

In [7]:
embeddings.show(truncate=60)

[Stage 4:>                                                          (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|...But not this one! I always wanted to know "what happen...|[0.050629996, -0.19899222, 2.6855804E-4, 0.13270335, -0.1...|
|I found myself getting increasingly angry as this movie p...|[-0.11778694, 0.08591189, -0.036073662, 0.055232063, 0.14...|
|The comparisons between the 1995 version and this are ine...|[-0.03128382, -0.18052554, 0.024394799, -0.033730447, -0....|
|Doesn't anyone bother to check where this kind of sludge ...|[0.1475993, -0.1878961, -0.21340893, 0.06103613, 0.140383...|
|Don't get me wrong, I love the TV series of League Of Gen...|[-0.19420478, 0.11641938, 0.0198595, -0.37481567, 0.05207...|
|Made it

                                                                                

In [8]:
%%time
results = embeddings.collect()

[Stage 7:>                                                          (0 + 1) / 1]

CPU times: user 12.6 ms, sys: 5.2 ms, total: 17.8 ms
Wall time: 5.2 s


                                                                                

In [9]:
spark.stop()