Skip to content

Commit

Permalink
support window (ClickHouse#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Dec 26, 2022
1 parent 5ab0b59 commit de33c20
Show file tree
Hide file tree
Showing 10 changed files with 685 additions and 14 deletions.
131 changes: 131 additions & 0 deletions utils/local-engine/Parser/RelParser.cpp
@@ -0,0 +1,131 @@
#include "RelParser.h"
#include <string>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/IDataType.h>
#include <Common/Exception.h>

namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}
}

namespace local_engine
{
AggregateFunctionPtr RelParser::getAggregateFunction(
DB::String & name, DB::DataTypes arg_types, DB::AggregateFunctionProperties & properties, const DB::Array & parameters)
{
auto & factory = AggregateFunctionFactory::instance();
return factory.get(name, arg_types, parameters, properties);
}

std::optional<String> RelParser::parseFunctionName(UInt32 function_ref)
{
const auto & function_mapping = getFunctionMapping();
auto it = function_mapping.find(std::to_string(function_ref));
if (it == function_mapping.end())
{
return {};
}
auto function_signature = it->second;
auto function_name = function_signature.substr(0, function_signature.find(':'));
return function_name;
}

DB::DataTypes RelParser::parseFunctionArgumentTypes(
const Block & header, const google::protobuf::RepeatedPtrField<substrait::FunctionArgument> & func_args)
{
DB::DataTypes res;
for (const auto & arg : func_args)
{
if (!arg.has_value())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect a FunctionArgument with value field");
}
const auto & value = arg.value();
if (value.has_selection())
{
auto pos = value.selection().direct_reference().struct_field().field();
res.push_back(header.getByPosition(pos).type);
}
else if (value.has_literal())
{
auto [data_type, _] = SerializedPlanParser::parseLiteral(value.literal());
res.push_back(data_type);
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow FunctionArgument: {}", arg.DebugString());
}
}
return res;
}

DB::Names RelParser::parseFunctionArgumentNames(
const Block & header, const google::protobuf::RepeatedPtrField<substrait::FunctionArgument> & func_args)
{
DB::Names res;
for (const auto & arg : func_args)
{
if (!arg.has_value())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect a FunctionArgument with value field");
}
const auto & value = arg.value();
if (value.has_selection())
{
auto pos = value.selection().direct_reference().struct_field().field();
res.push_back(header.getByPosition(pos).name);
}
else if (value.has_literal())
{
auto [_, field] = SerializedPlanParser::parseLiteral(value.literal());
res.push_back(field.dump());
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow FunctionArgument: {}", arg.DebugString());
}
}
return res;
}

RelParserFactory & RelParserFactory::instance()
{
static RelParserFactory factory;
return factory;
}

void RelParserFactory::registerBuilder(UInt32 k, RelParserBuilder builder)
{
auto it = builders.find(k);
if (it != builders.end())
{
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Duplicated builder key:{}", k);
}
builders[k] = builder;
}

RelParserFactory::RelParserBuilder RelParserFactory::getBuilder(DB::UInt32 k)
{
auto it = builders.find(k);
if (it == builders.end())
{
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Not found builder for key:{}", k);
}
return it->second;
}


void registerWindowRelParser(RelParserFactory & factory);
void registerSortRelParser(RelParserFactory & factory);
void initRelParserFactory()
{
auto & factory = RelParserFactory::instance();
registerWindowRelParser(factory);
registerSortRelParser(factory);
}
}
62 changes: 62 additions & 0 deletions utils/local-engine/Parser/RelParser.h
@@ -0,0 +1,62 @@
#pragma once
#include <map>
#include <optional>
#include <unordered_map>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Core/Field.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Parser/SerializedPlanParser.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <base/types.h>
#include <google/protobuf/repeated_field.h>
#include <substrait/plan.pb.h>
namespace local_engine
{
/// parse a single substrait relation
class RelParser
{
public:
explicit RelParser(SerializedPlanParser * plan_parser_)
:plan_parser(plan_parser_)
{}

virtual ~RelParser() = default;
virtual DB::QueryPlanPtr parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_) = 0;

static AggregateFunctionPtr getAggregateFunction(
DB::String & name,
DB::DataTypes arg_types,
DB::AggregateFunctionProperties & properties,
const DB::Array & parameters = {});

public:
static DB::DataTypePtr parseType(const substrait::Type & type) { return SerializedPlanParser::parseType(type); }
protected:
inline ContextPtr getContext() { return plan_parser->context; }
inline String getUniqueName(const std::string & name) { return plan_parser->getUniqueName(name); }
inline const std::unordered_map<std::string, std::string> & getFunctionMapping() { return plan_parser->function_mapping; }
std::optional<String> parseFunctionName(UInt32 function_ref);
static DB::DataTypes parseFunctionArgumentTypes(const Block & header, const google::protobuf::RepeatedPtrField<substrait::FunctionArgument> & func_args);
static DB::Names parseFunctionArgumentNames(const Block & header, const google::protobuf::RepeatedPtrField<substrait::FunctionArgument> & func_args);

private:
SerializedPlanParser * plan_parser;
};

class RelParserFactory
{
protected:
RelParserFactory() = default;
public:
using RelParserBuilder = std::function<std::shared_ptr<RelParser>(SerializedPlanParser *)>;
static RelParserFactory & instance();
void registerBuilder(UInt32 k, RelParserBuilder builder);
RelParserBuilder getBuilder(DB::UInt32 k);
private:
std::map<UInt32, RelParserBuilder> builders;
};

void initRelParserFactory();

}
28 changes: 21 additions & 7 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Expand Up @@ -62,6 +62,7 @@
#include <sys/select.h>
#include <Common/CHUtil.h>
#include "SerializedPlanParser.h"
#include <Parser/RelParser.h>

namespace DB
{
Expand Down Expand Up @@ -756,7 +757,8 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
}
}

if (need_convert) {
if (need_convert)
{
ActionsDAGPtr convert_action
= ActionsDAG::makeConvertingActions(source, target, DB::ActionsDAG::MatchColumnsMode::Position);
if (convert_action)
Expand Down Expand Up @@ -805,8 +807,15 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
break;
}
case substrait::Rel::RelTypeCase::kSort: {
const auto & sort_rel = rel.sort();
query_plan = parseSort(sort_rel);
query_plan = parseSort(rel.sort());
break;
}
case substrait::Rel::RelTypeCase::kWindow: {
const auto win_rel = rel.window();
query_plan = parseOp(win_rel.input());
auto win_parser = RelParserFactory::instance().getBuilder(substrait::Rel::RelTypeCase::kWindow)(this);
std::list<const substrait::Rel *> rel_stack;
query_plan = win_parser->parse(std::move(query_plan), rel, rel_stack);
break;
}
default:
Expand Down Expand Up @@ -885,7 +894,8 @@ void SerializedPlanParser::addPreProjectStepIfNeeded(
}
wrapNullable(to_wrap_nullable, expression, nullable_measure_names);

if (need_pre_project) {
if (need_pre_project)
{
auto expression_before_aggregate = std::make_unique<ExpressionStep>(input, expression);
expression_before_aggregate->setStepDescription("Before Aggregate");
plan.addStep(std::move(expression_before_aggregate));
Expand All @@ -912,10 +922,13 @@ QueryPlanStepPtr SerializedPlanParser::parseAggregate(QueryPlan & plan, const su

if (phase_set.size() > 1)
{
if (phase_set.size() == 2 && has_first_stage && has_inter_stage) {
if (phase_set.size() == 2 && has_first_stage && has_inter_stage)
{
// this will happen in a sql like:
// select sum(a), count(distinct b) from T
} else {
}
else
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "too many aggregate phase!");
}
}
Expand Down Expand Up @@ -968,7 +981,8 @@ QueryPlanStepPtr SerializedPlanParser::parseAggregate(QueryPlan & plan, const su
// if measure arg has nullable version, use it
auto input_column = measure_names.at(i);
auto entry = nullable_measure_names.find(input_column);
if (entry != nullable_measure_names.end()) {
if (entry != nullable_measure_names.end())
{
input_column = entry->second;
}
agg.arguments = ColumnNumbers{plan.getCurrentDataStream().header.getPositionByName(input_column)};
Expand Down
2 changes: 2 additions & 0 deletions utils/local-engine/Parser/SerializedPlanParser.h
Expand Up @@ -2,6 +2,7 @@

#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Core/SortDescription.h>
#include <DataTypes/DataTypeFactory.h>
#include <Parser/CHColumnToSparkRow.h>
#include <Processors/Executors/PullingPipelineExecutor.h>
Expand Down Expand Up @@ -168,6 +169,7 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type);

class SerializedPlanParser
{
friend class RelParser;
public:
explicit SerializedPlanParser(const ContextPtr & context);
static void initFunctionEnv();
Expand Down
86 changes: 86 additions & 0 deletions utils/local-engine/Parser/SortRelParser.cpp
@@ -0,0 +1,86 @@
#include "SortRelParser.h"
#include <Parser/RelParser.h>
#include <Processors/QueryPlan/SortingStep.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
}

namespace local_engine
{

SortRelParser::SortRelParser(SerializedPlanParser * plan_paser_)
: RelParser(plan_paser_)
{}

DB::QueryPlanPtr
SortRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & /*rel_stack_*/)
{
const auto & sort_rel = rel.sort();
auto sort_descr = parseSortDescription(sort_rel.sorts());
const auto & settings = getContext()->getSettingsRef();
auto sorting_step = std::make_unique<DB::SortingStep>(
query_plan->getCurrentDataStream(),
sort_descr,
settings.max_block_size,
0, // no limit now
SizeLimits(settings.max_rows_to_sort, settings.max_bytes_to_sort, settings.sort_overflow_mode),
settings.max_bytes_before_remerge_sort,
settings.remerge_sort_lowered_memory_bytes_ratio,
settings.max_bytes_before_external_sort,
getContext()->getTemporaryVolume(),
settings.min_free_disk_space_for_temporary_data);
sorting_step->setStepDescription("Sorting step");
query_plan->addStep(std::move(sorting_step));
return query_plan;
}

DB::SortDescription
SortRelParser::parseSortDescription(const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields, const DB::Block & header)
{
static std::map<int, std::pair<int, int>> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}};

DB::SortDescription sort_descr;
for (int i = 0, sz = sort_fields.size(); i < sz; ++i)
{
const auto & sort_field = sort_fields[i];

if (!sort_field.expr().has_selection() || !sort_field.expr().selection().has_direct_reference()
|| !sort_field.expr().selection().direct_reference().has_struct_field())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport sort field");
}
auto field_pos = sort_field.expr().selection().direct_reference().struct_field().field();

auto direction_iter = direction_map.find(sort_field.direction());
if (direction_iter == direction_map.end())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsuppor sort direction: {}", sort_field.direction());
}
if (header.columns())
{
auto & col_name = header.getByPosition(field_pos).name;
sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second);
sort_descr.back().column_number = field_pos;
}
else
{
sort_descr.emplace_back(field_pos, direction_iter->second.first, direction_iter->second.second);
}
}
return sort_descr;
}

void registerSortRelParser(RelParserFactory & factory)
{
auto builder = [](SerializedPlanParser * plan_parser)
{
return std::make_shared<SortRelParser>(plan_parser);
};
factory.registerBuilder(substrait::Rel::RelTypeCase::kSort, builder);
}
}
19 changes: 19 additions & 0 deletions utils/local-engine/Parser/SortRelParser.h
@@ -0,0 +1,19 @@
#pragma once
#include <Core/Block.h>
#include <Core/SortDescription.h>
#include <Parser/RelParser.h>
#include <google/protobuf/repeated_field.h>
namespace local_engine
{
class SortRelParser : public RelParser
{
public:
explicit SortRelParser(SerializedPlanParser * plan_paser_);
~SortRelParser() override = default;

DB::QueryPlanPtr
parse(DB::QueryPlanPtr query_plan, const substrait::Rel & sort_rel, std::list<const substrait::Rel *> & rel_stack_) override;
static DB::SortDescription parseSortDescription(const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields, const DB::Block & header = {});

};
}

0 comments on commit de33c20

Please sign in to comment.