# import

In [309]:
from pglast import ast,parser,visitors,printers,enums
from pprint import pprint
import pyarrow as arrow
from pyarrow import csv,compute,types
import pandas
from datetime import datetime,timedelta,date
import hashlib

# 取tpch q1 ast

In [310]:
q1Stmt = parser.parse_sql(
    "select \
        l_returnflag, \
        l_linestatus, \
        sum(l_quantity) as sum_qty, \
        sum(l_extendedprice) as sum_base_price, \
        sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, \
        sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, \
        avg(l_quantity) as avg_qty, \
        avg(l_extendedprice) as avg_price, \
        avg(l_discount) as avg_disc, \
        count(*) as count_order \
    from \
        lineitem \
    where \
        l_shipdate <= date '1998-12-01' - interval '112' day \
    group by \
        l_returnflag, \
        l_linestatus \
    order by \
        l_returnflag, \
        l_linestatus;"
)[0]

In [311]:
#q1Stmt

# 表lineitem的schema

In [312]:
# create table lineitem ( l_orderkey    bigint not null,
#                              l_partkey     integer not null,
#                              l_suppkey     integer not null,
#                              l_linenumber  integer not null,
#                              l_quantity    decimal(15,2) not null,
#                              l_extendedprice  decimal(15,2) not null,
#                              l_discount    decimal(15,2) not null,
#                              l_tax         decimal(15,2) not null,
#                              l_returnflag  varchar(1) not null,
#                              l_linestatus  varchar(1) not null,
#                              l_shipdate    date not null,
#                              l_commitdate  date not null,
#                              l_receiptdate date not null,
#                              l_shipinstruct varchar(25) /*char(25)*/ not null,
#                              l_shipmode     varchar(10) /*char(10)*/ not null,
#                              l_comment      varchar(44) not null,
#                          primary key (l_orderkey, l_linenumber)
#                         );

In [313]:
lineitemSchema = arrow.schema([
    arrow.field("l_orderkey",arrow.int64(),False),
    arrow.field("l_partkey",arrow.int32(),False),
    arrow.field("l_suppkey",arrow.int32(),False),
    arrow.field("l_linenumber",arrow.int32(),False),
    arrow.field("l_quantity",arrow.decimal256(21,2),False),
    arrow.field("l_extendedprice",arrow.decimal256(21,2),False),
    arrow.field("l_discount",arrow.decimal256(21,2),False),
    arrow.field("l_tax",arrow.decimal256(21,2),False),
    arrow.field("l_returnflag",arrow.utf8(),False),
    arrow.field("l_linestatus",arrow.utf8(),False),
    arrow.field("l_shipdate",arrow.date32(),False),
    arrow.field("l_commitdate",arrow.date32(),False),
    arrow.field("l_receiptdate",arrow.date32(),False),
    arrow.field("l_shipinstruct",arrow.utf8(),False),
    arrow.field("l_shipmode",arrow.utf8(),False),
    arrow.field("l_comment",arrow.utf8(),False),
],{"name":"lineitem",
   "delimiter":"|",
   "path":"/Users/pengzhen/Documents/GitHub/mo-test/tpch100M/lineitem.tbl_10"})
# lineitemSchema

# 常量

In [314]:

# logical plan 标签
RELATIONS = "relations"
SINGLE_RELATION = "singleRelation"
MULTI_RELATION = "multipleRelation"
WHERE = "where"
OUTPUT = "output"
GROUP = "group"
ORDER = "order"
AGGREGATE = "aggregate"

RELATION_TYPE_IDX = 0
DB_NAME_IDX = 1
ALIAS_IDX = 2
TABLE_NAME_IDX = 3
SCHEMA_IDX = 4

RELATION_TYPE_TABLE = "table"
RELATION_TYPE_VIEW = "view"
RELATION_TYPE_SUBQUERY = "subquery"

EXPR_OP_COLUMN_REF = "colRef"
EXPR_OP_LESS_EQUAL = "<="
EXPR_OP_PLUS = "+"
EXPR_OP_MINUS = "-"
EXPR_OP_MULTI = "*"
EXPR_OP_CAST = "cast"
EXPR_OP_CONST = "const"

FUNC_EXPR = "func"
FUNC_RESULT_REF_EXPR = "func_result_ref"
OUTPUT_EXPR = "output_expr"
GROUP_EXPR = "group_expr"
ORDER_EXPR = "order_expr"

# 取schema

In [315]:
Catalog = {
    "tpch":{
        "lineitem":lineitemSchema
    }
}

def isValidName(s):
    return not (s is None or len(s) == 0)


def getColumn(plan,dbName,tableName,colName):
    '''
    取列定义
    :param plan:
    :param dbName:
    :param tableName:
    :param colName:
    :return:
    '''
    relations = plan.get(RELATIONS,None)
    if relations is None:
        raise Exception("no table defs")

    if len(dbName) == 0:
        dbName = "tpch"

    if relations[0] == SINGLE_RELATION:
        single = relations[1]
        if len(tableName) == 0:
            #print(f"single {single}")
            for tableName2,tableDef in single.items():
                schema = tableDef[SCHEMA_IDX]
                colDef = schema.field(colName)
                if colDef is not None:
                    return colDef,dbName,tableName2
            raise Exception(f"no column name {colName} in table {tableName}")

        if tableName in single:
            tableDef = single[tableName]
            if tableDef[DB_NAME_IDX] == dbName and tableDef[ALIAS_IDX] == tableName:
                schema = tableDef[SCHEMA_IDX]
                colDef = schema.field(colName)
                if colDef is not None:
                    return colDef,dbName,tableName
                else:
                    raise Exception(f"no column name {colName} in table {tableName}")
            else:
                raise Exception(f"invalid database name {dbName} or table name {tableName}")
        else:
            raise Exception(f"no such relation {tableName} in database {dbName}")
    else:
        raise Exception("not implement multiple relations")



Catalog

{'tpch': {'lineitem': l_orderkey: int64 not null
  l_partkey: int32 not null
  l_suppkey: int32 not null
  l_linenumber: int32 not null
  l_quantity: decimal256(21, 2) not null
  l_extendedprice: decimal256(21, 2) not null
  l_discount: decimal256(21, 2) not null
  l_tax: decimal256(21, 2) not null
  l_returnflag: string not null
  l_linestatus: string not null
  l_shipdate: date32[day] not null
  l_commitdate: date32[day] not null
  l_receiptdate: date32[day] not null
  l_shipinstruct: string not null
  l_shipmode: string not null
  l_comment: string not null
  -- schema metadata --
  name: 'lineitem'
  delimiter: '|'
  path: '/Users/pengzhen/Documents/GitHub/mo-test/tpch100M/lineitem.tbl_10'}}

# 构建tpch q1的逻辑查询计划

## 逻辑查询计划builder

从各种SQL语句构建逻辑查询计划

逻辑查询计划定义为字典类型。

In [316]:
class LogicalPlanBuilder:
    def __init__(self):
        pass

    def build(self,node,plan : dict):
        pass

## select语句builder

为select语句生成逻辑查询计划。

In [317]:
class SelectBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,select : ast.SelectStmt,plan : dict):
        fb = FromBuilder()
        fb.build(select.fromClause,plan)

        wb = WhereBuilder()
        wb.build(select.whereClause,plan)

        selectList = SelectListBuilder()
        selectList.build(select.targetList,plan)

        groupby = GroupbyBuilder()
        groupby.build(select.groupClause,plan)

        orderby = OrderbyBuilder()
        orderby.build(select.sortClause,plan)
        pass

## from子句builder

从from子句中提取各个关系表。

In [318]:
class FromBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()
    def build(self,tableRefs : tuple,plan : dict):
        #print("tableRefs")
        if len(tableRefs) == 1:
            single = self.buildTableRef(tableRefs[0])
            plan[RELATIONS] = [SINGLE_RELATION,single]
            return
        raise Exception("unsupport multiple table refs")
        pass

    def buildTableRef(self,tableRef : ast.RangeVar)-> dict:
        #print("tableRef")
        dbName = tableRef.schemaname
        if dbName is None or len(dbName) == 0:
            dbName = "tpch"

        #print(Catalog[dbName])
        if tableRef.relname in Catalog[dbName] :
            return {tableRef.relname :
                    [RELATION_TYPE_TABLE, #relation type
                    dbName, #database name
                    tableRef.relname,#alias=
                    tableRef.relname,#original name
                    Catalog[dbName][tableRef.relname] #schema
                    ]}
        else:
            raise Exception("no such table in Catalog",tableRef.schemaname,tableRef.relname)
        pass

## 表达式builder

表达式构建基类。也是最复杂的类。

In [319]:
class ExpressionBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        if isinstance(node,ast.A_Expr):
            if node.kind == enums.parsenodes.A_Expr_Kind.AEXPR_OP:
                opName = node.name[0].sval
                if opName == "<=":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_LESS_EQUAL,l,r
                elif opName == "-":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_MINUS,l,r
                elif opName == "*":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_MULTI,l,r
                elif opName == "+":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_PLUS,l,r
                else:
                    raise Exception("unsupported operator",node)
            else:
                raise Exception("unsupported expr 1",node)
        elif isinstance(node,ast.ColumnRef):
            fields = node.fields
            if len(fields) == 1:
                colName = fields[0].sval
                #print(f"{fields}, {colName}")
                colRef = getColumn(plan,"","",colName)
                return EXPR_OP_COLUMN_REF, colRef
            raise Exception("unsupported column ref",node)
        elif isinstance(node,ast.TypeCast):
            #print(f"type_cast {node.arg} \nto_type {node.typeName}")

            e = self.build(node.arg,plan)
            t = node.typeName.names
            return EXPR_OP_CAST,e,t
        elif isinstance(node,ast.A_Const):
            if node.isnull:
                return EXPR_OP_CONST,node.isnull
            return EXPR_OP_CONST,node.isnull,node.val
        elif isinstance(node,ast.FuncCall):
            # pprint(node)
            func_name = node.funcname[0].sval
            is_agg_func = False
            if is_aggregate_func(func_name):
                is_agg_func = True
            args = None
            if node.args is not None:
                args = []
                for arg in node.args:
                    arg_e = self.build(arg,plan)
                    args.append(arg_e)
            elif node.agg_star:
                args = "*"
            else:
                raise Exception("function has no args",node)
            if is_agg_func:
                # print(f"agg func {func_name}")
                aggs = plan.get(AGGREGATE,[])
                agg_idx = len(aggs)
                aggs.append((FUNC_EXPR,func_name,args))
                plan[AGGREGATE] = aggs
                return FUNC_RESULT_REF_EXPR,agg_idx
            return FUNC_EXPR,func_name,args
        else:
            raise Exception("unsupported expr 2",node)
        pass


def is_aggregate_func(name):
    return name in ["count","avg","sum"]

## where子句builder

构建where表达式。

In [320]:
class WhereBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        #print("where ",node)
        eb = ExpressionBuilder()
        ret = eb.build(node,plan)
        plan[WHERE] = ret

## select expr builder

In [321]:
class SelectExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        if isinstance(node,ast.ResTarget):
            value = node.val
            if isinstance(value,ast.ColumnRef):
                r = super().build(value,plan)
                #print(f"-->{r}")
                alias = node.name
                if not isValidName(alias):
                    #取列名
                    alias = r[1][0].name
                return OUTPUT_EXPR, alias, r
            elif isinstance(value,ast.FuncCall):
                r = super().build(value,plan)
                alias = node.name
                if not isValidName(alias):
                    #取表达式字符串
                    alias = str(value)
                return OUTPUT_EXPR,alias,r
            else:
                return super().build(value,plan)
        else:
            return super().build(node,plan)


## select list builder

In [322]:

class SelectListBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        #print("select list",node)
        #pprint(node)
        selExprBuilder = SelectExprBuilder()
        output = []
        for expr in node:
            #pprint(expr)
            o = selExprBuilder.build(expr,plan)
            output.append(o)
        plan[OUTPUT] = output

## group expr builder

In [323]:
class GroupExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        r = super().build(node,plan)
        return GROUP_EXPR,r

## group by子句builder

In [324]:
class GroupbyBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        geb = GroupExprBuilder()
        groupby = []
        for expr in node:
            r = geb.build(expr,plan)
            groupby.append(r)
        plan[GROUP] = groupby
        pass

## order expr builder

In [325]:
class OrderExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        if isinstance(node,ast.SortBy):
            r = super().build(node.node,plan)
            return ORDER_EXPR,r,node.sortby_dir
        else:
            raise Exception(f"not implement order expr {node}")

## order by builder

In [326]:
class OrderbyBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        oeb = OrderExprBuilder()
        orderby = []
        for expr in node:
            r = oeb.build(expr,plan)
            orderby.append(r)
        plan[ORDER] = orderby
        pass

# 构建tpch q1的物理查询计划

## 物理执行计划执行器

In [327]:
class Executor:
    def __init__(self):
        pass

    def Open(self):
        pass

    def Next(self):
        pass

    def Close(self):
        pass

## csv table scan执行器

In [328]:
class CsvTableScan(Executor):
    def __init__(self,schema : arrow.Schema,column_names : list,drop_columns : list):
        super().__init__()
        self.reader = None
        self.schema = schema
        self.column_names = column_names
        self.drop_columns = drop_columns
        self.block_size = 16 * 1024

    def Open(self):
        # 打开文件
        meta = self.schema.metadata
        path = meta.get(b"path").decode('utf-8')
        delimiter = meta.get(b"delimiter").decode('utf-8')
        # pprint(meta)
        # pprint(meta.get(b"path").decode('utf-8'))
        # pprint(meta.get(b"delimiter").decode('utf-8'))
        read_opts = arrow.csv.ReadOptions(column_names = self.column_names,
                                          block_size = self.block_size,
                                          autogenerate_column_names = bool)
        parse_opts = arrow.csv.ParseOptions(delimiter = delimiter)
        convert_opts = arrow.csv.ConvertOptions(column_types = self.schema,
                                                include_columns = self.column_names)
        self.reader = arrow.csv.open_csv(path,read_options = read_opts,parse_options = parse_opts,convert_options = convert_opts)

    def Next(self):
        chunk = self.reader.read_next_batch()
        needed_arrays = []
        for col_idx in range(chunk.num_columns):
            if col_idx in self.drop_columns:
                continue
            needed_arrays.append(chunk.column(col_idx))
        # pprint(self.schema)
        # pprint(needed_arrays)
        ret_chunk = arrow.RecordBatch.from_arrays(needed_arrays,self.column_names,self.schema)
        return ret_chunk

    def Close(self):
        self.reader.close()

In [329]:
csvReader = CsvTableScan(lineitemSchema,None,[16])
csvReader.Open()
record1 = csvReader.Next()
pprint(record1.to_pandas())
csvReader.Close()

   l_orderkey  l_partkey  l_suppkey  l_linenumber l_quantity l_extendedprice  \
0           1      15519        785             1      17.00        24386.67   
1           1       6731        732             2      36.00        58958.28   
2           1       6370        371             3       8.00        10210.96   
3           1        214        465             4      28.00        31197.88   
4           1       2403        160             5      24.00        31329.60   
5           1       1564         67             6      32.00        46897.92   
6           2      10617        138             1      38.00        58049.18   
7           3        430        181             1      45.00        59869.35   
8           3       1904        658             2      49.00        88489.10   
9           3      12845        370             3      27.00        47461.68   

  l_discount l_tax l_returnflag l_linestatus  l_shipdate l_commitdate  \
0       0.04  0.02            N            O  

## 表达式执行函数

In [330]:
def exec_expr(expr,records):
    expr_type = expr[0]
    if expr_type == EXPR_OP_COLUMN_REF:
        return records.column(expr[1][0].name)
    elif expr_type == EXPR_OP_CONST:
        if expr[1]:# NULL
            raise Exception(f"not implement const NULL")
        else:
            if isinstance(expr[2],ast.String):
                return arrow.array([expr[2].sval]*records.num_rows,arrow.string())
            elif isinstance(expr[2],ast.Integer):
                return arrow.array([expr[2].ival]*records.num_rows,arrow.int32())
            else:
                raise Exception(f"not implement const expr {expr}")
    elif expr_type == EXPR_OP_CAST:
        #直接转换
        l = exec_expr(expr[1],records)
        target_type = expr[2][0].sval
        if len(expr[2]) > 1:
            target_type = expr[2][1].sval
        if target_type == "date":
            date_vals = []
            for s in l:
                date_vals.append(datetime.strptime(str(s),"%Y-%m-%d"))
            return arrow.array(date_vals,arrow.date32())
        elif target_type == "interval":
            int_vals = []
            for s in l:
                int_vals.append(int(str(s)))
            return arrow.array(int_vals,arrow.int32())
        else:
            print(f"target_type {target_type}")
            return compute.cast(l,target_type)
    elif expr_type == EXPR_OP_PLUS:
        l = exec_expr(expr[1],records)
        r = exec_expr(expr[2],records)
        return compute.add(l,r)
    elif expr_type == EXPR_OP_MINUS:
        l = exec_expr(expr[1],records)
        r = exec_expr(expr[2],records)

        if types.is_date32(l.type):
            #对时间的减法做特殊处理
            date_time_vals = []
            for d in l:
                date_time_vals.append(datetime.strptime(str(d),"%Y-%m-%d"))
            time_delta_vals = []
            if types.is_int32(r.type):
                for i in r:
                    time_delta_vals.append(timedelta(int(str(i))))
            else:
                raise Exception("date minus needs int32")

            res_vals = [date_time_vals[i] - time_delta_vals[i] for i in range(len(time_delta_vals))]
            return arrow.array(res_vals,arrow.date32())
        else:
            return compute.subtract(l,r)
    elif expr_type == EXPR_OP_MULTI:
        l = exec_expr(expr[1],records)
        r = exec_expr(expr[2],records)
        return compute.multiply(l,r)
    elif expr_type == EXPR_OP_LESS_EQUAL:
        l = exec_expr(expr[1],records)
        r = exec_expr(expr[2],records)
        return compute.less_equal(l,r)
    elif expr_type == GROUP_EXPR:
        return exec_expr(expr[1],records)
    else:
        raise Exception(f"not implement exec expr {expr}")
    pass

## filter执行器
执行where子句

In [331]:
class FilterExecutor(Executor):
    def __init__(self, filter: tuple, child: Executor):
        super().__init__()
        self.filter = filter
        self.child = child

    def Open(self):
        self.child.Open()

    def Next(self):
        child_records = self.child.Next()
        mask = exec_expr(self.filter,child_records)
        return child_records.filter(mask)

    def Close(self):
        self.child.Close()

### 测试时间减法

In [332]:
date1 = datetime.strptime('1998-12-01',"%Y-%m-%d")
date2_arr = arrow.array([date1]*2,arrow.date32())
# pprint(date2_arr)
date4_vals = []
if types.is_date32(date2_arr.type):
    for d in date2_arr:
        print(datetime.strptime(str(d),"%Y-%m-%d"))

date3_arr = arrow.array([112]*2,arrow.int32())
if types.is_int32(date3_arr.type):
    for i in date3_arr:
        print(timedelta(int(str(i))))



1998-12-01 00:00:00
1998-12-01 00:00:00
112 days, 0:00:00
112 days, 0:00:00


## groupby执行器

In [333]:
class GroupbyExecutor(Executor):
    def __init__(self, plan: dict, child: Executor):
        super().__init__()
        self.plan = plan
        self.child = child
        self.aggregate = plan.get(AGGREGATE,[])
        self.groupby = plan.get(GROUP,[])
        self.hash_func = hashlib.sha256()
        #构建哈希表
        self.hash_table = {}


    def gen_hash(self,records):
        hash_vals = []
        for input in records:
            self.hash_func.update(str(input).encode("utf-8"))
            val = self.hash_func.hexdigest()
            # print(input,type(val))
            hash_vals.append(val)
        return hash_vals

    def update_agg_func_val(self,hash_key,param_vals,row_idx):
        hash_val = self.hash_table.get(hash_key,{})
        for agg_idx in range(len(self.aggregate)):
            agg_partial = hash_val.get(agg_idx,None)
            agg_func = self.aggregate[agg_idx]
            agg_func_name = agg_func[1]
            if agg_func_name == "sum":
                pprint(f"row_idx {row_idx} param_vals {param_vals[agg_idx]}")
                param_val_row = param_vals[agg_idx][0][row_idx]
                if agg_partial is None:
                    agg_partial = [agg_func_name,param_val_row]
                else:
                    partial_add_val = compute.add(agg_partial[1],param_val_row)
                    agg_partial = [agg_func_name,partial_add_val]
            elif agg_func_name == "count":
                if agg_partial is None:
                    agg_partial = [agg_func_name,1]
                else:
                    agg_partial[1] = agg_partial[1]+1
                    agg_partial = [agg_func_name,agg_partial[1]]
            elif agg_func_name == "avg":
                param_val_row = param_vals[agg_idx][0][row_idx]
                if agg_partial is None:
                    agg_partial = [agg_func_name,1,param_val_row]
                else:
                    agg_partial[1] = agg_partial[1]+1
                    partial_add_val = compute.add(agg_partial[2],param_val_row)
                    agg_partial = [agg_func_name,agg_partial[1],partial_add_val]
            else:
                raise Exception(f"not implement agg func {agg_func}")
            hash_val[agg_idx] = agg_partial
    def end_agg_func_val(self):
        #TODO:
        pass

    def update_hash_table(self,hash_vals,records):
        # 计算聚合函数参数表达式
        param_vals = []
        for agg_idx in range(len(self.aggregate)):
            agg_func = self.aggregate[agg_idx]
            # pprint(f"agg_func {agg_func}")
            # 计算聚合函数每个参数
            agg_func_name = agg_func[1]
            agg_args = agg_func[2]
            agg_arg_vals = []
            for arg in agg_args:
                # pprint(f"agg arg {arg}")
                if agg_func_name == "count" and arg == "*":
                    agg_arg_vals.append(arrow.array(["*"]*records.num_rows,arrow.string()))
                else:
                    agg_arg_val = exec_expr(arg,records)
                    agg_arg_vals.append(agg_arg_val)
            param_vals.append(agg_arg_vals)

        # pprint(f"param_vals {param_vals}")

        pprint(f"hash_vals {hash_vals}")
        # 根据哈希key分组
        row_count = len(hash_vals[0])
        col_count = len(hash_vals)
        for r in range(row_count):
            hash_key = []
            for c in range(col_count):
                hash_key.append(hash_vals[c][r])
            #更新聚合函数的中间结果
            self.update_agg_func_val(",".join(hash_key),param_vals,r)

    def Open(self):
        self.child.Open()

    def Next(self):
        records = self.child.Next()
        while records is not None:
            # 计算groupby表达式,求groupby值的哈希
            group_by_hash_vals = []
            for e in self.groupby:
                # pprint(e)
                val = exec_expr(e,records)
                hash_val = self.gen_hash(val)
                group_by_hash_vals.append(hash_val)
            # pprint(group_by_hash_vals)
            self.update_hash_table(group_by_hash_vals,records)
            # 下一批输入
            records = self.child.Next()

        #完成聚合函数的结果
        ret_records = self.end_agg_func_val()
        return ret_records

    def Close(self):
        self.child.Close()

## 物理计划builder

In [334]:
class PhysicalPlanBuilder:
    def __init__(self):
        pass

    def build(self,plan : dict,node : str):
        if node == RELATIONS:
            rel_info = plan.get(RELATIONS)
            if rel_info[0] == SINGLE_RELATION:
                rel_def = rel_info[1]
                table_name = list(rel_def.keys())[0]
                table_def = rel_def.get(table_name)
                schema = table_def[SCHEMA_IDX]
                return CsvTableScan(schema,None,[16])
            else:
                raise Exception(f"not implement relations {rel_info[0]}")

        elif node == WHERE :
            child = self.build(plan,RELATIONS)
            # 生成filter执行器
            filter = plan.get(WHERE)
            filter_exec = FilterExecutor(filter,child)
            return filter_exec
        elif node == GROUP:
            child = self.build(plan,WHERE)
            # 生成groupby执行器
            groupby = GroupbyExecutor(plan,child)
            return groupby
        else:
            raise Exception(f"not implement plan")


# main

## 执行逻辑计划

In [335]:
plan = {}
selBuilder = SelectBuilder()
selBuilder.build(q1Stmt.stmt,plan)
pprint(plan)

{'aggregate': [('func',
                'sum',
                [('colRef',
                  (pyarrow.Field<l_quantity: decimal256(21, 2) not null>,
                   'tpch',
                   'lineitem'))]),
               ('func',
                'sum',
                [('colRef',
                  (pyarrow.Field<l_extendedprice: decimal256(21, 2) not null>,
                   'tpch',
                   'lineitem'))]),
               ('func',
                'sum',
                [('*',
                  ('colRef',
                   (pyarrow.Field<l_extendedprice: decimal256(21, 2) not null>,
                    'tpch',
                    'lineitem')),
                  ('-',
                   ('const', False, <Integer ival=1>),
                   ('colRef',
                    (pyarrow.Field<l_discount: decimal256(21, 2) not null>,
                     'tpch',
                     'lineitem'))))]),
               ('func',
                'sum',
                [('*',
         


## 执行物理计划

In [336]:
pplan_builder = PhysicalPlanBuilder()
exec = pplan_builder.build(plan,GROUP)
exec.Open()
records = exec.Next()
pprint(records.to_pandas())


('hash_vals '
 "[['8ce86a6ae65d3692e7305e2c58ac62eebd97d3d943e093f577da25c36988246b', "
 "'15c841dff8407a197a8be456d43283d973c0ba690724b24da5e5d0d6dbc70423', "
 "'c6194eb92ed46a0996c1cab8662c10bc6b176ddc6599998d35c2e6eb0a357364', "
 "'06ff7b7828c546ebf947a94cd81e3d2f89c05b5e1f70d85ed1e3da47847e33e1', "
 "'4dd65d504cb177906c39754c46ba99ea9699a0499ecc4f11a64bd709c26bdda0', "
 "'79b73eb424a23572d5a72cc8ca3eff14309ca5c41552f5c106e814347d7c1363', "
 "'a1f2843a5b3ff3bcd7d4a5ea4126565e7c1f727181ff3455bdf0c90d6ef27381', "
 "'7f52ded415fedfc85466ded02511189c89bcdf919e8195ba9965e52c212ffbb3', "
 "'b14b104072b2e22296cfb90e8382ca03c43d0221005a2be583d1ecdf9b4f0b7b', "
 "'cd67c3dac764b35e35372c97422867301e1f8e623be064bbbebcd01de4ca6b3a'], "
 "['0d546db37a959327185f8f0df9bfd0633a684b2641e3edf35f66deb8fc013855', "
 "'b82849c8c169a1d4b20f8be13c80f51b5b7889d9ee49a32770ce2536e10a6fc0', "
 "'c18bf9bb1fad956e10f5563252581b27e8f6030546b13c8205931896ea72d31f', "
 "'295354c5cc52e69e83030d67af2e3bcb490296cf856

StopIteration: 

/*### 执行filter方法1*/

In [None]:
# date1 = datetime.strptime('1998-12-01',"%Y-%m-%d")
# interval1 = timedelta(1500)
# date2 = date1 - interval1
# date2_arr = arrow.array([date2]*records.num_rows,arrow.date32())
# pprint(date2_arr)
# # date1_arr = arrow.array([date1]*records.num_rows,arrow.date32())
# # interval_arr = arrow.array([112]*records.num_rows,arrow.int32())
# # # arrow 没有时间相减的函数
# # date3_arr = compute.subtract(date1_arr,interval_arr)
#
# #cast1 = compute.cast()
# mask1 = compute.less_equal(records.column("l_shipdate"),date2_arr)
# pprint(mask1.to_pandas())
# res1 = records.filter(mask1)
# pprint(res1.to_pandas())

/*### 执行filter方法2*/

In [None]:
# table1 = arrow.Table.from_batches([records])
# le1 = compute.less_equal(compute.field("l_orderkey"),1)
# res2 = table1.filter(le1)
# pprint(res2.to_pandas())


In [None]:
exec.Close()