# 摘要

本文旨在以最快的coding速度跑通tpch q1并得到正确的结果。

因此做了简化。

1，极简的逻辑plan结构。
以python的dict，list，tuple的嵌套组合作为plan的基础数据结构。
以直观的方式认识plan。tpch q1中不需要的内容都没有。
**tpch q1 的plan结构见下文[main]->[查看tpch q1逻辑计划]**

2，复用开源parser和向量化火山执行模型。
parser用的[pglast](https://pglast.readthedocs.io/en/v5/index.html)。向量化执行器用的[pyarrow](https://arrow.apache.org/docs/python/api.html)（apache arrow)。

3，手工转化逻辑plan到物理执行器。

4，不支持tpch q1中没有的语义检查、运算符、函数和子句。
例如，表达式、别名、函数嵌套、数据类型的语义检查部分。

5，不追求执行速度。没有优化器规则。没做性能和内存优化。

6，代码量少。

某些限制会逐步在后期迭代中去掉。

# 用法

jetbrains 中DataSpell、PyCharm直接打开本文。按提示安装环境即可。

或者自行安装jupyter notebook。或者其它能支持的IDE。

注意：本文用的python 3.9.6。

# import

In [494]:
from pglast import ast,parser,visitors,printers,enums
from pprint import pprint
import pyarrow as arrow
from pyarrow import csv,compute,types
import pandas
from datetime import datetime,timedelta,date
import hashlib

# 设置数据源

## 设置lineitem的数据源

In [495]:
lineitem_path = "../mo-test/tpch100M/lineitem.tbl"
lineitem_delimiter = "|"

# 初始化表lineitem的schema

In [496]:
# create table lineitem ( l_orderkey    bigint not null,
#                              l_partkey     integer not null,
#                              l_suppkey     integer not null,
#                              l_linenumber  integer not null,
#                              l_quantity    double not null,
#                              l_extendedprice  double not null,
#                              l_discount    double not null,
#                              l_tax         double not null,
#                              l_returnflag  varchar(1) not null,
#                              l_linestatus  varchar(1) not null,
#                              l_shipdate    date not null,
#                              l_commitdate  date not null,
#                              l_receiptdate date not null,
#                              l_shipinstruct varchar(25) /*char(25)*/ not null,
#                              l_shipmode     varchar(10) /*char(10)*/ not null,
#                              l_comment      varchar(44) not null,
#                          primary key (l_orderkey, l_linenumber)
#                         );

In [497]:
lineitemSchema = arrow.schema([
    arrow.field("l_orderkey",arrow.int64(),False),
    arrow.field("l_partkey",arrow.int32(),False),
    arrow.field("l_suppkey",arrow.int32(),False),
    arrow.field("l_linenumber",arrow.int32(),False),
    arrow.field("l_quantity",arrow.float64(),False),
    arrow.field("l_extendedprice",arrow.float64(),False),
    arrow.field("l_discount",arrow.float64(),False),
    arrow.field("l_tax",arrow.float64(),False),
    arrow.field("l_returnflag",arrow.utf8(),False),
    arrow.field("l_linestatus",arrow.utf8(),False),
    arrow.field("l_shipdate",arrow.date32(),False),
    arrow.field("l_commitdate",arrow.date32(),False),
    arrow.field("l_receiptdate",arrow.date32(),False),
    arrow.field("l_shipinstruct",arrow.utf8(),False),
    arrow.field("l_shipmode",arrow.utf8(),False),
    arrow.field("l_comment",arrow.utf8(),False),
],{"name":"lineitem",
   "delimiter":lineitem_delimiter,
   "path":lineitem_path})
# lineitemSchema

# 定义常量

In [498]:

# logical plan 标签
RELATIONS = "relations"
SINGLE_RELATION = "singleRelation"
MULTI_RELATION = "multipleRelation"
WHERE = "where"
OUTPUT = "output"
GROUP = "group"
ORDER = "order"
AGGREGATE = "aggregate"
PROJECT = "project"

# table def 字段索引
RELATION_TYPE_IDX = 0
DB_NAME_IDX = 1
ALIAS_IDX = 2
TABLE_NAME_IDX = 3
SCHEMA_IDX = 4

# 关系类型
RELATION_TYPE_TABLE = "table"
RELATION_TYPE_VIEW = "view"
RELATION_TYPE_SUBQUERY = "subquery"

# 表达式类型
EXPR_OP_COLUMN_REF = "colRef"
EXPR_OP_LESS_EQUAL = "<="
EXPR_OP_PLUS = "+"
EXPR_OP_MINUS = "-"
EXPR_OP_MULTI = "*"
EXPR_OP_CAST = "cast"
EXPR_OP_CONST = "const"

FUNC_EXPR = "func"
FUNC_RESULT_REF_EXPR = "func_result_ref"
OUTPUT_EXPR = "output_expr"
GROUP_EXPR = "group_expr"
ORDER_EXPR = "order_expr"
PROJECT_EXPR = "project_expr"

# 初始化Catalog

In [499]:
Catalog = {
    "tpch":{
        "lineitem":lineitemSchema
    }
}

def isValidName(s):
    return not (s is None or len(s) == 0)


def getColumn(plan,dbName,tableName,colName):
    """取列定义"""
    relations = plan.get(RELATIONS,None)
    if relations is None:
        raise Exception("no table defs")

    if len(dbName) == 0:
        dbName = "tpch"

    if relations[0] == SINGLE_RELATION:
        single = relations[1]
        if len(tableName) == 0:
            #在每个表中找字段
            for tableName2,tableDef in single.items():
                schema = tableDef[SCHEMA_IDX]
                colDef = schema.field(colName)
                if colDef is not None:
                    return colDef,dbName,tableName2
            raise Exception(f"no column name {colName} in table {tableName}")

        #找到表定义
        if tableName in single:
            tableDef = single[tableName]
            if tableDef[DB_NAME_IDX] == dbName and tableDef[ALIAS_IDX] == tableName:
                schema = tableDef[SCHEMA_IDX]
                colDef = schema.field(colName)
                if colDef is not None:
                    return colDef,dbName,tableName
                else:
                    raise Exception(f"no column name {colName} in table {tableName}")
            else:
                raise Exception(f"invalid database name {dbName} or table name {tableName}")
        else:
            raise Exception(f"no such relation {tableName} in database {dbName}")
    else:
        raise Exception("not implement multiple relations")


Catalog

{'tpch': {'lineitem': l_orderkey: int64 not null
  l_partkey: int32 not null
  l_suppkey: int32 not null
  l_linenumber: int32 not null
  l_quantity: double not null
  l_extendedprice: double not null
  l_discount: double not null
  l_tax: double not null
  l_returnflag: string not null
  l_linestatus: string not null
  l_shipdate: date32[day] not null
  l_commitdate: date32[day] not null
  l_receiptdate: date32[day] not null
  l_shipinstruct: string not null
  l_shipmode: string not null
  l_comment: string not null
  -- schema metadata --
  name: 'lineitem'
  delimiter: '|'
  path: '../mo-test/tpch100M/lineitem.tbl'}}

# 取tpch q1 ast

In [500]:
q1Stmt = parser.parse_sql(
    "select \
        l_returnflag, \
        l_linestatus, \
        sum(l_quantity) as sum_qty, \
        sum(l_extendedprice) as sum_base_price, \
        sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, \
        sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, \
        avg(l_quantity) as avg_qty, \
        avg(l_extendedprice) as avg_price, \
        avg(l_discount) as avg_disc, \
        count(*) as count_order \
    from \
        lineitem \
    where \
        l_shipdate <= date '1998-12-01' - interval '112' day \
    group by \
        l_returnflag, \
        l_linestatus \
    order by \
        l_returnflag, \
        l_linestatus;"
)[0]

In [501]:
#q1Stmt

# 构建tpch q1的逻辑查询计划

## 逻辑查询计划builder

从SELECT语句构建逻辑查询计划

逻辑查询计划plan定义为字典类型。内部是嵌套的list,dict,tuple等。plan内部用可读的字符串表达信息。
tpch q1 的plan结构见下文[main]->[查看tpch q1逻辑计划]。

In [502]:
# plan构造器
class LogicalPlanBuilder:
    def __init__(self):
        pass

    def build(self,node,plan : dict):
        """
        构建逻辑plan
        :param node: 节点ast
        :param plan: 逻辑计划
        :return:
        """
        pass

## select语句builder

为select语句生成逻辑查询计划。

In [503]:
# SELECT语句构造器。目前只支持tpch q1
class SelectBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,select : ast.SelectStmt,plan : dict):
        fb = FromBuilder()
        fb.build(select.fromClause,plan)

        wb = WhereBuilder()
        wb.build(select.whereClause,plan)

        selectList = SelectListBuilder()
        selectList.build(select.targetList,plan)

        groupby = GroupbyBuilder()
        groupby.build(select.groupClause,plan)

        orderby = OrderbyBuilder()
        orderby.build(select.sortClause,plan)

        project =  ProjectListBuilder()
        project.build(select.targetList,plan)

        pass

## from子句builder

从from子句中提取各个关系表。

In [504]:
class FromBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()
    def build(self,tableRefs : tuple,plan : dict):
        if len(tableRefs) == 1:
            single = self.buildTableRef(tableRefs[0])
            plan[RELATIONS] = [SINGLE_RELATION,single]
            return
        raise Exception("unsupport multiple table refs")
        pass

    def buildTableRef(self,tableRef : ast.RangeVar)-> dict:
        dbName = tableRef.schemaname
        if dbName is None or len(dbName) == 0:
            dbName = "tpch"

        # 从catalog中取表定义
        if tableRef.relname in Catalog[dbName] :
            return {tableRef.relname :
                    [RELATION_TYPE_TABLE, #relation type
                    dbName, #database name
                    tableRef.relname,#alias=
                    tableRef.relname,#original name
                    Catalog[dbName][tableRef.relname] #schema
                    ]}
        else:
            raise Exception("no such table in Catalog",tableRef.schemaname,tableRef.relname)
        pass

## 表达式builder

表达式构建基类。也是最复杂的类。

In [505]:
class ExpressionBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        """根据ast的类型。构建表达式。结果是tuple类型"""
        if isinstance(node,ast.A_Expr):
            if node.kind == enums.parsenodes.A_Expr_Kind.AEXPR_OP:
                opName = node.name[0].sval
                if opName == "<=":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_LESS_EQUAL,l,r
                elif opName == "-":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_MINUS,l,r
                elif opName == "*":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_MULTI,l,r
                elif opName == "+":
                    l = self.build(node.lexpr,plan)
                    r = self.build(node.rexpr,plan)
                    return EXPR_OP_PLUS,l,r
                else:
                    raise Exception("unsupported operator",node)
            else:
                raise Exception("unsupported expr 1",node)
        elif isinstance(node,ast.ColumnRef):
            fields = node.fields
            if len(fields) == 1:
                colName = fields[0].sval
                colRef = getColumn(plan,"","",colName)
                return EXPR_OP_COLUMN_REF, colRef
            raise Exception("unsupported column ref",node)
        elif isinstance(node,ast.TypeCast):
            e = self.build(node.arg,plan)
            t = node.typeName.names
            return EXPR_OP_CAST,e,t
        elif isinstance(node,ast.A_Const):
            if node.isnull:
                return EXPR_OP_CONST,node.isnull
            return EXPR_OP_CONST,node.isnull,node.val
        elif isinstance(node,ast.FuncCall):
            func_name = node.funcname[0].sval
            is_agg_func = False
            if is_aggregate_func(func_name):
                is_agg_func = True
            args = None
            if node.args is not None:
                args = []
                for arg in node.args:
                    arg_e = self.build(arg,plan)
                    args.append(arg_e)
            elif node.agg_star:
                args = "*"
            else:
                raise Exception("function has no args",node)
            if is_agg_func:
                aggs = plan.get(AGGREGATE,[])
                agg_idx = len(aggs)
                aggs.append((FUNC_EXPR,func_name,args))
                plan[AGGREGATE] = aggs
                return FUNC_RESULT_REF_EXPR,agg_idx
            return FUNC_EXPR,func_name,args
        else:
            raise Exception("unsupported expr 2",node)
        pass


def is_aggregate_func(name):
    return name in ["count","avg","sum"]

## where子句builder

构建where表达式。

In [506]:
class WhereBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        eb = ExpressionBuilder()
        ret = eb.build(node,plan)
        plan[WHERE] = ret

## select expr builder

In [507]:
class SelectExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        if isinstance(node,ast.ResTarget):
            value = node.val
            if isinstance(value,ast.ColumnRef):
                r = super().build(value,plan)
                alias = node.name
                if not isValidName(alias):
                    #取列名
                    alias = r[1][0].name
                return OUTPUT_EXPR, alias, r
            elif isinstance(value,ast.FuncCall):
                r = super().build(value,plan)
                alias = node.name
                if not isValidName(alias):
                    #取表达式字符串
                    alias = str(value)
                return OUTPUT_EXPR,alias,r
            else:
                return super().build(value,plan)
        else:
            return super().build(node,plan)


## select list builder

In [508]:

class SelectListBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        selExprBuilder = SelectExprBuilder()
        output = []
        for expr in node:
            o = selExprBuilder.build(expr,plan)
            output.append(o)
        plan[OUTPUT] = output

## group expr builder

In [509]:
class GroupExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        r = super().build(node,plan)
        return GROUP_EXPR,r

## group by子句builder

In [510]:
class GroupbyBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        geb = GroupExprBuilder()
        groupby = []
        for expr in node:
            r = geb.build(expr,plan)
            groupby.append(r)
        plan[GROUP] = groupby
        pass

## order expr builder

In [511]:
class OrderExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        if isinstance(node,ast.SortBy):
            r = super().build(node.node,plan)
            return ORDER_EXPR,r,node.sortby_dir
        else:
            raise Exception(f"not implement order expr {node}")

## order by builder

In [512]:
class OrderbyBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        oeb = OrderExprBuilder()
        orderby = []
        for expr in node:
            r = oeb.build(expr,plan)
            orderby.append(r)
        plan[ORDER] = orderby
        pass

## project expr builder

In [513]:
class ProjectExprBuilder(ExpressionBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        if isinstance(node,ast.ResTarget):
            value = node.val
            if isinstance(value,ast.ColumnRef):
                r = super().build(value,plan)
                alias = node.name
                if not isValidName(alias):
                    #取列名
                    alias = r[1][0].name
                return PROJECT_EXPR, alias
            elif isinstance(value,ast.FuncCall):
                r = super().build(value,plan)
                alias = node.name
                if not isValidName(alias):
                    #取表达式字符串
                    alias = str(value)
                return PROJECT_EXPR,alias
            else:
                return super().build(value,plan)
        else:
            return super().build(node,plan)


## project list builder

In [514]:
class ProjectListBuilder(LogicalPlanBuilder):
    def __init__(self):
        super().__init__()

    def build(self,node,plan : dict):
        projectExprBuilder = ProjectExprBuilder()
        projects = []
        for expr in node:
            p = projectExprBuilder.build(expr,plan)
            projects.append(p)
        plan[PROJECT] = projects

# 构建tpch q1的物理查询计划

## 物理计划builder

In [515]:
class PhysicalPlanBuilder:
    def __init__(self):
        pass

    def build(self,plan : dict,node : str):
        """根据当前的节点类型。构建物理执行器。返回Executor实例。"""
        if node == RELATIONS:
            rel_info = plan.get(RELATIONS)
            if rel_info[0] == SINGLE_RELATION:
                rel_def = rel_info[1]
                table_name = list(rel_def.keys())[0]
                table_def = rel_def.get(table_name)
                schema = table_def[SCHEMA_IDX]
                return CsvTableScan(schema,None,[16])
            else:
                raise Exception(f"not implement relations {rel_info[0]}")

        elif node == WHERE :
            child = self.build(plan,RELATIONS)
            # 生成filter执行器
            filter = plan.get(WHERE)
            filter_exec = FilterExecutor(filter,child)
            return filter_exec
        elif node == GROUP:
            child = self.build(plan,WHERE)
            # 生成groupby执行器
            groupby = GroupbyExecutor(plan,child)
            return groupby
        elif node == ORDER:
            child = self.build(plan,GROUP)
            orderby = OrderbyExecutor(plan,child)
            return orderby
        elif node == PROJECT:
            child = self.build(plan,ORDER)
            project = ProjectListExecutor(plan,child)
            return project
        else:
            raise Exception(f"not implement plan")


## 物理执行计划执行器

In [516]:
class Executor:
    def __init__(self):
        pass

    def Open(self):
        pass

    def Next(self):
        pass

    def Close(self):
        pass

## csv table scan执行器

In [517]:
class CsvTableScan(Executor):
    def __init__(self,schema : arrow.Schema,column_names : list,drop_columns : list):
        super().__init__()
        self.reader = None
        self.schema = schema
        self.column_names = column_names
        #要删除列的索引
        self.drop_columns = drop_columns
        self.block_size = 16 * 1024

    def Open(self):
        # 打开文件
        meta = self.schema.metadata
        path = meta.get(b"path").decode('utf-8')
        delimiter = meta.get(b"delimiter").decode('utf-8')
        read_opts = arrow.csv.ReadOptions(column_names = self.column_names,
                                          block_size = self.block_size,
                                          autogenerate_column_names = bool)
        parse_opts = arrow.csv.ParseOptions(delimiter = delimiter)
        convert_opts = arrow.csv.ConvertOptions(column_types = self.schema,
                                                include_columns = self.column_names)
        self.reader = arrow.csv.open_csv(path,read_options = read_opts,parse_options = parse_opts,convert_options = convert_opts)

    def Next(self):
        try:
            chunk = self.reader.read_next_batch()
            needed_arrays = []
            for col_idx in range(chunk.num_columns):
                if col_idx in self.drop_columns:
                    continue
                needed_arrays.append(chunk.column(col_idx))
            ret_chunk = arrow.RecordBatch.from_arrays(needed_arrays,self.column_names,self.schema)
            return ret_chunk
        except StopIteration:
            return None

    def Close(self):
        self.reader.close()
        self.reader = None
        self.schema = None
        self.column_names = None
        self.drop_columns = None

In [518]:
# test
# 笔者实验用的csv文件会多出最后一列。
# csvReader = CsvTableScan(lineitemSchema,None,[16])
# csvReader.Open()
# record1 = csvReader.Next()
# pprint(record1.to_pandas())
# csvReader.Close()

## 表达式执行函数

In [519]:
def exec_expr(expr,records,agg_offset):
    """
    在输入数据上执行表达式。用pyarrow.compute完成执行。
    :param expr: 表达式
    :param records: 输入数据。列存。
    :param agg_offset: 聚合函数的偏移。用在group算子。
    :return: 执行结果集。
    """
    expr_type = expr[0]
    if expr_type == EXPR_OP_COLUMN_REF:
        return records.column(expr[1][0].name)
    elif expr_type == EXPR_OP_CONST:
        if expr[1]:# NULL
            raise Exception(f"not implement const NULL")
        else:
            if isinstance(expr[2],ast.String):
                return arrow.array([expr[2].sval]*records.num_rows,arrow.string())
            elif isinstance(expr[2],ast.Integer):
                return arrow.array([expr[2].ival]*records.num_rows,arrow.int32())
            else:
                raise Exception(f"not implement const expr {expr}")
    elif expr_type == EXPR_OP_CAST:
        #直接转换
        l = exec_expr(expr[1],records,agg_offset)
        target_type = expr[2][0].sval
        if len(expr[2]) > 1:
            target_type = expr[2][1].sval
        if target_type == "date":
            date_vals = []
            for s in l:
                date_vals.append(datetime.strptime(str(s),"%Y-%m-%d"))
            return arrow.array(date_vals,arrow.date32())
        elif target_type == "interval":
            int_vals = []
            for s in l:
                int_vals.append(int(str(s)))
            return arrow.array(int_vals,arrow.int32())
        else:
            print(f"target_type {target_type}")
            return compute.cast(l,target_type)
    elif expr_type == EXPR_OP_PLUS:
        l = exec_expr(expr[1],records,agg_offset)
        r = exec_expr(expr[2],records,agg_offset)
        return compute.add(l,r)
    elif expr_type == EXPR_OP_MINUS:
        l = exec_expr(expr[1],records,agg_offset)
        r = exec_expr(expr[2],records,agg_offset)

        if types.is_date32(l.type):
            #对时间的减法做特殊处理
            date_time_vals = []
            for d in l:
                date_time_vals.append(datetime.strptime(str(d),"%Y-%m-%d"))
            time_delta_vals = []
            if types.is_int32(r.type):
                for i in r:
                    time_delta_vals.append(timedelta(int(str(i))))
            else:
                raise Exception("date minus needs int32")

            res_vals = [date_time_vals[i] - time_delta_vals[i] for i in range(len(time_delta_vals))]
            return arrow.array(res_vals,arrow.date32())
        else:
            return compute.subtract(l,r)
    elif expr_type == EXPR_OP_MULTI:
        l = exec_expr(expr[1],records,agg_offset)
        r = exec_expr(expr[2],records,agg_offset)
        return compute.multiply(l,r)
    elif expr_type == EXPR_OP_LESS_EQUAL:
        l = exec_expr(expr[1],records,agg_offset)
        r = exec_expr(expr[2],records,agg_offset)
        return compute.less_equal(l,r)
    elif expr_type == GROUP_EXPR:
        return exec_expr(expr[1],records,agg_offset)
    elif expr_type == OUTPUT_EXPR:
        return exec_expr(expr[2],records,agg_offset)
    elif expr_type == FUNC_RESULT_REF_EXPR:
        return records.column(expr[1]+agg_offset)
    elif expr_type == PROJECT_EXPR:
        return records.column(expr[1])
    else:
        raise Exception(f"not implement exec expr {expr}")
    pass

## filter执行器
执行where表达式。

### 执行器实现

In [520]:
class FilterExecutor(Executor):
    def __init__(self, filter: tuple, child: Executor):
        super().__init__()
        self.filter = filter
        self.child = child

    def Open(self):
        self.child.Open()

    def Next(self):
        child_records = self.child.Next()
        if child_records is None:
            return None
        mask = exec_expr(self.filter,child_records,0)
        return child_records.filter(mask)

    def Close(self):
        self.child.Close()
        self.filter = None
        self.child = None

/*### 测试时间减法*/

In [521]:
# date1 = datetime.strptime('1998-12-01',"%Y-%m-%d")
# date2_arr = arrow.array([date1]*2,arrow.date32())
# # pprint(date2_arr)
# date4_vals = []
# if types.is_date32(date2_arr.type):
#     for d in date2_arr:
#         print(datetime.strptime(str(d),"%Y-%m-%d"))
#
# date3_arr = arrow.array([112]*2,arrow.int32())
# if types.is_int32(date3_arr.type):
#     for i in date3_arr:
#         print(timedelta(int(str(i))))
#

## groupby执行器

### 聚合函数中间结果

In [522]:
class AggFunc:
    def __init__(self):
        pass

    def name(self):
        pass

    def add(self,records,row_idx):
        pass

    def addRecords(self,records):
        pass

    def merge(self,other):
        pass

    def get(self):
        pass

class AggFuncFactory:
    def __init__(self):
        pass

    def create(self,name):
        if name == "sum":
            return AggFuncSum()
        elif name == "count":
            return AggFuncCount()
        elif name == "avg":
            return AggFuncAvg()
        else:
            raise Exception(f"not implement agg func {name}")

class AggFuncSum(AggFunc):
    def __init__(self):
        super().__init__()
        self.sum = None

    def name(self):
        return "sum"

    def add(self,records,row_idx):
        param_val = records[0][row_idx]
        if self.sum is None:
            self.sum = param_val
        else:
            self.sum = compute.add(self.sum,param_val)

    def addRecords(self,records):
        pass

    def merge(self,other):
        pass

    def get(self):
        return self.sum

class AggFuncCount(AggFunc):
    def __init__(self):
        super().__init__()
        self.count = 0

    def name(self):
        return "count"

    def add(self,records,row_idx):
        self.count = self.count + 1

    def addRecords(self,records):
        pass

    def merge(self,other):
        pass

    def get(self):
        return self.count

class AggFuncAvg(AggFunc):
    def __init__(self):
        super().__init__()
        self.count = 0
        self.sum = None

    def name(self):
        return "avg"

    def add(self,records,row_idx):
        param_val = records[0][row_idx]
        if self.sum is None:
            self.sum = param_val
        else:
            self.sum = compute.add(self.sum,param_val)
        self.count = self.count + 1

    def addRecords(self,records):
        pass

    def merge(self,other):
        pass

    def get(self):
        return compute.divide(self.sum , self.count)

### 取groupby的Field

In [523]:
def get_field_from_groupby(e):
    """从groupby表达式中取出pyarrow.Field"""
    typ = e[1][0]
    if typ == EXPR_OP_COLUMN_REF:
        return e[1][1][0]
    else:
        raise Exception(f"not implement group by {e}")

### 确定scalar的类型

In [524]:
def get_field_from_value(s):
    """将值类型转成pyarrow类型"""
    if isinstance(s,arrow.StringScalar):
        return s.type
    elif isinstance(s,arrow.Decimal256Scalar):
        return s.type
    elif isinstance(s,arrow.Int64Scalar):
        return s.type
    elif isinstance(s,arrow.DoubleScalar):
        return s.type
    elif isinstance(s,int):
        return arrow.int64()
    elif isinstance(s,str):
        return arrow.string()
    else:
        raise Exception(f"not implement scalar type {s} {type(s)}")

### 数据列转成array

In [525]:
def convert_col_to_array(col,typ):
    """将输入列转成pyarrow.Array"""
    if isinstance(col[0],arrow.StringScalar):
        return arrow.array([ss.as_py() for ss in col],typ)
    elif isinstance(col[0],arrow.Decimal256Scalar):
        return arrow.array([ss.as_py() for ss in col],typ)
    elif isinstance(col[0],arrow.DoubleScalar):
        return arrow.array([ss.as_py() for ss in col],typ)
    elif isinstance(col[0],int):
        return arrow.array(col,typ)
    elif isinstance(col[0],str):
        return arrow.array(col,typ)
    else:
        raise Exception(f"not implement col type {col}")

### 执行器实现

In [526]:
class GroupbyExecutor(Executor):
    def __init__(self, plan: dict, child: Executor):
        super().__init__()
        self.plan = plan
        self.child = child
        self.aggregate = plan.get(AGGREGATE,[])
        self.groupby = plan.get(GROUP,[])
        self.output = plan.get(OUTPUT,[])
        self.hash_func = hashlib.sha256()
        #构建哈希表
        self.hash_table = {}
        self.agg_factory = AggFuncFactory()

    def update_agg_func_val(self,hash_key,group_by_vals,param_vals,row_idx):
        """
        更新聚合函数的中间结果。
        中间结果的分布：
            groupby1_val, groupby2_val, ..., agg1_val, agg2_val, ...,
        :param hash_key:
        :param group_by_vals:group by表达式值
        :param param_vals:参数表达式值
        :param row_idx:行号。这行值要更新到聚合结果中。
        :return:
        """
        #取中间结果
        hash_val = self.hash_table.get(hash_key,[])
        if len(hash_val) == 0:
            agg_func_vals = []
            for agg_func_expr in self.aggregate:
                agg_func_name = agg_func_expr[1]
                intermediate_result = self.agg_factory.create(agg_func_name)
                agg_func_vals.append(intermediate_result)
            group_by_row_vals = []
            for v in group_by_vals:
                group_by_row_vals.append(v[row_idx])
            # groupby1, groupby2, ..., agg1,agg2, ...,
            hash_val = [group_by_row_vals,agg_func_vals]
            self.hash_table[hash_key] = hash_val

        #更新中间结果
        for agg_idx in range(len(self.aggregate)):
            agg_func_val = hash_val[1][agg_idx]
            param_val = param_vals[agg_idx]
            agg_func_val.add(param_val,row_idx)


    def end_agg_func_val(self):
        #取聚合的最终结果
        # groupby1, groupby2, ..., agg1,agg2,...,
        result_cols = []
        for hash_key,hash_val in self.hash_table.items():
            if len(result_cols) == 0:
                result_cols = [[] for _ in range(len(hash_val[0]) + len(hash_val[1]))]
            #拼接group_by
            for i in range(len(hash_val[0])):
                result_cols[i].append(hash_val[0][i])

            ##拼接聚合函数结果
            begin = len(hash_val[0])
            for i in range(len(hash_val[1])):
                j = begin + i
                result_cols[j].append(hash_val[1][i].get())

        # 定义schema，此处直接拿第一行的结果类型。
        # 正确的做法应该是类型推断
        result_types = []
        for i in range(len(result_cols)):
            col = result_cols[i]
            v = col[0]
            if i < len(self.groupby):
                result_types.append(get_field_from_groupby(self.groupby[i]))
            else:
                t = get_field_from_value(v)
                result_types.append((str(i - len(self.groupby)),t))

        schema = arrow.schema(result_types)

        #确定数据
        result_arr = []
        for i in range(len(result_cols)):
            col = result_cols[i]
            result_arr.append(convert_col_to_array(col,schema.field(i).type))
        return arrow.record_batch(result_arr,schema)

    def update_hash_table(self,group_by_vals,hash_keys,records):
        """
        更新每个组（hash key区分）的聚合函数值。
        :param group_by_vals: groupby表达式的值
        :param hash_keys: groupby表达式的值的hash值
        :param records:
        :return:
        """
        # 计算聚合函数的参数表达式
        all_arg_vals = []
        for agg_idx in range(len(self.aggregate)):
            agg_func = self.aggregate[agg_idx]
            # 计算每个参数的值
            agg_func_name = agg_func[1]
            agg_args = agg_func[2]
            agg_arg_vals = []
            for arg in agg_args:
                if agg_func_name == "count" and arg == "*":
                    agg_arg_vals.append(arrow.array(["*"]*records.num_rows,arrow.string()))
                else:
                    agg_arg_val = exec_expr(arg,records,0)
                    agg_arg_vals.append(agg_arg_val)
            all_arg_vals.append(agg_arg_vals)

        # 根据哈希key分组
        row_count = len(group_by_vals[0])
        for r in range(row_count):
            hash_key = hash_keys[r]
            #更新聚合函数的中间结果
            self.update_agg_func_val(hash_key,group_by_vals,all_arg_vals,r)

    def Open(self):
        self.child.Open()

    def Next(self):
        records = self.child.Next()
        if records is None:
            return None
        while records is not None:
            # 计算groupby表达式
            group_by_vals = []
            #每行一个hash
            hash_funcs = [hashlib.sha256() for _ in range(records.num_rows)]
            for e in self.groupby:
                val = exec_expr(e,records,0)
                i = 0
                for v in val:
                    hash_funcs[i].update(str(v).encode("utf-8"))
                    i = i + 1
                group_by_vals.append(val)
            hash_keys = []
            for hash in hash_funcs:
                hash_keys.append(hash.hexdigest())
            hash_funcs = None
            self.update_hash_table(group_by_vals,hash_keys,records)
            # 下一批输入
            records = self.child.Next()

        #拼接聚合函数结果
        # groupby1, groupby2, ..., agg1,agg2,...,
        agg_records = self.end_agg_func_val()

        #计算output list的值
        output_vals = []
        output_types = []
        for e in self.output:
            val = exec_expr(e,agg_records,len(self.groupby))
            output_vals.append(val)
            output_types.append((e[1],get_field_from_value(val[0])))

        output_schema = arrow.schema(output_types)
        output_records = arrow.record_batch(output_vals,output_schema)
        return output_records

    def Close(self):
        self.child.Close()
        self.plan = None
        self.child = None
        self.aggregate = None
        self.groupby = None
        self.output = None
        self.hash_func = None
        self.hash_table = None
        self.agg_factory = None

## orderby执行器

### 取排序表达式

In [527]:
def get_sort_key_name(e):
    #从sortby表达式中取出name
    typ = e[0]
    if typ == EXPR_OP_COLUMN_REF:
        return e[1][0].name
    else:
        raise Exception(f"not implement sort key name {e}")

def get_sort_keys(orderby):
    sorts =[]
    for e in orderby:
        name = get_sort_key_name(e[1])
        dir = "ascending"
        if e[2] == enums.parsenodes.SortByDir.SORTBY_DESC:
            dir = "descending"
        sorts.append((name,dir))
    return sorts

### 执行器实现

In [528]:
class OrderbyExecutor(Executor):
    def __init__(self,plan,child):
        super().__init__()
        self.plan = plan
        self.child = child
        self.orderby = plan.get(ORDER,[])

    def Open(self):
        self.child.Open()

    def Next(self):
        records = self.child.Next()
        if records is None:
            return None

        #取排序字段
        sort_keys = get_sort_keys(self.orderby)
        #排序并获取排序后的索引
        indices = compute.sort_indices(records,sort_keys)
        #按排序索引重新行顺序
        return compute.take(records,indices)

    def Close(self):
        self.child.Close()
        self.plan = None
        self.child= None
        self.orderby = None

## project list执行器

In [529]:
class ProjectListExecutor(Executor):
    def __init__(self,plan,child):
        super().__init__()
        self.plan = plan
        self.child = child
        self.project_list = plan.get(PROJECT,[])

    def Open(self):
        self.child.Open()

    def Next(self):
        records = self.child.Next()
        if records is None:
            return records

        #计算project list的值
        project_vals = []
        project_types = []
        for e in self.project_list:
            val = exec_expr(e,records,0)
            project_vals.append(val)
            project_types.append((e[1],get_field_from_value(val[0])))

        schema = arrow.schema(project_types)
        project_records = arrow.record_batch(project_vals,schema)
        return project_records

    def Close(self):
        self.child.Close()
        self.plan = None
        self.child= None
        self.project_list = None

# main

## 查看tpch q1逻辑计划

In [530]:
plan = {}
selBuilder = SelectBuilder()
selBuilder.build(q1Stmt.stmt,plan)
pprint(plan)

{'aggregate': [('func',
                'sum',
                [('colRef',
                  (pyarrow.Field<l_quantity: double not null>,
                   'tpch',
                   'lineitem'))]),
               ('func',
                'sum',
                [('colRef',
                  (pyarrow.Field<l_extendedprice: double not null>,
                   'tpch',
                   'lineitem'))]),
               ('func',
                'sum',
                [('*',
                  ('colRef',
                   (pyarrow.Field<l_extendedprice: double not null>,
                    'tpch',
                    'lineitem')),
                  ('-',
                   ('const', False, <Integer ival=1>),
                   ('colRef',
                    (pyarrow.Field<l_discount: double not null>,
                     'tpch',
                     'lineitem'))))]),
               ('func',
                'sum',
                [('*',
                  ('*',
                   ('colRef',


## 执行tpch q1物理计划

In [531]:
pplan_builder = PhysicalPlanBuilder()
exec = pplan_builder.build(plan,PROJECT)
exec.Open()
records = exec.Next()
# csv.write_csv(records,"q1.csv")
pprint(records.to_pandas())


  l_returnflag l_linestatus    sum_qty  sum_base_price  sum_disc_price  \
0            A            F  3774200.0    5.320754e+09    5.054096e+09   
1            N            F    95257.0    1.337378e+08    1.271324e+08   
2            N            O  7338617.0    1.034164e+10    9.824174e+09   
3            R            F  3785523.0    5.337951e+09    5.071819e+09   

     sum_charge    avg_qty     avg_price  avg_disc  count_order  
0  5.256751e+09  25.537587  36002.123829  0.050145       147790  
1  1.322863e+08  25.300664  35521.326916  0.049394         3765  
2  1.021700e+10  25.547128  36001.234322  0.050099       287258  
3  5.274406e+09  25.525944  35994.029214  0.049989       148301  


In [532]:
exec.Close()