Skip to content

Commit

Permalink
Fix problems found when running hugging face text summariztion model …
Browse files Browse the repository at this point in the history
…on large input. (#929)

1. Pass `batch_mem_size` correctly to storage engines
2. Fix hugging face crash when the number of input token is large then
the limit.

👋 Thanks for submitting a Pull Request to EvaDB!

🙌 We want to make contributing to EvaDB as easy and transparent as
possible. Here are a few tips to get you started:

- 🔍 Search existing EvaDB
[PRs](https://github.com/georgia-tech-db/eva/pulls) to see if a similar
PR already exists.
- 🔗 Link this PR to a EvaDB
[issue](https://github.com/georgia-tech-db/eva/issues) to help us
understand what bug fix or feature is being implemented.
- 📈 Provide before and after profiling results to help us quantify the
improvement your PR provides (if applicable).

👉 Please see our ✅ [Contributing
Guide](https://evadb.readthedocs.io/en/stable/source/contribute/index.html)
for more details.
  • Loading branch information
xzdandy committed Aug 13, 2023
1 parent a5324ba commit dbad313
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 6 deletions.
12 changes: 6 additions & 6 deletions evadb/executor/delete_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,31 @@ def predicate_node_to_filter_clause(
left = predicate_node.get_child(0)
right = predicate_node.get_child(1)

if type(left) == TupleValueExpression:
if isinstance(left, TupleValueExpression):
column = left.name
x = table.columns[column]
elif type(left) == ConstantValueExpression:
elif isinstance(left, ConstantValueExpression):
value = left.value
x = value
else:
left_filter_clause = self.predicate_node_to_filter_clause(table, left)

if type(right) == TupleValueExpression:
if isinstance(right, TupleValueExpression):
column = right.name
y = table.columns[column]
elif type(right) == ConstantValueExpression:
elif isinstance(right, ConstantValueExpression):
value = right.value
y = value
else:
right_filter_clause = self.predicate_node_to_filter_clause(table, right)

if type(predicate_node) == LogicalExpression:
if isinstance(predicate_node, LogicalExpression):
if predicate_node.etype == ExpressionType.LOGICAL_AND:
filter_clause = and_(left_filter_clause, right_filter_clause)
elif predicate_node.etype == ExpressionType.LOGICAL_OR:
filter_clause = or_(left_filter_clause, right_filter_clause)

elif type(predicate_node) == ComparisonExpression:
elif isinstance(predicate_node, ComparisonExpression):
assert (
predicate_node.etype != ExpressionType.COMPARE_CONTAINS
and predicate_node.etype != ExpressionType.COMPARE_IS_CONTAINED
Expand Down
2 changes: 2 additions & 0 deletions evadb/optimizer/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext):
# read in a batch from storage engine.
# Todo: Experiment heuristics.
after = SeqScanPlan(None, before.target_list, before.alias)
batch_mem_size = context.db.config.get_value("executor", "batch_mem_size")
after.append_child(
StoragePlan(
before.table_obj,
Expand All @@ -880,6 +881,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext):
sampling_rate=before.sampling_rate,
sampling_type=before.sampling_type,
chunk_params=before.chunk_params,
batch_mem_size=batch_mem_size,
)
)
yield after
Expand Down
6 changes: 6 additions & 0 deletions evadb/third_party/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class TextHFModel(AbstractHFUdf):
Base Model for all HF Models that take in text as input
"""

def __call__(self, *args, **kwargs):
# Use truncation=True to handle the case where num of tokens is larger
# than limit
# Ref: https://stackoverflow.com/questions/66954682/token-indices-sequence-length-is-longer-than-the-specified-maximum-sequence-leng
return self.forward(args[0], truncation=True)

def input_formatter(self, inputs: Any):
return inputs.values.flatten().tolist()

Expand Down
60 changes: 60 additions & 0 deletions test/optimizer/rules/test_batch_mem_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test.util import get_evadb_for_testing

from mock import ANY, patch

from evadb.server.command_handler import execute_query_fetch_all
from evadb.storage.sqlite_storage_engine import SQLStorageEngine


class BatchMemSizeTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.evadb = get_evadb_for_testing()
# reset the catalog manager before running each test
cls.evadb.catalog().reset()

@classmethod
def tearDownClass(cls):
execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS MyCSV;")

def test_batch_mem_size_for_sqlite_storage_engine(self):
"""
This testcase make sure that the `batch_mem_size` is correctly passed to
the storage engine.
"""
test_batch_mem_size = 100
self.evadb.config.update_value(
"executor", "batch_mem_size", test_batch_mem_size
)
create_table_query = """
CREATE TABLE IF NOT EXISTS MyCSV (
id INTEGER UNIQUE,
frame_id INTEGER,
video_id INTEGER,
dataset_name TEXT(30),
label TEXT(30),
bbox NDARRAY FLOAT32(4),
object_id INTEGER
);"""
execute_query_fetch_all(self.evadb, create_table_query)

select_table_query = "SELECT * FROM MyCSV;"
with patch.object(SQLStorageEngine, "read") as mock_read:
mock_read.__iter__.return_value = []
execute_query_fetch_all(self.evadb, select_table_query)
mock_read.assert_called_with(ANY, test_batch_mem_size)
47 changes: 47 additions & 0 deletions test/udfs/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pandas as pd
from mock import MagicMock

from evadb.third_party.huggingface.model import TextHFModel


class TestTextHFModel(TextHFModel):
@property
def default_pipeline_args(self) -> dict:
# We need to improve the hugging face interface, passing
# UdfCatalogEntry into UDF is not ideal.
return {
"task": "summarization",
"model": "sshleifer/distilbart-cnn-12-6",
"min_length": 5,
"max_length": 100,
}


class HuggingFaceTest(unittest.TestCase):
def test_hugging_face_with_large_input(self):
udf_obj = MagicMock()
udf_obj.metadata = []
text_summarization_model = TestTextHFModel(udf_obj)

large_text = pd.DataFrame([{"text": "hello" * 4096}])
try:
text_summarization_model(large_text)
except IndexError:
self.fail("hugging face with large input raised IndexError.")

0 comments on commit dbad313

Please sign in to comment.