# PySpark Huggingface Inferencing
### Conditional generation

From: https://huggingface.co/docs/transformers/model_doc/t5

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

In [None]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      return_tensors="pt").input_ids
outputs = model.generate(input_ids)

In [None]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

## PySpark

In [None]:
import os
from pathlib import Path
from torchtext.datasets import IMDB

In [None]:
# load IMDB reviews (test) dataset
data = IMDB(split='test')
len(data)

In [None]:
# convert to nested array of string for pyspark
lines = []
for label, text in data:
    # only take first sentence of IMDB review
    lines.append([text])

### Create PySpark DataFrame

In [None]:
from pyspark.sql.types import *

In [None]:
df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

In [None]:
df.take(1)

### Save the test dataset as parquet files

In [None]:
df.write.mode("overwrite").parquet("imdb_test")

### Check arrow memory configuration

In [None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

## Inference using Spark ML Model
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import sparkext

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)

In [None]:
my_model = sparkext.huggingface.Model(model, tokenizer, prefix="Translate English to German: ")

In [None]:
predictions = my_model.transform(df)

In [None]:
predictions.write.mode("overwrite").text("imdb_translations")
predictions.collect()

## Inference using Spark DL UDF
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from sparkext.huggingface import model_udf

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [None]:
df = spark.read.parquet("imdb_test")
df.show(truncate=120)

In [None]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)
df1.show(truncate=120)

In [None]:
# could also create a model loader function w/ model + tokenizer + kwargs
# generate = model_udf(model_loader)

# note: default return_type is 'string'
generate = model_udf(model, tokenizer=tokenizer,
                     max_length=128, padding="longest", return_tensors="pt", truncation=True, skip_special_tokens=True)

In [None]:
predictions = df1.withColumn("preds", generate(col("input")))
predictions.show(truncate=60)

In [None]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100)

In [None]:
predictions = df2.withColumn("preds", generate(col("input")))
predictions.show(truncate=60)