From dbad3136ef4ab502e0eb7b5958f3ebfdcc375e0b Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Sun, 13 Aug 2023 00:04:47 -0700 Subject: [PATCH] Fix problems found when running hugging face text summariztion model on large input. (#929) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Pass `batch_mem_size` correctly to storage engines 2. Fix hugging face crash when the number of input token is large then the limit. 👋 Thanks for submitting a Pull Request to EvaDB! 🙌 We want to make contributing to EvaDB as easy and transparent as possible. Here are a few tips to get you started: - 🔍 Search existing EvaDB [PRs](https://github.com/georgia-tech-db/eva/pulls) to see if a similar PR already exists. - 🔗 Link this PR to a EvaDB [issue](https://github.com/georgia-tech-db/eva/issues) to help us understand what bug fix or feature is being implemented. - 📈 Provide before and after profiling results to help us quantify the improvement your PR provides (if applicable). 👉 Please see our ✅ [Contributing Guide](https://evadb.readthedocs.io/en/stable/source/contribute/index.html) for more details. --- evadb/executor/delete_executor.py | 12 ++--- evadb/optimizer/rules/rules.py | 2 + evadb/third_party/huggingface/model.py | 6 +++ test/optimizer/rules/test_batch_mem_size.py | 60 +++++++++++++++++++++ test/udfs/test_hugging_face.py | 47 ++++++++++++++++ 5 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 test/optimizer/rules/test_batch_mem_size.py create mode 100644 test/udfs/test_hugging_face.py diff --git a/evadb/executor/delete_executor.py b/evadb/executor/delete_executor.py index 2bf2f8cf6..766822746 100644 --- a/evadb/executor/delete_executor.py +++ b/evadb/executor/delete_executor.py @@ -45,31 +45,31 @@ def predicate_node_to_filter_clause( left = predicate_node.get_child(0) right = predicate_node.get_child(1) - if type(left) == TupleValueExpression: + if isinstance(left, TupleValueExpression): column = left.name x = table.columns[column] - elif type(left) == ConstantValueExpression: + elif isinstance(left, ConstantValueExpression): value = left.value x = value else: left_filter_clause = self.predicate_node_to_filter_clause(table, left) - if type(right) == TupleValueExpression: + if isinstance(right, TupleValueExpression): column = right.name y = table.columns[column] - elif type(right) == ConstantValueExpression: + elif isinstance(right, ConstantValueExpression): value = right.value y = value else: right_filter_clause = self.predicate_node_to_filter_clause(table, right) - if type(predicate_node) == LogicalExpression: + if isinstance(predicate_node, LogicalExpression): if predicate_node.etype == ExpressionType.LOGICAL_AND: filter_clause = and_(left_filter_clause, right_filter_clause) elif predicate_node.etype == ExpressionType.LOGICAL_OR: filter_clause = or_(left_filter_clause, right_filter_clause) - elif type(predicate_node) == ComparisonExpression: + elif isinstance(predicate_node, ComparisonExpression): assert ( predicate_node.etype != ExpressionType.COMPARE_CONTAINS and predicate_node.etype != ExpressionType.COMPARE_IS_CONTAINED diff --git a/evadb/optimizer/rules/rules.py b/evadb/optimizer/rules/rules.py index 883a68e3d..3c5f39c3a 100644 --- a/evadb/optimizer/rules/rules.py +++ b/evadb/optimizer/rules/rules.py @@ -872,6 +872,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext): # read in a batch from storage engine. # Todo: Experiment heuristics. after = SeqScanPlan(None, before.target_list, before.alias) + batch_mem_size = context.db.config.get_value("executor", "batch_mem_size") after.append_child( StoragePlan( before.table_obj, @@ -880,6 +881,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext): sampling_rate=before.sampling_rate, sampling_type=before.sampling_type, chunk_params=before.chunk_params, + batch_mem_size=batch_mem_size, ) ) yield after diff --git a/evadb/third_party/huggingface/model.py b/evadb/third_party/huggingface/model.py index 9e66284b9..04d182f48 100644 --- a/evadb/third_party/huggingface/model.py +++ b/evadb/third_party/huggingface/model.py @@ -33,6 +33,12 @@ class TextHFModel(AbstractHFUdf): Base Model for all HF Models that take in text as input """ + def __call__(self, *args, **kwargs): + # Use truncation=True to handle the case where num of tokens is larger + # than limit + # Ref: https://stackoverflow.com/questions/66954682/token-indices-sequence-length-is-longer-than-the-specified-maximum-sequence-leng + return self.forward(args[0], truncation=True) + def input_formatter(self, inputs: Any): return inputs.values.flatten().tolist() diff --git a/test/optimizer/rules/test_batch_mem_size.py b/test/optimizer/rules/test_batch_mem_size.py new file mode 100644 index 000000000..70033b014 --- /dev/null +++ b/test/optimizer/rules/test_batch_mem_size.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from test.util import get_evadb_for_testing + +from mock import ANY, patch + +from evadb.server.command_handler import execute_query_fetch_all +from evadb.storage.sqlite_storage_engine import SQLStorageEngine + + +class BatchMemSizeTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.evadb = get_evadb_for_testing() + # reset the catalog manager before running each test + cls.evadb.catalog().reset() + + @classmethod + def tearDownClass(cls): + execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS MyCSV;") + + def test_batch_mem_size_for_sqlite_storage_engine(self): + """ + This testcase make sure that the `batch_mem_size` is correctly passed to + the storage engine. + """ + test_batch_mem_size = 100 + self.evadb.config.update_value( + "executor", "batch_mem_size", test_batch_mem_size + ) + create_table_query = """ + CREATE TABLE IF NOT EXISTS MyCSV ( + id INTEGER UNIQUE, + frame_id INTEGER, + video_id INTEGER, + dataset_name TEXT(30), + label TEXT(30), + bbox NDARRAY FLOAT32(4), + object_id INTEGER + );""" + execute_query_fetch_all(self.evadb, create_table_query) + + select_table_query = "SELECT * FROM MyCSV;" + with patch.object(SQLStorageEngine, "read") as mock_read: + mock_read.__iter__.return_value = [] + execute_query_fetch_all(self.evadb, select_table_query) + mock_read.assert_called_with(ANY, test_batch_mem_size) diff --git a/test/udfs/test_hugging_face.py b/test/udfs/test_hugging_face.py new file mode 100644 index 000000000..a2e537a8b --- /dev/null +++ b/test/udfs/test_hugging_face.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pandas as pd +from mock import MagicMock + +from evadb.third_party.huggingface.model import TextHFModel + + +class TestTextHFModel(TextHFModel): + @property + def default_pipeline_args(self) -> dict: + # We need to improve the hugging face interface, passing + # UdfCatalogEntry into UDF is not ideal. + return { + "task": "summarization", + "model": "sshleifer/distilbart-cnn-12-6", + "min_length": 5, + "max_length": 100, + } + + +class HuggingFaceTest(unittest.TestCase): + def test_hugging_face_with_large_input(self): + udf_obj = MagicMock() + udf_obj.metadata = [] + text_summarization_model = TestTextHFModel(udf_obj) + + large_text = pd.DataFrame([{"text": "hello" * 4096}]) + try: + text_summarization_model(large_text) + except IndexError: + self.fail("hugging face with large input raised IndexError.")