Skip to content

Commit

Permalink
fix: chatgpt exact cache (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiashenC committed Sep 25, 2023
1 parent 49bdf55 commit c87f7c7
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 30 deletions.
2 changes: 1 addition & 1 deletion evadb/functions/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ChatGPT(AbstractFunction):
def name(self) -> str:
return "ChatGPT"

@setup(cacheable=False, function_type="chat-completion", batchable=True)
@setup(cacheable=True, function_type="chat-completion", batchable=True)
def setup(
self,
model="gpt-3.5-turbo",
Expand Down
8 changes: 8 additions & 0 deletions evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@
EvaDB_INSTALLATION_DIR
)

DummyLLM_function_query = """CREATE FUNCTION
IF NOT EXISTS DummyLLM
IMPL '{}/../test/util.py';
""".format(
EvaDB_INSTALLATION_DIR
)

fuzzy_function_query = """CREATE FUNCTION IF NOT EXISTS FuzzDistance
INPUT (Input_Array1 NDARRAY ANYTYPE, Input_Array2 NDARRAY ANYTYPE)
OUTPUT (distance FLOAT(32, 7))
Expand Down Expand Up @@ -250,6 +257,7 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
DummyMultiObjectDetector_function_query,
DummyFeatureExtractor_function_query,
DummyNoInputFunction_function_query,
DummyLLM_function_query,
]
)

Expand Down
85 changes: 58 additions & 27 deletions evadb/optimizer/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
from evadb.constants import CACHEABLE_FUNCTIONS, DEFAULT_FUNCTION_EXPRESSION_COST
from evadb.expression.abstract_expression import AbstractExpression, ExpressionType
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.expression_utils import (
conjunction_list_to_expression_tree,
contains_single_column,
Expand Down Expand Up @@ -190,10 +191,44 @@ def extract_pushdown_predicate_for_alias(
)


def optimize_cache_key_for_tuple_value_expression(
context: "OptimizerContext", tv_expr: TupleValueExpression
):
catalog = context.db.catalog()
col_catalog_obj = tv_expr.col_object

# Optimized cache key for TupleValueExpression.
new_keys = []

if isinstance(col_catalog_obj, ColumnCatalogEntry):
table_obj = catalog.get_table_catalog_entry(col_catalog_obj.table_name)
for col in get_table_primary_columns(table_obj):
new_obj = catalog.get_column_catalog_entry(table_obj, col.name)
new_keys.append(
TupleValueExpression(
name=col.name,
table_alias=tv_expr.table_alias,
col_object=new_obj,
col_alias=f"{tv_expr.table_alias}.{col.name}",
)
)
return new_keys

return [tv_expr]


def optimize_cache_key_for_constant_value_expression(
context: "OptimizerContext", cv_expr: ConstantValueExpression
):
# No need to additional optimization for constant value expression.
return [cv_expr]


def optimize_cache_key(context: "OptimizerContext", expr: FunctionExpression):
"""Optimize the cache key
It tries to reduce the caching overhead by replacing the caching key with logically equivalent key. For instance, frame data can be replaced with frame id.
It tries to reduce the caching overhead by replacing the caching key with
logically equivalent key. For instance, frame data can be replaced with frame id.
Args:
expr (FunctionExpression): expression to optimize the caching key for.
Expand All @@ -206,27 +241,19 @@ def optimize_cache_key(context: "OptimizerContext", expr: FunctionExpression):
"""
keys = expr.children
catalog = context.db.catalog()
# handle simple one column inputs
if len(keys) == 1 and isinstance(keys[0], TupleValueExpression):
child = keys[0]
col_catalog_obj = child.col_object
if isinstance(col_catalog_obj, ColumnCatalogEntry):
new_keys = []
table_obj = catalog.get_table_catalog_entry(col_catalog_obj.table_name)
for col in get_table_primary_columns(table_obj):
new_obj = catalog.get_column_catalog_entry(table_obj, col.name)
new_keys.append(
TupleValueExpression(
name=col.name,
table_alias=child.table_alias,
col_object=new_obj,
col_alias=f"{child.table_alias}.{col.name}",
)
)

return new_keys
return keys
optimize_key_mapping_f = {
TupleValueExpression: optimize_cache_key_for_tuple_value_expression,
ConstantValueExpression: optimize_cache_key_for_constant_value_expression,
}

optimized_keys = []
for key in keys:
if type(key) not in optimize_key_mapping_f:
raise RuntimeError(f"Optimize cache key of {type(key)} is not implemented")
optimized_keys += optimize_key_mapping_f[type(key)](context, key)

return optimized_keys


def enable_cache_init(
Expand Down Expand Up @@ -280,12 +307,16 @@ def enable_cache_on_expression_tree(


def check_expr_validity_for_cache(expr: FunctionExpression):
return (
expr.name in CACHEABLE_FUNCTIONS
and not expr.has_cache()
and len(expr.children) <= 1
and isinstance(expr.children[0], TupleValueExpression)
)
valid = expr.name in CACHEABLE_FUNCTIONS and not expr.has_cache()
if len(expr.children) == 1:
# Normal function that only takes one parameter.
valid &= isinstance(expr.children[0], TupleValueExpression)
elif len(expr.children) == 2:
# LLM-based function that takes two parameters.
valid &= isinstance(expr.children[0], ConstantValueExpression) and isinstance(
expr.children[1], TupleValueExpression
)
return valid


def get_expression_execution_cost(
Expand Down
49 changes: 47 additions & 2 deletions test/integration_tests/long/test_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
shutdown_ray,
)

from mock import patch

from evadb.configuration.constants import EvaDB_ROOT_DIR
from evadb.models.storage.batch import Batch
from evadb.optimizer.operators import LogicalFunctionScan
from evadb.optimizer.plan_generator import PlanGenerator
from evadb.optimizer.rules.rules import (
Expand Down Expand Up @@ -52,12 +55,22 @@ def setUp(self):
self.evadb.catalog().reset()
ua_detrac = f"{EvaDB_ROOT_DIR}/data/ua_detrac/ua_detrac.mp4"
execute_query_fetch_all(self.evadb, f"LOAD VIDEO '{ua_detrac}' INTO DETRAC;")
execute_query_fetch_all(self.evadb, "CREATE TABLE fruitTable (data TEXT(100))")
data_list = [
"The color of apple is red",
"The color of banana is yellow",
]
for data in data_list:
execute_query_fetch_all(
self.evadb, f"INSERT INTO fruitTable (data) VALUES ('{data}')"
)
load_functions_for_testing(self.evadb)
self._load_hf_model()

def tearDown(self):
shutdown_ray()
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS DETRAC;")
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS fruitTable;")

def _verify_reuse_correctness(self, query, reuse_batch):
# Fix memory failures on CI when running reuse test cases. An issue with yolo
Expand Down Expand Up @@ -95,12 +108,44 @@ def _reuse_experiment(self, queries):
exec_times.append(timer.total_elapsed_time)
return batches, exec_times

def _strict_reuse_experiment(self, queries):
# This test mocks the apply_function_expression, if it is called, it will raise
# an exception.
exec_times = []
batches = []
for i, query in enumerate(queries):
timer = Timer()
if i != 0:
with timer, patch.object(
Batch, "apply_function_expression"
) as mock_batch_func:
mock_batch_func.side_effect = Exception("Results are not reused")
batches.append(execute_query_fetch_all(self.evadb, query))
else:
with timer:
batches.append(execute_query_fetch_all(self.evadb, query))
exec_times.append(timer.total_elapsed_time)
return batches, exec_times

def test_reuse_chatgpt(self):
from evadb.constants import CACHEABLE_FUNCTIONS

CACHEABLE_FUNCTIONS += ["DummyLLM"]
select_query = """SELECT DummyLLM('What is the fruit described in this sentence', data)
FROM fruitTable"""
batches, exec_times = self._strict_reuse_experiment(
[select_query, select_query]
)
self._verify_reuse_correctness(select_query, batches[1])
self.assertTrue(exec_times[0] > exec_times[1])

def test_reuse_when_query_is_duplicate(self):
select_query = """SELECT id, label FROM DETRAC JOIN
LATERAL HFObjectDetector(data) AS Obj(score, label, bbox) WHERE id < 15;"""
batches, exec_times = self._reuse_experiment([select_query, select_query])
batches, exec_times = self._strict_reuse_experiment(
[select_query, select_query]
)
self._verify_reuse_correctness(select_query, batches[1])
# reuse should be faster than no reuse
self.assertTrue(exec_times[0] > exec_times[1])

@gpu_skip_marker
Expand Down
52 changes: 52 additions & 0 deletions test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,55 @@ def name(self) -> str:
def forward(self, df: pd.DataFrame) -> pd.DataFrame:
ret = pd.DataFrame([{"label": "DummyNoInputFunction"}])
return ret


class DummyLLM(AbstractFunction):
@property
def name(self) -> str:
return "DummyLLM"

@decorators.setup(cacheable=True, function_type="chat-completion", batchable=True)
def setup(self, *args, **kwargs):
pass

@decorators.forward(
input_signatures=[
PandasDataframe(
columns=["query", "content", "prompt"],
column_types=[
NdArrayType.STR,
NdArrayType.STR,
NdArrayType.STR,
],
column_shapes=[(1,), (1,), (None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(1,)],
)
],
)
def forward(self, text_df):
queries = text_df[text_df.columns[0]]
content = text_df[text_df.columns[0]]

if len(text_df.columns) > 1:
queries = text_df.iloc[:, 0]
content = text_df.iloc[:, 1]

prompt = None
if len(text_df.columns) > 2:
prompt = text_df.iloc[0, 2]

results = []
for query, content in zip(queries, content):
results.append(("" if prompt is None else prompt) + query + " " + content)

df = pd.DataFrame({"response": results})

return df

0 comments on commit c87f7c7

Please sign in to comment.