Skip to content

Commit

Permalink
feat: sample before applying predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav274 committed Sep 27, 2022
1 parent dceada7 commit 0703940
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 142 deletions.
1 change: 1 addition & 0 deletions eva/executor/storage_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def exec(self) -> Iterator[Batch]:
self.node.video,
self.node.batch_mem_size,
predicate=self.node.predicate,
sampling_rate=self.node.sampling_rate,
)
else:
return StorageEngine.read(self.node.video, self.node.batch_mem_size)
8 changes: 8 additions & 0 deletions eva/optimizer/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,15 @@ def __init__(
alias: str,
predicate: AbstractExpression = None,
target_list: List[AbstractExpression] = None,
sampling_rate: int = None,
children=None,
):
self._video = video
self._dataset_metadata = dataset_metadata
self._alias = alias
self._predicate = predicate
self._target_list = target_list
self._sampling_rate = sampling_rate
super().__init__(OperatorType.LOGICALGET, children)

@property
Expand Down Expand Up @@ -179,6 +181,10 @@ def target_list(self):
def target_list(self, target_list):
self._target_list = target_list

@property
def sampling_rate(self):
return self._sampling_rate

def __eq__(self, other):
is_subtree_equal = super().__eq__(other)
if not isinstance(other, LogicalGet):
Expand All @@ -190,6 +196,7 @@ def __eq__(self, other):
and self.alias == other.alias
and self.predicate == other.predicate
and self.target_list == other.target_list
and self.sampling_rate == other.sampling_rate
)

def __hash__(self) -> int:
Expand All @@ -201,6 +208,7 @@ def __hash__(self) -> int:
self.dataset_metadata,
self.predicate,
tuple(self.target_list or []),
self.sampling_rate,
)
)

Expand Down
98 changes: 19 additions & 79 deletions eva/optimizer/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ class RuleType(Flag):
# REWRITE RULES(LOGICAL -> LOGICAL)
EMBED_FILTER_INTO_GET = auto()
EMBED_FILTER_INTO_DERIVED_GET = auto()
PUSHDOWN_FILTER_THROUGH_SAMPLE = auto()
PUSHDOWN_PROJECT_THROUGH_SAMPLE = auto()
EMBED_SAMPLE_INTO_GET = auto()
EMBED_PROJECT_INTO_DERIVED_GET = auto()
EMBED_PROJECT_INTO_GET = auto()
PUSHDOWN_FILTER_THROUGH_JOIN = auto()
Expand Down Expand Up @@ -170,8 +169,7 @@ class Promise(IntEnum):
EMBED_PROJECT_INTO_GET = auto()
EMBED_FILTER_INTO_DERIVED_GET = auto()
EMBED_PROJECT_INTO_DERIVED_GET = auto()
PUSHDOWN_FILTER_THROUGH_SAMPLE = auto()
PUSHDOWN_PROJECT_THROUGH_SAMPLE = auto()
EMBED_SAMPLE_INTO_GET = auto()
PUSHDOWN_FILTER_THROUGH_JOIN = auto()


Expand Down Expand Up @@ -274,6 +272,7 @@ def apply(self, before: LogicalFilter, context: OptimizerContext):
alias=lget.alias,
predicate=pushdown_pred,
target_list=lget.target_list,
sampling_rate=lget.sampling_rate,
children=lget.children,
)
if unsupported_pred:
Expand All @@ -284,11 +283,12 @@ def apply(self, before: LogicalFilter, context: OptimizerContext):
else:
return before


class EmbedSampleIntoGet(Rule):
def __init__(self):
pattern = Pattern(OperatorType.LOGICALSAMPLE)
pattern.append_child(Pattern(OperatorType.LOGICALGET))
super().__init__(RuleType.EMBED_FILTER_INTO_GET, pattern)
super().__init__(RuleType.EMBED_SAMPLE_INTO_GET, pattern)

def promise(self):
return Promise.EMBED_SAMPLE_INTO_GET
Expand All @@ -301,30 +301,19 @@ def check(self, before: LogicalSample, context: OptimizerContext):
return False

def apply(self, before: LogicalSample, context: OptimizerContext):
sample_freq = before.sample_freq
sample_freq = before.sample_freq.value
lget: LogicalGet = before.children[0]
# System only supports pushing basic range predicates on id
video_alias = lget.video.alias
col_alias = f"{video_alias}.id"
pushdown_pred, unsupported_pred = extract_pushdown_predicate(
predicate, col_alias
new_get_opr = LogicalGet(
lget.video,
lget.dataset_metadata,
alias=lget.alias,
predicate=lget.predicate,
target_list=lget.target_list,
sampling_rate=sample_freq,
children=lget.children,
)
if pushdown_pred:
new_get_opr = LogicalGet(
lget.video,
lget.dataset_metadata,
alias=lget.alias,
predicate=pushdown_pred,
target_list=lget.target_list,
children=lget.children,
)
if unsupported_pred:
unsupported_opr = LogicalFilter(unsupported_pred)
unsupported_opr.append_child(new_get_opr)
return unsupported_opr
return new_get_opr
else:
return before
return new_get_opr


class EmbedProjectIntoGet(Rule):
def __init__(self):
Expand All @@ -348,6 +337,7 @@ def apply(self, before: LogicalProject, context: OptimizerContext):
alias=lget.alias,
predicate=lget.predicate,
target_list=target_list,
sampling_rate=lget.sampling_rate,
children=lget.children,
)

Expand Down Expand Up @@ -411,56 +401,6 @@ def apply(self, before: LogicalProject, context: OptimizerContext):
return new_opr


class PushdownFilterThroughSample(Rule):
def __init__(self):
pattern = Pattern(OperatorType.LOGICALFILTER)
pattern_sample = Pattern(OperatorType.LOGICALSAMPLE)
pattern_sample.append_child(Pattern(OperatorType.LOGICALGET))
pattern.append_child(pattern_sample)
super().__init__(RuleType.PUSHDOWN_FILTER_THROUGH_SAMPLE, pattern)

def promise(self):
return Promise.PUSHDOWN_FILTER_THROUGH_SAMPLE

def check(self, before: Operator, context: OptimizerContext):
# nothing else to check if logical match found return true
return True

def apply(self, before: LogicalFilter, context: OptimizerContext):
sample = before.children[0]
logical_get = sample.children[0]
new_filter = LogicalFilter(before.predicate)
new_filter.append_child(logical_get)
sample.clear_children()
sample.append_child(new_filter)
return sample


class PushdownProjectThroughSample(Rule):
def __init__(self):
pattern = Pattern(OperatorType.LOGICALPROJECT)
pattern_sample = Pattern(OperatorType.LOGICALSAMPLE)
pattern_sample.append_child(Pattern(OperatorType.LOGICALGET))
pattern.append_child(pattern_sample)
super().__init__(RuleType.PUSHDOWN_PROJECT_THROUGH_SAMPLE, pattern)

def promise(self):
return Promise.PUSHDOWN_PROJECT_THROUGH_SAMPLE

def check(self, before: Operator, context: OptimizerContext):
# nothing else to check if logical match found return true
return True

def apply(self, before: LogicalProject, context: OptimizerContext):
sample = before.children[0]
logical_get = sample.children[0]
new_project = LogicalProject(before.target_list)
new_project.append_child(logical_get)
sample.clear_children()
sample.append_child(new_project)
return sample


# Join Queries
class PushDownFilterThroughJoin(Rule):
def __init__(self):
Expand Down Expand Up @@ -758,6 +698,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext):
before.dataset_metadata,
batch_mem_size=batch_mem_size,
predicate=before.predicate,
sampling_rate=before.sampling_rate,
)
)
return after
Expand Down Expand Up @@ -1038,10 +979,9 @@ def __init__(self):
self._rewrite_rules = [
EmbedFilterIntoGet(),
# EmbedFilterIntoDerivedGet(),
PushdownFilterThroughSample(),
EmbedProjectIntoGet(),
# EmbedProjectIntoDerivedGet(),
PushdownProjectThroughSample(),
EmbedSampleIntoGet(),
PushDownFilterThroughJoin(),
]

Expand Down
8 changes: 8 additions & 0 deletions eva/planner/storage_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class StoragePlan(AbstractPlan):
limit (int): limit on data records to be retrieved
total_shards (int): number of shards of data (if sharded)
curr_shard (int): current curr_shard if data is sharded
sampling_rate (int): uniform sampling rate
"""

def __init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
total_shards: int = 0,
curr_shard: int = 0,
predicate: AbstractExpression = None,
sampling_rate: int = None,
):
super().__init__(PlanOprType.STORAGE_PLAN)
self._video = video
Expand All @@ -53,6 +55,7 @@ def __init__(
self._total_shards = total_shards
self._curr_shard = curr_shard
self._predicate = predicate
self._sampling_rate = sampling_rate

@property
def video(self):
Expand Down Expand Up @@ -86,6 +89,10 @@ def curr_shard(self):
def predicate(self):
return self._predicate

@property
def sampling_rate(self):
return self._sampling_rate

def __hash__(self) -> int:
return hash(
(
Expand All @@ -98,5 +105,6 @@ def __hash__(self) -> int:
self.total_shards,
self.curr_shard,
self.predicate,
self.sampling_rate,
)
)
7 changes: 3 additions & 4 deletions eva/readers/opencv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ def _read(self) -> Iterator[Dict]:
frame_id += 1
else:
for begin, end in range_list:

# align begin with sampling rate
if begin % self._sampling_rate:
begin += self._sampling_rate - (
begin % self._sampling_rate
)

begin += self._sampling_rate - (begin % self._sampling_rate)
print(begin, end + 1, self._sampling_rate)
for frame_id in range(begin, end + 1, self._sampling_rate):
video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
_, frame = video.read()
Expand Down
24 changes: 23 additions & 1 deletion test/integration_tests/test_select_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,29 @@ def test_select_and_sample(self):
self.assertEqual(len(actual_batch), len(expected_batch[0]))
# Since frames are fetched in random order, this test might be flaky
# Disabling it for time being
# self.assertEqual(actual_batch, expected_batch[0])
self.assertEqual(actual_batch, expected_batch[0])

def test_aaselect_and_sample_with_predicate(self):
select_query = (
"SELECT name, id,data FROM MyVideo SAMPLE 2 WHERE id > 5 ORDER BY id;"
)
actual_batch = execute_query_fetch_all(select_query)
expected_batch = list(create_dummy_batches(filters=range(6, NUM_FRAMES, 2)))
self.assertEqual(actual_batch, expected_batch[0])

select_query = (
"SELECT name, id,data FROM MyVideo SAMPLE 4 WHERE id > 2 ORDER BY id;"
)
actual_batch = execute_query_fetch_all(select_query)
print(actual_batch)
expected_batch = list(create_dummy_batches(filters=range(4, NUM_FRAMES, 4)))
self.assertEqual(actual_batch, expected_batch[0])

select_query = "SELECT name, id,data FROM MyVideo SAMPLE 2 WHERE id > 2 AND id < 8 ORDER BY id;"
actual_batch = execute_query_fetch_all(select_query)
print(actual_batch)
expected_batch = list(create_dummy_batches(filters=range(4, 8, 2)))
self.assertEqual(actual_batch, expected_batch[0])

@pytest.mark.torchtest
def test_lateral_join(self):
Expand Down
41 changes: 3 additions & 38 deletions test/optimizer/rules/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
LogicalGet,
LogicalProject,
LogicalQueryDerivedGet,
LogicalSample,
)
from eva.optimizer.rules.rules import (
EmbedFilterIntoDerivedGet,
EmbedFilterIntoGet,
EmbedProjectIntoDerivedGet,
EmbedProjectIntoGet,
EmbedSampleIntoGet,
LogicalCreateMaterializedViewToPhysical,
LogicalCreateToPhysical,
LogicalCreateUDFToPhysical,
Expand All @@ -60,8 +60,6 @@
LogicalUploadToPhysical,
Promise,
PushDownFilterThroughJoin,
PushdownFilterThroughSample,
PushdownProjectThroughSample,
RulesManager,
)
from eva.server.command_handler import execute_query_fetch_all
Expand All @@ -86,10 +84,7 @@ def test_rules_promises_order(self):
Promise.EMBED_PROJECT_INTO_DERIVED_GET > Promise.IMPLEMENTATION_DELIMETER
)
self.assertTrue(
Promise.PUSHDOWN_FILTER_THROUGH_SAMPLE > Promise.IMPLEMENTATION_DELIMETER
)
self.assertTrue(
Promise.PUSHDOWN_PROJECT_THROUGH_SAMPLE > Promise.IMPLEMENTATION_DELIMETER
Promise.EMBED_SAMPLE_INTO_GET > Promise.IMPLEMENTATION_DELIMETER
)
self.assertTrue(
Promise.EMBED_FILTER_INTO_GET > Promise.IMPLEMENTATION_DELIMETER
Expand Down Expand Up @@ -144,10 +139,9 @@ def test_supported_rules(self):
supported_rewrite_rules = [
EmbedFilterIntoGet(),
# EmbedFilterIntoDerivedGet(),
PushdownFilterThroughSample(),
EmbedProjectIntoGet(),
EmbedSampleIntoGet(),
# EmbedProjectIntoDerivedGet(),
PushdownProjectThroughSample(),
PushDownFilterThroughJoin(),
]
self.assertEqual(
Expand Down Expand Up @@ -255,35 +249,6 @@ def test_simple_project_into_derived_get(self):
self.assertFalse(rewrite_opr is logi_derived_get)
self.assertEqual(rewrite_opr.target_list, target_list)

# PushdownFilterThroughSample
def test_pushdown_filter_thru_sample(self):
rule = PushdownFilterThroughSample()
predicate = MagicMock()
constexpr = MagicMock()
logi_get = LogicalGet(MagicMock(), MagicMock(), MagicMock())
sample = LogicalSample(constexpr, [logi_get])
logi_filter = LogicalFilter(predicate, [sample])
rewrite_opr = rule.apply(logi_filter, MagicMock())
self.assertIsInstance(rewrite_opr, LogicalSample)
print(rewrite_opr.children[0])
self.assertIsInstance(rewrite_opr.children[0], LogicalFilter)
self.assertIsInstance(rewrite_opr.children[0].children[0], LogicalGet)

# PushdownProjectThroughSample
def test_pushdown_project_thru_sample(self):
rule = PushdownProjectThroughSample()
target_list = MagicMock()
constexpr = MagicMock()
logi_get = LogicalGet(MagicMock(), MagicMock(), MagicMock())
sample = LogicalSample(constexpr, [logi_get])
logi_project = LogicalProject(target_list, [sample])

rewrite_opr = rule.apply(logi_project, MagicMock())
self.assertTrue(rewrite_opr is sample)
self.assertFalse(rewrite_opr.children[0] is logi_project)
self.assertTrue(logi_get is rewrite_opr.children[0].children[0])
self.assertEqual(rewrite_opr.children[0].target_list, target_list)

def test_should_pushdown_filter_through_join(self):
query = """SELECT id, label
FROM MyVideo JOIN LATERAL
Expand Down
Loading

0 comments on commit 0703940

Please sign in to comment.