# PySpark Huggingface Inferencing
## Conditional generation

From: https://huggingface.co/docs/transformers/model_doc/t5

### Using PyTorch

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

In [None]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      return_tensors="pt").input_ids
outputs = model.generate(input_ids)

In [None]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

In [None]:
model.framework

### Using TensorFlow

In [None]:
from transformers import T5Tokenizer, TFT5ForConditionalGeneration

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

In [None]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      return_tensors="tf").input_ids
outputs = model.generate(input_ids)

In [None]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

In [None]:
model.framework

## PySpark

In [None]:
import os
from pathlib import Path
from torchtext.datasets import IMDB

In [None]:
# load IMDB reviews (test) dataset
data = IMDB(split='test')
len(data)

In [None]:
# convert to nested array of string for pyspark
lines = []
for label, text in data:
    # only take first sentence of IMDB review
    lines.append([text])

### Create PySpark DataFrame

In [None]:
from pyspark.sql.types import *

In [None]:
df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

In [None]:
df.take(1)

### Save the test dataset as parquet files

In [None]:
df.write.mode("overwrite").parquet("imdb_test")

### Check arrow memory configuration

In [None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

## Inference using Spark ML Model
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import pandas as pd
import sparkext
from pyspark.sql.functions import col, pandas_udf

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)

In [None]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [None]:
# add prefix, only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input")
df1.show(truncate=120)

In [None]:
my_model = sparkext.huggingface.Model(model, tokenizer, 
                    max_length=128, padding="longest", return_tensors="pt", truncation=True, skip_special_tokens=True) \
                    .setInputCol("input") \
                    .setOutputCol("translation")

**Note**: "AutoModel from string" doesn't work here, because the T5ForConditionalGeneration model actually adds a 
language modeling head on top of the standard T5 model, where the AutoModel only loads the standard T5 model.
See: https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForConditionalGeneration
```
my_model = sparkext.huggingface.Model("t5-small")
```

In [None]:
predictions = my_model.transform(df1)

In [None]:
%%time
predictions.write.mode("overwrite").parquet("imdb_translations")
results = predictions.collect()

In [None]:
results[:5]

## Inference using Spark DL UDF (PyTorch)
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from sparkext.huggingface import model_udf

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)

In [None]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [None]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [None]:
df1.show(truncate=120)

In [None]:
# note: default return_type is 'string'
generate = model_udf(model, tokenizer=tokenizer,
                     max_length=128, padding="longest", return_tensors="pt", truncation=True, skip_special_tokens=True)

In [None]:
predictions = df1.withColumn("preds", generate(col("input")))

In [None]:
predictions.show(truncate=60)

In [None]:
%%time
preds = predictions.collect()

In [None]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100)

In [None]:
df2.show(truncate=120)

In [None]:
predictions = df2.withColumn("preds", generate(col("input")))

In [None]:
predictions.show(truncate=60)

In [None]:
%%time
preds = predictions.collect()

## Inference using Spark DL UDF (TensorFlow)
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from sparkext.huggingface import model_udf

In [None]:
from transformers import T5Tokenizer, TFT5ForConditionalGeneration

In [None]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)

In [None]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [None]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [None]:
df1.show(truncate=120)

In [None]:
# Need to use a model_loader since spark doesn't serialize this model correctly
def model_loader(model_id):
    from transformers import TFT5ForConditionalGeneration, T5Tokenizer
    model = TFT5ForConditionalGeneration.from_pretrained(model_id)
    tokenizer = T5Tokenizer.from_pretrained(model_id)
    return model, tokenizer

In [None]:
# note: default return_type for model_udf is 'string'
generate = model_udf("t5-small", tokenizer=tokenizer, model_loader=model_loader,
                     max_length=128, padding="longest", return_tensors="tf", truncation=True, skip_special_tokens=True)

In [None]:
predictions = df1.withColumn("preds", generate(col("input")))

In [None]:
predictions.show(truncate=60)

In [None]:
%%time
preds = predictions.collect()

In [None]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100)

In [None]:
df2.show(truncate=120)

In [None]:
predictions = df2.withColumn("preds", generate(col("input")))

In [None]:
predictions.show(truncate=60)

In [None]:
%%time
preds = predictions.collect()

## Inference using Spark DL API (PyTorch)
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [1]:
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [2]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)

                                                                                

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|...But not this one! I always wanted to know "what happened" next. We will never know for sure what happened because ...|
|I found myself getting increasingly angry as this movie progressed.<br /><br />Basically, Dr. Crawford (Dennis Hopper...|
|The comparisons between the 1995 version and this are inevitable. Sadly, this version falls far short.<br /><br />The...|
|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about its supposed revelatio...|
|Don't get me wrong, I love the TV series of League Of Gentlemen. It was funny, twisted and completely inspired. I was...|
|Made it through

In [3]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [4]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [5]:
df1.show(truncate=120)

[Stage 4:>                                                          (0 + 1) / 1]

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|                                                                                           Translate English to German: |
|                         Translate English to German: I found myself getting increasingly angry as this movie progressed|
|                           Translate English to German: The comparisons between the 1995 version and this are inevitable|
|Translate English to German: Doesn't anyone bother to check where this kind of sludge comes from before blathering on...|
|                            Translate English to German: Don't get me wrong, I love the TV series of League Of Gentlemen|
|           Tran

                                                                                

In [6]:
def predict_batch_fn():
    import numpy as np
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    
    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()   # convert 2d numpy array of string into flattened python list
        input_ids = tokenizer(flattened, 
                              padding="longest", 
                              max_length=128,
                              return_tensors="pt").input_ids
        output_ids = model.generate(input_ids)
        string_outputs = [tokenizer.decode(o, skip_special_tokens=True) for o in output_ids]
        return string_outputs
    
    return predict

In [7]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=10)

In [8]:
%%time
# first pass caches model/fn
predictions = df1.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 7:>                                                          (0 + 1) / 1]

CPU times: user 32.3 ms, sys: 0 ns, total: 32.3 ms
Wall time: 13.5 s


                                                                                

In [9]:
%%time
predictions = df1.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 10:>                                                         (0 + 1) / 1]

CPU times: user 18.7 ms, sys: 0 ns, total: 18.7 ms
Wall time: 5.77 s


                                                                                

In [10]:
%%time
predictions = df1.withColumn("preds", generate("input"))
preds = predictions.collect()

[Stage 13:>                                                         (0 + 1) / 1]

CPU times: user 40.1 ms, sys: 0 ns, total: 40.1 ms
Wall time: 5.67 s


                                                                                

In [11]:
%%time
predictions = df1.withColumn("preds", generate(col("input")))
preds = predictions.collect()

[Stage 16:>                                                         (0 + 1) / 1]

CPU times: user 8.55 ms, sys: 4.95 ms, total: 13.5 ms
Wall time: 5.68 s


                                                                                

In [12]:
predictions.show(truncate=60)

[Stage 19:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|                               Translate English to German: |                                    Übersetzen Sie Englisch.|
|Translate English to German: I found myself getting incre...| Ich sah mich immer ärgerlicher, als dieser Film weiterging.|
|Translate English to German: The comparisons between the ...|Die Vergleiche zwischen der Version 1995 und diese sind u...|
|Translate English to German: Doesn't anyone bother to che...|    Warum hat man sich nicht angefreut, zu überprüfen, woher|
|Translate English to German: Don't get me wrong, I love t...|Verstehen Sie mich nicht falsch, ich liebe die TV-Serie L...|
|Transla

                                                                                

In [13]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100)

In [14]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|                                                                                           Translate English to French: |
|                         Translate English to French: I found myself getting increasingly angry as this movie progressed|
|                           Translate English to French: The comparisons between the 1995 version and this are inevitable|
|Translate English to French: Doesn't anyone bother to check where this kind of sludge comes from before blathering on...|
|                            Translate English to French: Don't get me wrong, I love the TV series of League Of Gentlemen|
|           Tran

In [15]:
%%time
# first pass caches model/fn
predictions = df2.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 25:>                                                         (0 + 1) / 1]

CPU times: user 12.3 ms, sys: 4.58 ms, total: 16.9 ms
Wall time: 13.5 s


                                                                                

In [16]:
%%time
predictions = df2.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 28:>                                                         (0 + 1) / 1]

CPU times: user 16 ms, sys: 3.72 ms, total: 19.7 ms
Wall time: 5.83 s


                                                                                

In [17]:
%%time
predictions = df2.withColumn("preds", generate("input"))
preds = predictions.collect()

[Stage 31:>                                                         (0 + 1) / 1]

CPU times: user 19 ms, sys: 0 ns, total: 19 ms
Wall time: 5.87 s


                                                                                

In [18]:
%%time
predictions = df2.withColumn("preds", generate(col("input")))
preds = predictions.collect()

[Stage 34:>                                                         (0 + 1) / 1]

CPU times: user 12.7 ms, sys: 0 ns, total: 12.7 ms
Wall time: 5.84 s


                                                                                

In [19]:
predictions.show(truncate=60)

[Stage 37:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|                               Translate English to French: |                                                           :|
|Translate English to French: I found myself getting incre...|  Je me suis rendu de plus en plus en colère à mesure que ce|
|Translate English to French: The comparisons between the ...|Les comparaisons entre la version de 1995 et cette versio...|
|Translate English to French: Doesn't anyone bother to che...|          Ne s'agit-il pas de vérifier où viennent ces boues|
|Translate English to French: Don't get me wrong, I love t...|Ne m'oubliez pas, je m'aime la série de télévision de League|
|Transla

                                                                                

### Using Triton Server

#### Start Triton Server on each executor

In [20]:
num_executors = 1

nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:22.07-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="256M",
            volumes={
                "/home/leey/devpub/leewyang/sparkext/examples/models_hf": {"bind": "/models", "mode": "ro"},
                "/home/leey/huggingface/cache": {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.mapPartitions(start_triton).collect()

                                                                                

[True]

#### Run inference

In [21]:
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [22]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|...But not this one! I always wanted to know "what happened" next. We will never know for sure what happened because ...|
|I found myself getting increasingly angry as this movie progressed.<br /><br />Basically, Dr. Crawford (Dennis Hopper...|
|The comparisons between the 1995 version and this are inevitable. Sadly, this version falls far short.<br /><br />The...|
|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about its supposed revelatio...|
|Don't get me wrong, I love the TV series of League Of Gentlemen. It was funny, twisted and completely inspired. I was...|
|Made it through

In [23]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [24]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [25]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|                                                                                           Translate English to German: |
|                         Translate English to German: I found myself getting increasingly angry as this movie progressed|
|                           Translate English to German: The comparisons between the 1995 version and this are inevitable|
|Translate English to German: Doesn't anyone bother to check where this kind of sludge comes from before blathering on...|
|                            Translate English to German: Don't get me wrong, I love the TV series of League Of Gentlemen|
|           Tran

In [26]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [27]:
generate = predict_batch_udf(triton_fn,
                             triton_uri="localhost:8001",
                             model_name="hf_generation",
                             return_type=StringType(),
                             input_tensor_shapes=[[-1,1]],
                             batch_size=10)

In [28]:
%%time
# first pass caches model/fn
predictions = df1.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 46:>                                                         (0 + 1) / 1]

CPU times: user 13.2 ms, sys: 20 ms, total: 33.2 ms
Wall time: 6.7 s


                                                                                

In [29]:
%%time
predictions = df1.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 49:>                                                         (0 + 1) / 1]

CPU times: user 13.8 ms, sys: 1.12 ms, total: 15 ms
Wall time: 5.97 s


                                                                                

In [30]:
%%time
predictions = df1.withColumn("preds", generate("input"))
preds = predictions.collect()

[Stage 52:>                                                         (0 + 1) / 1]

CPU times: user 10.6 ms, sys: 3.72 ms, total: 14.4 ms
Wall time: 5.92 s


                                                                                

In [31]:
%%time
predictions = df1.withColumn("preds", generate(col("input")))
preds = predictions.collect()

[Stage 55:>                                                         (0 + 1) / 1]

CPU times: user 13 ms, sys: 1.5 ms, total: 14.5 ms
Wall time: 6.05 s


                                                                                

In [32]:
predictions.show(truncate=60)

[Stage 58:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|                               Translate English to German: |                                    Übersetzen Sie Englisch.|
|Translate English to German: I found myself getting incre...| Ich sah mich immer ärgerlicher, als dieser Film weiterging.|
|Translate English to German: The comparisons between the ...|Die Vergleiche zwischen der Version 1995 und diese sind u...|
|Translate English to German: Doesn't anyone bother to che...|    Warum hat man sich nicht angefreut, zu überprüfen, woher|
|Translate English to German: Don't get me wrong, I love t...|Verstehen Sie mich nicht falsch, ich liebe die TV-Serie L...|
|Transla

                                                                                

In [33]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100)

In [34]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|                                                                                           Translate English to French: |
|                         Translate English to French: I found myself getting increasingly angry as this movie progressed|
|                           Translate English to French: The comparisons between the 1995 version and this are inevitable|
|Translate English to French: Doesn't anyone bother to check where this kind of sludge comes from before blathering on...|
|                            Translate English to French: Don't get me wrong, I love the TV series of League Of Gentlemen|
|           Tran

In [35]:
%%time
predictions = df2.withColumn("preds", generate(struct("input")))
preds = predictions.collect()

[Stage 64:>                                                         (0 + 1) / 1]

CPU times: user 12.6 ms, sys: 881 µs, total: 13.5 ms
Wall time: 6.57 s


                                                                                

In [36]:
%%time
predictions = df2.withColumn("preds", generate("input"))
preds = predictions.collect()

[Stage 67:>                                                         (0 + 1) / 1]

CPU times: user 15.7 ms, sys: 4.94 ms, total: 20.7 ms
Wall time: 5.95 s


                                                                                

In [37]:
%%time
predictions = df2.withColumn("preds", generate(col("input")))
preds = predictions.collect()

[Stage 70:>                                                         (0 + 1) / 1]

CPU times: user 21 ms, sys: 1.27 ms, total: 22.2 ms
Wall time: 5.94 s


                                                                                

In [38]:
predictions.show(truncate=60)

[Stage 73:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|                               Translate English to French: |                                                           :|
|Translate English to French: I found myself getting incre...|  Je me suis rendu de plus en plus en colère à mesure que ce|
|Translate English to French: The comparisons between the ...|Les comparaisons entre la version de 1995 et cette versio...|
|Translate English to French: Doesn't anyone bother to che...|          Ne s'agit-il pas de vérifier où viennent ces boues|
|Translate English to French: Don't get me wrong, I love t...|Ne m'oubliez pas, je m'aime la série de télévision de League|
|Transla

                                                                                

#### Stop Triton Server on each executor

In [39]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.mapPartitions(stop_triton).collect()

                                                                                

[True]

In [40]:
spark.stop()