From 2e4fbe858d38694c454acafa39f5d5df2713dc09 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 27 Apr 2023 14:23:02 -0700 Subject: [PATCH] [spark] Support requirements.txt in model tar file --- docker/spark/Dockerfile | 6 ++ .../task/audio/whisper_speech_recognizer.py | 20 +++++ .../setup/djl_spark/task/text/__init__.py | 23 +++--- .../task/text/text2text_generator.py | 29 ++++--- ...ngface_text_decoder.py => text_decoder.py} | 16 ++-- .../djl_spark/task/text/text_embedder.py | 12 +-- ...ngface_text_encoder.py => text_encoder.py} | 14 ++-- .../djl_spark/task/text/text_generator.py | 29 ++++--- ...ce_text_tokenizer.py => text_tokenizer.py} | 16 ++-- .../setup/djl_spark/translator/__init__.py | 12 --- .../djl_spark/translator/vision/__init__.py | 12 --- .../djl_spark/translator/vision/translator.py | 24 ------ .../setup/djl_spark/util/dependency_util.py | 38 +++++++++ .../spark/setup/djl_spark/util/files_util.py | 81 +++++++++++++++++++ ...aceTextDecoder.scala => TextDecoder.scala} | 14 ++-- ...aceTextEncoder.scala => TextEncoder.scala} | 14 ++-- ...extTokenizer.scala => TextTokenizer.scala} | 14 ++-- 17 files changed, 247 insertions(+), 127 deletions(-) rename extensions/spark/setup/djl_spark/task/text/{huggingface_text_decoder.py => text_decoder.py} (76%) rename extensions/spark/setup/djl_spark/task/text/{huggingface_text_encoder.py => text_encoder.py} (79%) rename extensions/spark/setup/djl_spark/task/text/{huggingface_text_tokenizer.py => text_tokenizer.py} (75%) delete mode 100644 extensions/spark/setup/djl_spark/translator/__init__.py delete mode 100644 extensions/spark/setup/djl_spark/translator/vision/__init__.py delete mode 100644 extensions/spark/setup/djl_spark/translator/vision/translator.py create mode 100644 extensions/spark/setup/djl_spark/util/dependency_util.py create mode 100644 extensions/spark/setup/djl_spark/util/files_util.py rename extensions/spark/src/main/scala/ai/djl/spark/task/text/{HuggingFaceTextDecoder.scala => TextDecoder.scala} (83%) rename extensions/spark/src/main/scala/ai/djl/spark/task/text/{HuggingFaceTextEncoder.scala => TextEncoder.scala} (85%) rename extensions/spark/src/main/scala/ai/djl/spark/task/text/{HuggingFaceTextTokenizer.scala => TextTokenizer.scala} (83%) diff --git a/docker/spark/Dockerfile b/docker/spark/Dockerfile index eb6913bf19a..54edbaa6681 100644 --- a/docker/spark/Dockerfile +++ b/docker/spark/Dockerfile @@ -53,10 +53,16 @@ ADD --chmod=644 https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java # Set environment ENV PYTORCH_PRECXX11 true ENV OMP_NUM_THREADS 1 +ENV DJL_CACHE_DIR /tmp/.djl.ai +ENV HUGGINGFACE_HUB_CACHE /tmp +ENV TRANSFORMERS_CACHE /tmp RUN echo 'export SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS -Dai.djl.pytorch.graph_optimizer=false"' >> /opt/hadoop-config/spark-env.sh RUN echo "export PYTORCH_PRECXX11=true" >> /opt/hadoop-config/spark-env.sh RUN echo "export OMP_NUM_THREADS=1" >> /opt/hadoop-config/spark-env.sh +RUN echo "export DJL_CACHE_DIR=/tmp/.djl.ai" >> /opt/hadoop-config/spark-env.sh +RUN echo "export HUGGINGFACE_HUB_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh +RUN echo "export TRANSFORMERS_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh RUN echo "spark.yarn.appMasterEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf RUN echo "spark.executorEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf RUN echo "spark.hadoop.fs.s3a.connection.maximum 1000" >> /opt/hadoop-config/spark-defaults.conf diff --git a/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py b/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py index 1b2c0c4a788..ada64f45c6c 100644 --- a/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py +++ b/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py @@ -19,6 +19,9 @@ import pandas as pd from typing import Iterator from transformers import pipeline +import os +from ...util import files_util, dependency_util +from time import gmtime, strftime class WhisperSpeechRecognizer: @@ -54,6 +57,23 @@ def recognize(self, dataset, generate_kwargs=None, **kwargs): model=model_name_or_url, chunk_length_s=30, **kwargs) bc_pipe = sc.broadcast(pipe) + sc = SparkContext._active_spark_context + if self.model_url: + timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) + cache_dir = os.path.join(files_util.get_cache_dir(), "cache/repo/model/audio/automatic_speech_recognition/", + self.model_name if self.model_name else "model_" + timestamp_suffix) + files_util.download_and_extract(self.model_url, cache_dir) + dependency_util.install(cache_dir) + pipe = pipeline("automatic-speech-recognition", generate_kwargs=generate_kwargs, + model=cache_dir, chunk_length_s=30, **kwargs) + bc_pipe = sc.broadcast(pipe) + elif self.model_name: + pipe = pipeline("automatic-speech-recognition", generate_kwargs=generate_kwargs, + model=self.model_name, chunk_length_s=30, **kwargs) + bc_pipe = sc.broadcast(pipe) + else: + raise ValueError("Either model_url or model_name must be provided.") + @pandas_udf(StringType()) def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: for s in iterator: diff --git a/extensions/spark/setup/djl_spark/task/text/__init__.py b/extensions/spark/setup/djl_spark/task/text/__init__.py index 4b581a897bd..f8aaba072f4 100644 --- a/extensions/spark/setup/djl_spark/task/text/__init__.py +++ b/extensions/spark/setup/djl_spark/task/text/__init__.py @@ -13,18 +13,19 @@ """DJL Spark Tasks Text API.""" -from . import huggingface_text_decoder -from . import huggingface_text_encoder -from . import huggingface_text_tokenizer -from . import text_embedder +from . import text_decoder, text_encoder, text_tokenizer, text_embedder, text2text_generator, text_generator -HuggingFaceTextDecoder = huggingface_text_decoder.HuggingFaceTextDecoder -HuggingFaceTextEncoder = huggingface_text_encoder.HuggingFaceTextEncoder -HuggingFaceTextTokenizer = huggingface_text_tokenizer.HuggingFaceTextTokenizer +TextDecoder = text_decoder.TextDecoder +TextEncoder = text_encoder.TextEncoder +TextTokenizer = text_tokenizer.TextTokenizer TextEmbedder = text_embedder.TextEmbedder +Text2TextGenerator = text2text_generator.Text2TextGenerator +TextGenerator = text_generator.TextGenerator # Remove unnecessary modules to avoid duplication in API. -del huggingface_text_decoder -del huggingface_text_encoder -del huggingface_text_tokenizer -del text_embedder \ No newline at end of file +del text_decoder +del text_encoder +del text_tokenizer +del text_embedder +del text2text_generator +del text_generator \ No newline at end of file diff --git a/extensions/spark/setup/djl_spark/task/text/text2text_generator.py b/extensions/spark/setup/djl_spark/task/text/text2text_generator.py index b5690bbcebb..83c1a87010a 100644 --- a/extensions/spark/setup/djl_spark/task/text/text2text_generator.py +++ b/extensions/spark/setup/djl_spark/task/text/text2text_generator.py @@ -16,11 +16,14 @@ from pyspark.sql.types import StringType from typing import Iterator from transformers import pipeline +import os +from ...util import files_util, dependency_util +from time import gmtime, strftime class Text2TextGenerator: - def __init__(self, input_col, output_col, engine, model_url=None, model_name=None): + def __init__(self, input_col, output_col, engine, model_url=None, hf_model_id=None): """ Initializes the Text2TextGenerator. @@ -28,13 +31,13 @@ def __init__(self, input_col, output_col, engine, model_url=None, model_name=Non :param output_col: The output column :param engine: The engine. Currently only PyTorch is supported. :param model_url: The model URL - :param model_name: The model name + :param hf_model_id: The Huggingface model ID """ self.input_col = input_col self.output_col = output_col self.engine = engine self.model_url = model_url - self.model_name = model_name + self.hf_model_id = hf_model_id def generate(self, dataset, **kwargs): """ @@ -43,16 +46,24 @@ def generate(self, dataset, **kwargs): :param dataset: input dataset :return: output dataset """ - if not self.model_url and not self.model_name: - raise ValueError("Either model_url or model_name must be provided.") - model_name_or_url = self.model_url if self.model_url else self.model_name + if self.model_url: + timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) + cache_dir = os.path.join(files_util.get_cache_dir(), "cache/repo/model/nlp/text2text_generation/", + self.hf_model_id if self.hf_model_id else "model_" + timestamp_suffix) + files_util.download_and_extract(self.model_url, cache_dir) + dependency_util.install(cache_dir) + model_id_or_path = cache_dir + elif self.hf_model_id: + model_id_or_path = self.hf_model_id + else: + raise ValueError("Either model_url or hf_model_id must be provided.") @pandas_udf(StringType()) def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: - generator = pipeline('text2text-generation', model=model_name_or_url, **kwargs) + pipe = pipeline('text2text-generation', model=model_id_or_path, **kwargs) for s in iterator: - output = generator(s.tolist()) - text = [o["generated_text"] for o in output] + output = pipe(s.tolist()) + text = map(lambda d: d['generated_text'], output) yield pd.Series(text) return dataset.withColumn(self.output_col, predict_udf(self.input_col)) diff --git a/extensions/spark/setup/djl_spark/task/text/huggingface_text_decoder.py b/extensions/spark/setup/djl_spark/task/text/text_decoder.py similarity index 76% rename from extensions/spark/setup/djl_spark/task/text/huggingface_text_decoder.py rename to extensions/spark/setup/djl_spark/task/text/text_decoder.py index a10d8e29b48..8112b6cff1d 100644 --- a/extensions/spark/setup/djl_spark/task/text/huggingface_text_decoder.py +++ b/extensions/spark/setup/djl_spark/task/text/text_decoder.py @@ -15,31 +15,31 @@ from pyspark.sql import DataFrame -class HuggingFaceTextDecoder: +class TextDecoder: - def __init__(self, input_col, output_col, name): + def __init__(self, input_col, output_col, hf_model_id): """ - Initializes the HuggingFaceTextDecoder. + Initializes the TextDecoder. :param input_col: The input column :param output_col: The output column - :param name: The name of the tokenizer + :param hf_model_id: The Huggingface model ID """ self.input_col = input_col self.output_col = output_col - self.name = name + self.hf_model_id = hf_model_id def decode(self, dataset): """ - Performs sentence encoding on the provided dataset. + Performs sentence decoding on the provided dataset. :param dataset: input dataset :return: output dataset """ sc = SparkContext._active_spark_context - decoder = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextDecoder() \ + decoder = sc._jvm.ai.djl.spark.task.text.TextDecoder() \ .setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ - .setName(self.name) + .setHfModelId(self.hf_model_id) return DataFrame(decoder.decode(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text_embedder.py b/extensions/spark/setup/djl_spark/task/text/text_embedder.py index 27c3e300748..33bc32684f7 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_embedder.py +++ b/extensions/spark/setup/djl_spark/task/text/text_embedder.py @@ -18,23 +18,23 @@ class TextEmbedder: def __init__(self, input_col, output_col, engine, model_url, - output_class=None, translator=None): + output_class=None, translator_factory=None): """ Initializes the TextEmbedder. :param input_col: The input column :param output_col: The output column - :param engine (optional): The engine + :param engine: The engine :param model_url: The model URL :param output_class (optional): The output class - :param translator (optional): The translator. Default is TextEmbeddingTranslator. + :param translator_factory (optional): The translator factory. Default is TextEmbeddingTranslatorFactory. """ self.input_col = input_col self.output_col = output_col self.engine = engine self.model_url = model_url self.output_class = output_class - self.translator = translator + self.translator_factory = translator_factory def embed(self, dataset): """ @@ -47,8 +47,8 @@ def embed(self, dataset): embedder = sc._jvm.ai.djl.spark.task.text.TextEmbedder() if self.output_class is not None: embedder = embedder.setOutputClass(self.output_class) - if self.translator is not None: - embedder = embedder.setTranslator(self.translator) + if self.translator_factory is not None: + embedder = embedder.setTranslatorFactory(self.translator_factory) embedder = embedder.setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ .setEngine(self.engine) \ diff --git a/extensions/spark/setup/djl_spark/task/text/huggingface_text_encoder.py b/extensions/spark/setup/djl_spark/task/text/text_encoder.py similarity index 79% rename from extensions/spark/setup/djl_spark/task/text/huggingface_text_encoder.py rename to extensions/spark/setup/djl_spark/task/text/text_encoder.py index 8b7c35cd118..44378e5930a 100644 --- a/extensions/spark/setup/djl_spark/task/text/huggingface_text_encoder.py +++ b/extensions/spark/setup/djl_spark/task/text/text_encoder.py @@ -15,19 +15,19 @@ from pyspark.sql import DataFrame -class HuggingFaceTextEncoder: +class TextEncoder: - def __init__(self, input_col, output_col, name): + def __init__(self, input_col, output_col, hf_model_id): """ - Initializes the HuggingFaceTextEncoder. + Initializes the TextEncoder. :param input_col: The input column :param output_col: The output column - :param name: The name of the tokenizer + :param hf_model_id: The Huggingface model ID """ self.input_col = input_col self.output_col = output_col - self.name = name + self.hf_model_id = hf_model_id def encode(self, dataset): """ @@ -37,9 +37,9 @@ def encode(self, dataset): :return: output dataset """ sc = SparkContext._active_spark_context - encoder = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextEncoder() \ + encoder = sc._jvm.ai.djl.spark.task.text.TextEncoder() \ .setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ - .setName(self.name) + .setHfModelId(self.hf_model_id) return DataFrame(encoder.encode(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text_generator.py b/extensions/spark/setup/djl_spark/task/text/text_generator.py index 97257ffb4be..00ec376cec7 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_generator.py +++ b/extensions/spark/setup/djl_spark/task/text/text_generator.py @@ -17,11 +17,14 @@ from pyspark.sql.types import StringType from typing import Iterator from transformers import pipeline +import os +from ...util import files_util, dependency_util +from time import gmtime, strftime class TextGenerator: - def __init__(self, input_col, output_col, engine, model_url=None, model_name=None): + def __init__(self, input_col, output_col, engine, model_url=None, hf_model_id=None): """ Initializes the TextGenerator. @@ -29,13 +32,13 @@ def __init__(self, input_col, output_col, engine, model_url=None, model_name=Non :param output_col: The output column :param engine: The engine. Currently only PyTorch is supported. :param model_url: The model URL - :param model_name: The model name + :param hf_model_id: The Huggingface model ID """ self.input_col = input_col self.output_col = output_col self.engine = engine self.model_url = model_url - self.model_name = model_name + self.hf_model_id = hf_model_id def generate(self, dataset, **kwargs): """ @@ -45,17 +48,25 @@ def generate(self, dataset, **kwargs): :return: output dataset """ sc = SparkContext._active_spark_context - if not self.model_url and not self.model_name: - raise ValueError("Either model_url or model_name must be provided.") - model_name_or_url = self.model_url if self.model_url else self.model_name - pipe = pipeline('text-generation', model=model_name_or_url, **kwargs) - bc_pipe = sc.broadcast(pipe) + if self.model_url: + timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) + cache_dir = os.path.join(files_util.get_cache_dir(), "cache/repo/model/nlp/text_generation/", + self.hf_model_id if self.hf_model_id else "model_" + timestamp_suffix) + files_util.download_and_extract(self.model_url, cache_dir) + dependency_util.install(cache_dir) + pipe = pipeline('text-generation', model=cache_dir, **kwargs) + bc_pipe = sc.broadcast(pipe) + elif self.hf_model_id: + pipe = pipeline('text-generation', model=self.hf_model_id, **kwargs) + bc_pipe = sc.broadcast(pipe) + else: + raise ValueError("Either model_url or hf_model_id must be provided.") @pandas_udf(StringType()) def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: for s in iterator: output = bc_pipe.value(s.tolist()) - text = [o["generated_text"] for o in output[0]] + text = map(lambda d: d['generated_text'], output[0]) yield pd.Series(text) return dataset.withColumn(self.output_col, predict_udf(self.input_col)) diff --git a/extensions/spark/setup/djl_spark/task/text/huggingface_text_tokenizer.py b/extensions/spark/setup/djl_spark/task/text/text_tokenizer.py similarity index 75% rename from extensions/spark/setup/djl_spark/task/text/huggingface_text_tokenizer.py rename to extensions/spark/setup/djl_spark/task/text/text_tokenizer.py index c2d89fd3fab..7a8e9ce129c 100644 --- a/extensions/spark/setup/djl_spark/task/text/huggingface_text_tokenizer.py +++ b/extensions/spark/setup/djl_spark/task/text/text_tokenizer.py @@ -15,31 +15,31 @@ from pyspark.sql import DataFrame -class HuggingFaceTextTokenizer: +class TextTokenizer: - def __init__(self, input_col, output_col, name): + def __init__(self, input_col, output_col, hf_model_id): """ - Initializes the HuggingFaceTextEncoder. + Initializes the TextTokenizer. :param input_col: The input column :param output_col: The output column - :param name: The name of the tokenizer + :param hf_model_id: The Huggingface model ID """ self.input_col = input_col self.output_col = output_col - self.name = name + self.hf_model_id = hf_model_id def tokenize(self, dataset): """ - Performs sentence encoding on the provided dataset. + Performs sentence tokenization on the provided dataset. :param dataset: input dataset :return: output dataset """ sc = SparkContext._active_spark_context - tokenizer = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextTokenizer() \ + tokenizer = sc._jvm.ai.djl.spark.task.text.TextTokenizer() \ .setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ - .setName(self.name) + .setHfModelId(self.hf_model_id) return DataFrame(tokenizer.tokenize(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/translator/__init__.py b/extensions/spark/setup/djl_spark/translator/__init__.py deleted file mode 100644 index dca648ed5ed..00000000000 --- a/extensions/spark/setup/djl_spark/translator/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file -# except in compliance with the License. A copy of the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" -# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for -# the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/extensions/spark/setup/djl_spark/translator/vision/__init__.py b/extensions/spark/setup/djl_spark/translator/vision/__init__.py deleted file mode 100644 index dca648ed5ed..00000000000 --- a/extensions/spark/setup/djl_spark/translator/vision/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file -# except in compliance with the License. A copy of the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" -# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for -# the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/extensions/spark/setup/djl_spark/translator/vision/translator.py b/extensions/spark/setup/djl_spark/translator/vision/translator.py deleted file mode 100644 index 0c1e9bfea2d..00000000000 --- a/extensions/spark/setup/djl_spark/translator/vision/translator.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file -# except in compliance with the License. A copy of the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" -# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for -# the specific language governing permissions and limitations under the License. - -from pyspark import SparkContext - - -class ImageClassificationTranslator: - """A translator for Spark Image Classification tasks. - """ - - def __new__(cls, *args, **kwargs): - sc = SparkContext._active_spark_context - return sc._jvm.ai.djl.spark.translator.vision.ImageClassificationTranslator( - ) diff --git a/extensions/spark/setup/djl_spark/util/dependency_util.py b/extensions/spark/setup/djl_spark/util/dependency_util.py new file mode 100644 index 00000000000..ea7fc787fab --- /dev/null +++ b/extensions/spark/setup/djl_spark/util/dependency_util.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import os +import subprocess +import sys + + +def install(path): + """Install a Python package. + + :param path: The path to find the requirements.txt. + """ + if os.path.exists(os.path.join(path, "requirements.txt")): + cmd = [python_executable(), "-m", "pip", "install", "-r", os.path.join(path, "requirements.txt")] + try: + subprocess.run(cmd, stderr=subprocess.STDOUT, check=True) + except subprocess.CalledProcessError as e: + print("Error occurred during installing dependency:", e) + + +def python_executable(): + """Returns the path of the Python executable, if it exists. + + :return: The path of the Python executable. + """ + if not sys.executable: + raise RuntimeError("Failed to retrieve the path of the Python executable.") + return sys.executable diff --git a/extensions/spark/setup/djl_spark/util/files_util.py b/extensions/spark/setup/djl_spark/util/files_util.py new file mode 100644 index 00000000000..c1f302388a6 --- /dev/null +++ b/extensions/spark/setup/djl_spark/util/files_util.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import boto3 +import contextlib +import os +import shutil +import tempfile +import tarfile +from urllib.parse import urlparse +from urllib.request import urlopen + + +def get_cache_dir(): + """Get the cache directory. + """ + cache_dir = os.environ.get("DJL_CACHE_DIR") + if not cache_dir: + cache_dir = os.path.join(os.path.expanduser("~"), ".djl.ai") + return cache_dir + + +@contextlib.contextmanager +def tmpdir(suffix="", prefix="tmp"): + """Create a temporary directory with a context manager. The file is deleted when the + context exits. + + :param suffix: If suffix is not None, the file name will end with that suffix. + :param prefix: If prefix is not None, the file name will begin with that prefix. + """ + tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix) + yield tmp + shutil.rmtree(tmp) + + +def s3_download(url, dst): + """Download a file from S3. + + :param url: The S3 URL of the file. + :param dst: The destination where the file will be saved. + """ + url = urlparse(url) + + if url.scheme != "s3": + raise ValueError("Expecting 's3' scheme, got: %s in %s" % (url.scheme, url)) + + bucket, key = url.netloc, url.path.lstrip("/") + s3 = boto3.client("s3") + s3.download_file(bucket, key, dst) + + +def download_and_extract(url, path): + """Download and extract a tar file. + + :param url: The url of the tar file. + :param path: The path to save the file. + """ + if not os.path.exists(path): + os.makedirs(path) + if not os.listdir(path): + with tmpdir() as tmp: + if url.startswith("s3://"): + dst = os.path.join(tmp, "tar_file") + s3_download(url, dst) + with tarfile.open(name=dst, mode="r:gz") as t: + t.extractall(path=path) + elif url.startswith("http://") or url.startswith("https://"): + dst = os.path.join(tmp, "tar_file") + with urlopen(url) as response, open(dst, 'wb') as f: + shutil.copyfileobj(response, f) + with tarfile.open(name=dst, mode="r:gz") as t: + t.extractall(path=path) diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextDecoder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala similarity index 83% rename from extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextDecoder.scala rename to extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala index b46867a0761..405c69e581c 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextDecoder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala @@ -24,12 +24,12 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} * * @param uid An immutable unique ID for the object and its derivatives. */ -class HuggingFaceTextDecoder(override val uid: String) extends BaseTextPredictor[Array[Long], String] +class TextDecoder(override val uid: String) extends BaseTextPredictor[Array[Long], String] with HasInputCol with HasOutputCol { - def this() = this(Identifiable.randomUID("HuggingFaceTextDecoder")) + def this() = this(Identifiable.randomUID("TextDecoder")) - final val tokenizer = new Param[String](this, "tokenizer", "The name of the tokenizer") + final val hfModelId = new Param[String](this, "hfModelId", "The Huggingface model ID") private var inputColIndex : Int = _ @@ -48,11 +48,11 @@ class HuggingFaceTextDecoder(override val uid: String) extends BaseTextPredictor def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Sets the tokenizer parameter. + * Sets the hfModelId parameter. * * @param value the value of the parameter */ - def setTokenizer(value: String): this.type = set(tokenizer, value) + def setHfModelId(value: String): this.type = set(hfModelId, value) setDefault(inputClass, classOf[Array[Long]]) setDefault(outputClass, classOf[String]) @@ -76,9 +76,9 @@ class HuggingFaceTextDecoder(override val uid: String) extends BaseTextPredictor /** @inheritdoc */ override def transformRows(iter: Iterator[Row]): Iterator[Row] = { - val t = HuggingFaceTokenizer.newInstance($(tokenizer)) + val tokenizer = HuggingFaceTokenizer.newInstance($(hfModelId)) iter.map(row => { - Row.fromSeq(row.toSeq :+ t.decode(row.getAs[Seq[Long]]($(inputCol)).toArray)) + Row.fromSeq(row.toSeq :+ tokenizer.decode(row.getAs[Seq[Long]]($(inputCol)).toArray)) }) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextEncoder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala similarity index 85% rename from extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextEncoder.scala rename to extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala index 4f5f324de5e..f02e8fbe640 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextEncoder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala @@ -24,12 +24,12 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} * * @param uid An immutable unique ID for the object and its derivatives. */ -class HuggingFaceTextEncoder(override val uid: String) extends BaseTextPredictor[String, Encoding] +class TextEncoder(override val uid: String) extends BaseTextPredictor[String, Encoding] with HasInputCol with HasOutputCol { - def this() = this(Identifiable.randomUID("HuggingFaceTextEncoder")) + def this() = this(Identifiable.randomUID("TextEncoder")) - final val tokenizer = new Param[String](this, "tokenizer", "The name of the tokenizer") + final val hfModelId = new Param[String](this, "hfModelId", "The Huggingface model ID") private var inputColIndex : Int = _ @@ -48,11 +48,11 @@ class HuggingFaceTextEncoder(override val uid: String) extends BaseTextPredictor def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Sets the tokenizer parameter. + * Sets the hfModelId parameter. * * @param value the value of the parameter */ - def setTokenizer(value: String): this.type = set(tokenizer, value) + def setHfModelId(value: String): this.type = set(hfModelId, value) setDefault(inputClass, classOf[String]) setDefault(outputClass, classOf[Encoding]) @@ -76,9 +76,9 @@ class HuggingFaceTextEncoder(override val uid: String) extends BaseTextPredictor /** @inheritdoc */ override def transformRows(iter: Iterator[Row]): Iterator[Row] = { - val t = HuggingFaceTokenizer.newInstance($(tokenizer)) + val tokenizer = HuggingFaceTokenizer.newInstance($(hfModelId)) iter.map(row => { - val encoding = t.encode(row.getString(inputColIndex)) + val encoding = tokenizer.encode(row.getString(inputColIndex)) Row.fromSeq(row.toSeq :+ Row(encoding.getIds, encoding.getTypeIds, encoding.getAttentionMask)) }) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextTokenizer.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala similarity index 83% rename from extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextTokenizer.scala rename to extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala index 2684e9b0be4..9bc8361ff1b 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/HuggingFaceTextTokenizer.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala @@ -24,12 +24,12 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} * * @param uid An immutable unique ID for the object and its derivatives. */ -class HuggingFaceTextTokenizer(override val uid: String) extends BaseTextPredictor[String, Array[String]] +class TextTokenizer(override val uid: String) extends BaseTextPredictor[String, Array[String]] with HasInputCol with HasOutputCol { - def this() = this(Identifiable.randomUID("HuggingFaceTextTokenizer")) + def this() = this(Identifiable.randomUID("TextTokenizer")) - final val tokenizer = new Param[String](this, "tokenizer", "The name of the tokenizer") + final val hfModelId = new Param[String](this, "hfModelId", "The Huggingface model ID") private var inputColIndex: Int = _ @@ -48,11 +48,11 @@ class HuggingFaceTextTokenizer(override val uid: String) extends BaseTextPredict def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Sets the tokenizer parameter. + * Sets the hfModelId parameter. * * @param value the value of the parameter */ - def setTokenizer(value: String): this.type = set(tokenizer, value) + def setHfModelId(value: String): this.type = set(hfModelId, value) setDefault(inputClass, classOf[String]) setDefault(outputClass, classOf[Array[String]]) @@ -76,9 +76,9 @@ class HuggingFaceTextTokenizer(override val uid: String) extends BaseTextPredict /** @inheritdoc */ override def transformRows(iter: Iterator[Row]): Iterator[Row] = { - val t = HuggingFaceTokenizer.newInstance($(tokenizer)) + val tokenizer = HuggingFaceTokenizer.newInstance($(hfModelId)) iter.map(row => { - Row.fromSeq(row.toSeq :+ t.tokenize(row.getString(inputColIndex)).toArray) + Row.fromSeq(row.toSeq :+ tokenizer.tokenize(row.getString(inputColIndex)).toArray) }) }