Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jiashenC committed Sep 24, 2023
1 parent 8c09e5e commit d620bd2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
21 changes: 15 additions & 6 deletions evadb/optimizer/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
from evadb.constants import CACHEABLE_FUNCTIONS, DEFAULT_FUNCTION_EXPRESSION_COST
from evadb.expression.abstract_expression import AbstractExpression, ExpressionType
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.expression_utils import (
conjunction_list_to_expression_tree,
contains_single_column,
Expand All @@ -36,7 +37,6 @@
FunctionExpressionCache,
)
from evadb.expression.tuple_value_expression import TupleValueExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.parser.alias import Alias
from evadb.parser.create_statement import ColumnDefinition
from evadb.utils.kv_cache import DiskKVCache
Expand Down Expand Up @@ -191,7 +191,9 @@ def extract_pushdown_predicate_for_alias(
)


def optimize_cache_key_for_tuple_value_expression(context: "OptimizerContext", tv_expr: TupleValueExpression):
def optimize_cache_key_for_tuple_value_expression(
context: "OptimizerContext", tv_expr: TupleValueExpression
):
catalog = context.db.catalog()
col_catalog_obj = tv_expr.col_object

Expand All @@ -215,12 +217,13 @@ def optimize_cache_key_for_tuple_value_expression(context: "OptimizerContext", t
return [tv_expr]


def optimize_cache_key_for_constant_value_expression(context: "OptimizerContext", cv_expr: ConstantValueExpression):
def optimize_cache_key_for_constant_value_expression(
context: "OptimizerContext", cv_expr: ConstantValueExpression
):
# No need to additional optimization for constant value expression.
return [cv_expr]



def optimize_cache_key(context: "OptimizerContext", expr: FunctionExpression):
"""Optimize the cache key
Expand All @@ -243,8 +246,14 @@ def optimize_cache_key(context: "OptimizerContext", expr: FunctionExpression):
return optimize_cache_key_for_tuple_value_expression(context, keys[0])

# Handle ConstantValueExpressin + TupleValueExpression
if len(keys) == 2 and isinstance(keys[0], ConstantValueExpression) and isinstance(keys[1], TupleValueExpression):
return optimize_cache_key_for_constant_value_expression(context, keys[0]) + optimize_cache_key_for_tuple_value_expression(context, keys[1])
if (
len(keys) == 2
and isinstance(keys[0], ConstantValueExpression)
and isinstance(keys[1], TupleValueExpression)
):
return optimize_cache_key_for_constant_value_expression(
context, keys[0]
) + optimize_cache_key_for_tuple_value_expression(context, keys[1])

return keys

Expand Down
28 changes: 18 additions & 10 deletions test/integration_tests/long/test_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import gc
import os
import unittest

from mock import patch
from pathlib import Path
from test.markers import gpu_skip_marker, windows_skip_marker
from test.util import (
Expand All @@ -26,7 +24,10 @@
shutdown_ray,
)

from mock import patch

from evadb.configuration.constants import EvaDB_ROOT_DIR
from evadb.models.storage.batch import Batch
from evadb.optimizer.operators import LogicalFunctionScan
from evadb.optimizer.plan_generator import PlanGenerator
from evadb.optimizer.rules.rules import (
Expand All @@ -37,7 +38,6 @@
from evadb.optimizer.rules.rules_manager import RulesManager, disable_rules
from evadb.server.command_handler import execute_query_fetch_all
from evadb.utils.stats import Timer
from evadb.models.storage.batch import Batch


class ReuseTest(unittest.TestCase):
Expand All @@ -55,13 +55,15 @@ def setUp(self):
self.evadb.catalog().reset()
ua_detrac = f"{EvaDB_ROOT_DIR}/data/ua_detrac/ua_detrac.mp4"
execute_query_fetch_all(self.evadb, f"LOAD VIDEO '{ua_detrac}' INTO DETRAC;")
execute_query_fetch_all(self.evadb, f"CREATE TABLE fruitTable (data TEXT(100))")
execute_query_fetch_all(self.evadb, "CREATE TABLE fruitTable (data TEXT(100))")
data_list = [
"The color of apple is red",
"The color of banana is yellow",
]
for data in data_list:
execute_query_fetch_all(self.evadb, f"INSERT INTO fruitTable (data) VALUES ('{data}')")
execute_query_fetch_all(
self.evadb, f"INSERT INTO fruitTable (data) VALUES ('{data}')"
)
load_functions_for_testing(self.evadb)
self._load_hf_model()

Expand Down Expand Up @@ -108,13 +110,15 @@ def _reuse_experiment(self, queries):

def _strict_reuse_experiment(self, queries):
# This test mocks the apply_function_expression, if it is called, it will raise
# an exception.
# an exception.
exec_times = []
batches = []
for i, query in enumerate(queries):
timer = Timer()
if i != 0:
with timer, patch.object(Batch, "apply_function_expression") as mock_batch_func:
with timer, patch.object(
Batch, "apply_function_expression"
) as mock_batch_func:
mock_batch_func.side_effect = Exception("Results are not reused")
batches.append(execute_query_fetch_all(self.evadb, query))
else:
Expand All @@ -126,14 +130,18 @@ def _strict_reuse_experiment(self, queries):
def test_reuse_chatgpt(self):
select_query = """SELECT DummyLLM('What is the fruit described in this sentence', data)
FROM fruitTable"""
batches, exec_times = self._strict_reuse_experiment([select_query, select_query])
batches, exec_times = self._strict_reuse_experiment(
[select_query, select_query]
)
self._verify_reuse_correctness(select_query, batches[1])
self.assertTrue(exec_times[0] > exec_times[1])

def test_reuse_when_query_is_duplicate(self):
select_query = """SELECT id, label FROM DETRAC JOIN
LATERAL HFObjectDetector(data) AS Obj(score, label, bbox) WHERE id < 15;"""
batches, exec_times = self._strict_reuse_experiment([select_query, select_query])
batches, exec_times = self._strict_reuse_experiment(
[select_query, select_query]
)
self._verify_reuse_correctness(select_query, batches[1])
self.assertTrue(exec_times[0] > exec_times[1])

Expand Down Expand Up @@ -299,4 +307,4 @@ def test_drop_table_should_remove_cache(self):
cache_name
)
self.assertIsNone(function_cache)
self.assertFalse(cache_dir.exists())
self.assertFalse(cache_dir.exists())

0 comments on commit d620bd2

Please sign in to comment.