Skip to content

Commit

Permalink
LATERAL JOIN: Checkpoint 2: Optimizer Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav274 committed Feb 28, 2021
1 parent 7db752a commit 127a9a4
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 21 deletions.
15 changes: 14 additions & 1 deletion src/optimizer/generators/seq_scan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
# limitations under the License.
from src.optimizer.generators.base import Generator
from src.optimizer.operators import LogicalGet, Operator, LogicalFilter, \
LogicalProject, LogicalUnion, LogicalOrderBy, LogicalLimit
LogicalProject, LogicalUnion, LogicalOrderBy, LogicalLimit, LogicalJoin, \
LogicalFunctionScan
from src.planner.seq_scan_plan import SeqScanPlan
from src.planner.union_plan import UnionPlan
from src.planner.storage_plan import StoragePlan
from src.planner.orderby_plan import OrderByPlan
from src.planner.limit_plan import LimitPlan
from src.planner.nested_loop_join_plan import NestedLoopJoin
from src.planner.function_scan import FunctionScan


class ScanGenerator(Generator):
Expand Down Expand Up @@ -59,6 +62,10 @@ def _visit_logical_limit(self, operator: LogicalLimit):
limitplan.append_child(self._plan)
self._plan = limitplan

def _visit_logical_join(self, operator: LogicalJoin):
joinplan = NestedLoopJoin(operator.join_type, operator.join_predicate)
self._plan = joinplan

def _visit(self, operator: Operator):
if isinstance(operator, LogicalUnion):
self._visit_logical_union(operator)
Expand All @@ -67,6 +74,12 @@ def _visit(self, operator: Operator):
for child in operator.children:
self._visit(child)

if isinstance(operator, LogicalFunctionScan):
self._plan = FunctionScan(operator.func_expr)

if isinstance(operator, LogicalJoin):
self._visit_logical_join(operator)

if isinstance(operator, LogicalOrderBy):
self._visit_logical_orderby(operator)

Expand Down
70 changes: 70 additions & 0 deletions src/optimizer/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from src.catalog.models.df_metadata import DataFrameMetadata
from src.expression.constant_value_expression import ConstantValueExpression
from src.parser.table_ref import TableRef
from src.parser.types import JoinType
from src.expression.abstract_expression import AbstractExpression
from src.catalog.models.df_column import DataFrameColumn
from src.catalog.models.udf_io import UdfIO
Expand All @@ -40,6 +41,8 @@ class OperatorType(IntEnum):
LOGICALUNION = auto()
LOGICALORDERBY = auto()
LOGICALLIMIT = auto()
LOGICALJOIN = auto()
LOGICAL_FUNCTION_SCAN = auto()


class Operator:
Expand Down Expand Up @@ -401,3 +404,70 @@ def __eq__(self, other):
return (is_subtree_equal
and self.table_metainfo == other.table_metainfo
and self.path == other.path)


class LogicalJoin(Operator):
"""
Logical node for join operators
Attributes:
join_type: JoinType
Join type provided by the user - Lateral, Inner, Outer
left: TableRef
Left join table
right: TableRef
Right join table
join_predicate: AbstractExpression
condition/predicate expression used to join the tables
"""

def __init__(self,
join_type: JoinType,
join_predicate: AbstractExpression = None,
children: List = None):
super().__init__(OperatorType.LOGICALJOIN, children)
self._join_type = join_type
self._join_predicate = join_predicate

@property
def join_type(self):
return self._join_type

@property
def join_predicate(self):
return self._join_predicate

def __eq__(self, other):
is_subtree_equal = super().__eq__(other)
if not isinstance(other, LogicalJoin):
return False
return (is_subtree_equal
and self.join_type == other.join_type
and self.join_predicate == other.join_predicate)


class LogicalFunctionScan(Operator):
"""
Logical node for function table scans
Attributes:
func_expr: AbstractExpression
function_expression that yield a table like output
"""

def __init__(self,
func_expr: AbstractExpression,
children: List = None):
super().__init__(OperatorType.LOGICAL_FUNCTION_SCAN, children)
self._func_expr = func_expr

@property
def func_expr(self):
return self._func_expr

def __eq__(self, other):
is_subtree_equal = super().__eq__(other)
if not isinstance(other, LogicalFunctionScan):
return False
return (is_subtree_equal
and self.func_expr == other.func_expr)
33 changes: 26 additions & 7 deletions src/optimizer/statement_to_opr_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
LogicalInsert, LogicalCreate,
LogicalCreateUDF, LogicalLoadData,
LogicalQueryDerivedGet, LogicalUnion,
LogicalOrderBy, LogicalLimit)
LogicalOrderBy, LogicalLimit, LogicalJoin,
LogicalFunctionScan)
from src.parser.statement import AbstractStatement
from src.parser.select_statement import SelectStatement
from src.parser.insert_statement import InsertTableStatement
Expand Down Expand Up @@ -46,16 +47,34 @@ def _populate_column_map(self, dataset: DataFrameMetadata):
self._column_map[column.name.lower()] = column

def visit_table_ref(self, video: TableRef):
"""Bind table ref object and convert to Logical get operator
"""Bind table ref object and convert to Logical Get
or Logical Join
Arguments:
video {TableRef} -- [Input table ref object created by the parser]
"""
catalog_vid_metadata = bind_dataset(video.table_info)

self._populate_column_map(catalog_vid_metadata)

self._plan = LogicalGet(video, catalog_vid_metadata)
if (video.table_info):
catalog_vid_metadata = bind_dataset(video.table_info)
self._populate_column_map(catalog_vid_metadata)
self._plan = LogicalGet(video, catalog_vid_metadata)

if (video.join):
if isinstance(video.join.left, TableRef):
self.visit_table_ref(video.join.left)
if isinstance(video.join.left, AbstractExpression):
self._plan = LogicalFunctionScan(func_expr=video.join.left)
left_child_plan = self._plan

if isinstance(video.join.right, TableRef):
self.visit_table_ref(video.join.right)
if isinstance(video.join.right, AbstractExpression):
self._plan = LogicalFunctionScan(func_expr=video.join.right)
right_child_plan = self._plan

self._plan = LogicalJoin(join_type=video.join.join_type,
join_predicate=video.join.predicate)
self._plan.append_child(left_child_plan)
self._plan.append_child(right_child_plan)

def visit_select(self, statement: SelectStatement):
"""converter for select statement
Expand Down
50 changes: 50 additions & 0 deletions src/planner/abstract_join_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# coding=utf-8
# Copyright 2018-2020 EVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Abstract class for all the join planners
"""
from src.expression.abstract_expression import AbstractExpression
from src.planner.abstract_plan import AbstractPlan
from src.parser.types import JoinType

from src.planner.types import PlanNodeType


class AbstractJoin(AbstractPlan):
"""Abstract class for all the join based planners
Arguments:
join_type: JoinType
type of join, INNER, OUTER , LATERAL etc
join_predicate: AbstractExpression
An expression used for joining
"""

def __init__(self,
node_type: PlanNodeType,
join_type: JoinType,
join_predicate: AbstractExpression):
super().__init__(node_type)
self._join_type = join_type
self._join_predicate = join_predicate

@property
def join_type(self) -> AbstractExpression:
return self._join_type

@property
def join_predicate(self) -> AbstractExpression:
return self._join_predicate
35 changes: 35 additions & 0 deletions src/planner/function_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright 2018-2020 EVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.planner.types import PlanNodeType
from src.planner.abstract_plan import AbstractPlan
from src.expression.function_expression import FunctionExpression


class FunctionScan(AbstractPlan):
"""
This plan used to store metadata to perform function table scan.
Arguments:
"""

def __init__(self, func_expr: FunctionExpression):
self._func_expr = func_expr
super().__init__(PlanNodeType.FUNCTION_SCAN)

@property
def func_expr(self):
return self._func_expr
34 changes: 34 additions & 0 deletions src/planner/nested_loop_join_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# coding=utf-8
# Copyright 2018-2020 EVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.planner.types import PlanNodeType
from src.planner.abstract_join_plan import AbstractJoin
from src.expression.abstract_expression import AbstractExpression
from src.parser.types import JoinType


class NestedLoopJoin(AbstractJoin):
"""
This plan is used for storing information required for nested loop join.
Arguments:
join_type: JoinType
join_predicate: AbstractExpression
"""

def __init__(self,
join_type: JoinType,
join_predicate: AbstractExpression):
super().__init__(PlanNodeType.JOIN, join_type, join_predicate)
24 changes: 13 additions & 11 deletions src/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import unique, IntEnum
from enum import unique, IntEnum, auto


@unique
class PlanNodeType(IntEnum):
SEQUENTIAL_SCAN = 1
STORAGE_PLAN = 2
PP_FILTER = 3
INSERT = 4
CREATE = 5
CREATE_UDF = 6
LOAD_DATA = 7
UNION = 8
ORDER_BY = 9
LIMIT = 10
SEQUENTIAL_SCAN = auto()
STORAGE_PLAN = auto()
PP_FILTER = auto()
INSERT = auto()
CREATE = auto()
CREATE_UDF = auto()
LOAD_DATA = auto()
UNION = auto()
ORDER_BY = auto()
LIMIT = auto()
JOIN = auto()
FUNCTION_SCAN = auto()
# add other types
35 changes: 33 additions & 2 deletions test/optimizer/test_statement_to_opr_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
LogicalQueryDerivedGet, LogicalCreate,
LogicalCreateUDF, LogicalInsert,
LogicalLoadData, LogicalUnion,
LogicalOrderBy, LogicalLimit)
LogicalOrderBy, LogicalLimit, LogicalJoin,
LogicalFunctionScan)

from src.expression.tuple_value_expression import TupleValueExpression
from src.expression.constant_value_expression import ConstantValueExpression
from src.expression.comparison_expression import ComparisonExpression
from src.expression.function_expression import FunctionExpression
from src.expression.abstract_expression import ExpressionType

from src.parser.types import ParserOrderBySortType
from src.parser.types import ParserOrderBySortType, JoinType


class StatementToOprTest(unittest.TestCase):
Expand Down Expand Up @@ -410,3 +412,32 @@ def test_should_return_false_for_unequal_plans(self):
self.assertNotEqual(query_derived_plan, create_plan)
self.assertNotEqual(insert_plan, query_derived_plan)
self.assertNotEqual(load_plan, insert_plan)

@patch('src.optimizer.statement_to_opr_convertor.bind_dataset')
@patch('src.optimizer.statement_to_opr_convertor.bind_columns_expr')
@patch('src.optimizer.statement_to_opr_convertor.bind_predicate_expr')
def test_should_handle_lateral_join(self, mock_p, mock_c, mock_d):
m = MagicMock()
mock_p.return_value = mock_c.return_value = mock_d.return_value = m
stmt = Parser().parse("""SELECT id FROM DETRAC,
LATERAL UNNEST(ObjDet(frame)) WHERE id > 3;""")[0]
converter = StatementToPlanConvertor()
actual_plan = converter.visit(stmt)

right = FunctionExpression(func=None, name='UNNEST')
child = FunctionExpression(func=None, name='ObjDet')
child.append_child(TupleValueExpression('frame'))
right.append_child(child)
left = LogicalGet(TableRef(TableInfo('DETRAC')), m)
join = LogicalJoin(JoinType.LATERAL_JOIN)
join.append_child(left)
join.append_child(LogicalFunctionScan(right))
filter = LogicalFilter(
ComparisonExpression(
ExpressionType.COMPARE_GREATER,
TupleValueExpression('id'),
ConstantValueExpression(3)))
filter.append_child(join)
expected_plan = LogicalProject([TupleValueExpression('id')])
expected_plan.append_child(filter)
self.assertEqual(actual_plan, expected_plan)

0 comments on commit 127a9a4

Please sign in to comment.