Skip to content

Commit

Permalink
[spark] Support requirements.txt in model tar file
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Apr 27, 2023
1 parent 4a2d5fc commit 2e4fbe8
Show file tree
Hide file tree
Showing 17 changed files with 247 additions and 127 deletions.
6 changes: 6 additions & 0 deletions docker/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@ ADD --chmod=644 https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java
# Set environment
ENV PYTORCH_PRECXX11 true
ENV OMP_NUM_THREADS 1
ENV DJL_CACHE_DIR /tmp/.djl.ai
ENV HUGGINGFACE_HUB_CACHE /tmp
ENV TRANSFORMERS_CACHE /tmp

RUN echo 'export SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS -Dai.djl.pytorch.graph_optimizer=false"' >> /opt/hadoop-config/spark-env.sh
RUN echo "export PYTORCH_PRECXX11=true" >> /opt/hadoop-config/spark-env.sh
RUN echo "export OMP_NUM_THREADS=1" >> /opt/hadoop-config/spark-env.sh
RUN echo "export DJL_CACHE_DIR=/tmp/.djl.ai" >> /opt/hadoop-config/spark-env.sh
RUN echo "export HUGGINGFACE_HUB_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "export TRANSFORMERS_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "spark.yarn.appMasterEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.executorEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.hadoop.fs.s3a.connection.maximum 1000" >> /opt/hadoop-config/spark-defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import pandas as pd
from typing import Iterator
from transformers import pipeline
import os
from ...util import files_util, dependency_util
from time import gmtime, strftime


class WhisperSpeechRecognizer:
Expand Down Expand Up @@ -54,6 +57,23 @@ def recognize(self, dataset, generate_kwargs=None, **kwargs):
model=model_name_or_url, chunk_length_s=30, **kwargs)
bc_pipe = sc.broadcast(pipe)

sc = SparkContext._active_spark_context
if self.model_url:
timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
cache_dir = os.path.join(files_util.get_cache_dir(), "cache/repo/model/audio/automatic_speech_recognition/",
self.model_name if self.model_name else "model_" + timestamp_suffix)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
pipe = pipeline("automatic-speech-recognition", generate_kwargs=generate_kwargs,
model=cache_dir, chunk_length_s=30, **kwargs)
bc_pipe = sc.broadcast(pipe)
elif self.model_name:
pipe = pipeline("automatic-speech-recognition", generate_kwargs=generate_kwargs,
model=self.model_name, chunk_length_s=30, **kwargs)
bc_pipe = sc.broadcast(pipe)
else:
raise ValueError("Either model_url or model_name must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for s in iterator:
Expand Down
23 changes: 12 additions & 11 deletions extensions/spark/setup/djl_spark/task/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@

"""DJL Spark Tasks Text API."""

from . import huggingface_text_decoder
from . import huggingface_text_encoder
from . import huggingface_text_tokenizer
from . import text_embedder
from . import text_decoder, text_encoder, text_tokenizer, text_embedder, text2text_generator, text_generator

HuggingFaceTextDecoder = huggingface_text_decoder.HuggingFaceTextDecoder
HuggingFaceTextEncoder = huggingface_text_encoder.HuggingFaceTextEncoder
HuggingFaceTextTokenizer = huggingface_text_tokenizer.HuggingFaceTextTokenizer
TextDecoder = text_decoder.TextDecoder
TextEncoder = text_encoder.TextEncoder
TextTokenizer = text_tokenizer.TextTokenizer
TextEmbedder = text_embedder.TextEmbedder
Text2TextGenerator = text2text_generator.Text2TextGenerator
TextGenerator = text_generator.TextGenerator

# Remove unnecessary modules to avoid duplication in API.
del huggingface_text_decoder
del huggingface_text_encoder
del huggingface_text_tokenizer
del text_embedder
del text_decoder
del text_encoder
del text_tokenizer
del text_embedder
del text2text_generator
del text_generator
29 changes: 20 additions & 9 deletions extensions/spark/setup/djl_spark/task/text/text2text_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,28 @@
from pyspark.sql.types import StringType
from typing import Iterator
from transformers import pipeline
import os
from ...util import files_util, dependency_util
from time import gmtime, strftime


class Text2TextGenerator:

def __init__(self, input_col, output_col, engine, model_url=None, model_name=None):
def __init__(self, input_col, output_col, engine, model_url=None, hf_model_id=None):
"""
Initializes the Text2TextGenerator.
:param input_col: The input column
:param output_col: The output column
:param engine: The engine. Currently only PyTorch is supported.
:param model_url: The model URL
:param model_name: The model name
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.engine = engine
self.model_url = model_url
self.model_name = model_name
self.hf_model_id = hf_model_id

def generate(self, dataset, **kwargs):
"""
Expand All @@ -43,16 +46,24 @@ def generate(self, dataset, **kwargs):
:param dataset: input dataset
:return: output dataset
"""
if not self.model_url and not self.model_name:
raise ValueError("Either model_url or model_name must be provided.")
model_name_or_url = self.model_url if self.model_url else self.model_name
if self.model_url:
timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
cache_dir = os.path.join(files_util.get_cache_dir(), "cache/repo/model/nlp/text2text_generation/",
self.hf_model_id if self.hf_model_id else "model_" + timestamp_suffix)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
generator = pipeline('text2text-generation', model=model_name_or_url, **kwargs)
pipe = pipeline('text2text-generation', model=model_id_or_path, **kwargs)
for s in iterator:
output = generator(s.tolist())
text = [o["generated_text"] for o in output]
output = pipe(s.tolist())
text = map(lambda d: d['generated_text'], output)
yield pd.Series(text)

return dataset.withColumn(self.output_col, predict_udf(self.input_col))
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@
from pyspark.sql import DataFrame


class HuggingFaceTextDecoder:
class TextDecoder:

def __init__(self, input_col, output_col, name):
def __init__(self, input_col, output_col, hf_model_id):
"""
Initializes the HuggingFaceTextDecoder.
Initializes the TextDecoder.
:param input_col: The input column
:param output_col: The output column
:param name: The name of the tokenizer
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.name = name
self.hf_model_id = hf_model_id

def decode(self, dataset):
"""
Performs sentence encoding on the provided dataset.
Performs sentence decoding on the provided dataset.
:param dataset: input dataset
:return: output dataset
"""
sc = SparkContext._active_spark_context
decoder = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextDecoder() \
decoder = sc._jvm.ai.djl.spark.task.text.TextDecoder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setName(self.name)
.setHfModelId(self.hf_model_id)
return DataFrame(decoder.decode(dataset._jdf),
dataset.sparkSession)
12 changes: 6 additions & 6 deletions extensions/spark/setup/djl_spark/task/text/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@
class TextEmbedder:

def __init__(self, input_col, output_col, engine, model_url,
output_class=None, translator=None):
output_class=None, translator_factory=None):
"""
Initializes the TextEmbedder.
:param input_col: The input column
:param output_col: The output column
:param engine (optional): The engine
:param engine: The engine
:param model_url: The model URL
:param output_class (optional): The output class
:param translator (optional): The translator. Default is TextEmbeddingTranslator.
:param translator_factory (optional): The translator factory. Default is TextEmbeddingTranslatorFactory.
"""
self.input_col = input_col
self.output_col = output_col
self.engine = engine
self.model_url = model_url
self.output_class = output_class
self.translator = translator
self.translator_factory = translator_factory

def embed(self, dataset):
"""
Expand All @@ -47,8 +47,8 @@ def embed(self, dataset):
embedder = sc._jvm.ai.djl.spark.task.text.TextEmbedder()
if self.output_class is not None:
embedder = embedder.setOutputClass(self.output_class)
if self.translator is not None:
embedder = embedder.setTranslator(self.translator)
if self.translator_factory is not None:
embedder = embedder.setTranslatorFactory(self.translator_factory)
embedder = embedder.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setEngine(self.engine) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
from pyspark.sql import DataFrame


class HuggingFaceTextEncoder:
class TextEncoder:

def __init__(self, input_col, output_col, name):
def __init__(self, input_col, output_col, hf_model_id):
"""
Initializes the HuggingFaceTextEncoder.
Initializes the TextEncoder.
:param input_col: The input column
:param output_col: The output column
:param name: The name of the tokenizer
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.name = name
self.hf_model_id = hf_model_id

def encode(self, dataset):
"""
Expand All @@ -37,9 +37,9 @@ def encode(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
encoder = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextEncoder() \
encoder = sc._jvm.ai.djl.spark.task.text.TextEncoder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setName(self.name)
.setHfModelId(self.hf_model_id)
return DataFrame(encoder.encode(dataset._jdf),
dataset.sparkSession)
29 changes: 20 additions & 9 deletions extensions/spark/setup/djl_spark/task/text/text_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@
from pyspark.sql.types import StringType
from typing import Iterator
from transformers import pipeline
import os
from ...util import files_util, dependency_util
from time import gmtime, strftime


class TextGenerator:

def __init__(self, input_col, output_col, engine, model_url=None, model_name=None):
def __init__(self, input_col, output_col, engine, model_url=None, hf_model_id=None):
"""
Initializes the TextGenerator.
:param input_col: The input column
:param output_col: The output column
:param engine: The engine. Currently only PyTorch is supported.
:param model_url: The model URL
:param model_name: The model name
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.engine = engine
self.model_url = model_url
self.model_name = model_name
self.hf_model_id = hf_model_id

def generate(self, dataset, **kwargs):
"""
Expand All @@ -45,17 +48,25 @@ def generate(self, dataset, **kwargs):
:return: output dataset
"""
sc = SparkContext._active_spark_context
if not self.model_url and not self.model_name:
raise ValueError("Either model_url or model_name must be provided.")
model_name_or_url = self.model_url if self.model_url else self.model_name
pipe = pipeline('text-generation', model=model_name_or_url, **kwargs)
bc_pipe = sc.broadcast(pipe)
if self.model_url:
timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
cache_dir = os.path.join(files_util.get_cache_dir(), "cache/repo/model/nlp/text_generation/",
self.hf_model_id if self.hf_model_id else "model_" + timestamp_suffix)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
pipe = pipeline('text-generation', model=cache_dir, **kwargs)
bc_pipe = sc.broadcast(pipe)
elif self.hf_model_id:
pipe = pipeline('text-generation', model=self.hf_model_id, **kwargs)
bc_pipe = sc.broadcast(pipe)
else:
raise ValueError("Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for s in iterator:
output = bc_pipe.value(s.tolist())
text = [o["generated_text"] for o in output[0]]
text = map(lambda d: d['generated_text'], output[0])
yield pd.Series(text)

return dataset.withColumn(self.output_col, predict_udf(self.input_col))
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@
from pyspark.sql import DataFrame


class HuggingFaceTextTokenizer:
class TextTokenizer:

def __init__(self, input_col, output_col, name):
def __init__(self, input_col, output_col, hf_model_id):
"""
Initializes the HuggingFaceTextEncoder.
Initializes the TextTokenizer.
:param input_col: The input column
:param output_col: The output column
:param name: The name of the tokenizer
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.name = name
self.hf_model_id = hf_model_id

def tokenize(self, dataset):
"""
Performs sentence encoding on the provided dataset.
Performs sentence tokenization on the provided dataset.
:param dataset: input dataset
:return: output dataset
"""
sc = SparkContext._active_spark_context
tokenizer = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextTokenizer() \
tokenizer = sc._jvm.ai.djl.spark.task.text.TextTokenizer() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setName(self.name)
.setHfModelId(self.hf_model_id)
return DataFrame(tokenizer.tokenize(dataset._jdf),
dataset.sparkSession)
12 changes: 0 additions & 12 deletions extensions/spark/setup/djl_spark/translator/__init__.py

This file was deleted.

12 changes: 0 additions & 12 deletions extensions/spark/setup/djl_spark/translator/vision/__init__.py

This file was deleted.

Loading

0 comments on commit 2e4fbe8

Please sign in to comment.