Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new window function ntile #46256

Merged
merged 6 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/sql-reference/window-functions/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ClickHouse supports the standard grammar for defining windows and window functio
| `lag/lead(value, offset)` | Not supported. Workarounds: |
| | 1) replace with `any(value) over (.... rows between <offset> preceding and <offset> preceding)`, or `following` for `lead` |
| | 2) use `lagInFrame/leadInFrame`, which are analogous, but respect the window frame. To get behavior identical to `lag/lead`, use `rows between unbounded preceding and unbounded following` |
| ntile(buckets) | Supported. Specify window like, (partition by x order by y rows between unbounded preceding and unounded following). |

## ClickHouse-specific Window Functions

Expand Down
149 changes: 148 additions & 1 deletion src/Processors/Transforms/WindowTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,6 @@ void WindowTransform::work()
assert(prev_frame_start <= frame_start);
const auto first_used_block = std::min(next_output_block_number,
std::min(prev_frame_start.block, current_row.block));

if (first_block_number < first_used_block)
{
// fmt::print(stderr, "will drop blocks from {} to {}\n", first_block_number,
Expand Down Expand Up @@ -1970,6 +1969,147 @@ struct WindowFunctionRowNumber final : public WindowFunction
}
};

// Usage: ntile(n). n is the number of buckets.
struct WindowFunctionNtile final : public WindowFunction
{
WindowFunctionNtile(const std::string & name_,
const DataTypes & argument_types_, const Array & parameters_)
: WindowFunction(name_, argument_types_, parameters_, std::make_shared<DataTypeUInt64>())
{
if (argument_types.size() != 1)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} takes exactly one parameter", name_);
}
auto type_id = argument_types[0]->getTypeId();
if (type_id != TypeIndex::UInt8 && type_id != TypeIndex::UInt16 && type_id != TypeIndex::UInt32 && type_id != TypeIndex::UInt32 && type_id != TypeIndex::UInt64)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument type must be an unsigned integer (not larger then 64-bit), but got {}", argument_types[0]->getName());
}
}

bool allocatesMemoryInArena() const override { return false; }

void windowInsertResultInto(const WindowTransform * transform,
size_t function_index) override
{
if (!buckets) [[unlikely]]
{
checkWindowFrameType(transform);
const auto & current_block = transform->blockAt(transform->current_row);
const auto & workspace = transform->workspaces[function_index];
const auto & arg_col = *current_block.original_input_columns[workspace.argument_column_indices[0]];
if (!isColumnConst(arg_col))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument must be a constant");
auto type_id = argument_types[0]->getTypeId();
if (type_id == TypeIndex::UInt8)
buckets = arg_col[transform->current_row.row].get<UInt8>();
else if (type_id == TypeIndex::UInt16)
buckets = arg_col[transform->current_row.row].get<UInt16>();
else if (type_id == TypeIndex::UInt32)
buckets = arg_col[transform->current_row.row].get<UInt32>();
else if (type_id == TypeIndex::UInt64)
buckets = arg_col[transform->current_row.row].get<UInt64>();

if (!buckets)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument must > 0");
}
}
// new partition
if (transform->current_row_number == 1) [[unlikely]]
{
current_partition_rows = 0;
current_partition_inserted_row = 0;
start_row = transform->current_row;
}
current_partition_rows++;

// Only do the action when we meet the last row in this partition.
if (!transform->partition_ended)
return;
else
{
auto current_row = transform->current_row;
current_row.row++;
const auto & end_row = transform->partition_end;
if (current_row != end_row)
{

if (current_row.row < transform->blockRowsNumber(current_row))
return;
if (end_row.block != current_row.block + 1 || end_row.row)
{
return;
}
// else, current_row is the last input row.
}
}
auto bucket_capacity = current_partition_rows / buckets;
auto capacity_diff = current_partition_rows - bucket_capacity * buckets;

// bucket number starts from 1.
UInt64 bucket_num = 1;
while (current_partition_inserted_row < current_partition_rows)
{
auto current_bucket_capacity = bucket_capacity;
if (capacity_diff > 0)
{
current_bucket_capacity += 1;
capacity_diff--;
}
auto left_rows = current_bucket_capacity;
while (left_rows)
{
auto available_block_rows = transform->blockRowsNumber(start_row) - start_row.row;
IColumn & to = *transform->blockAt(start_row).output_columns[function_index];
auto & pod_array = assert_cast<ColumnUInt64 &>(to).getData();
if (left_rows < available_block_rows)
{
pod_array.resize_fill(pod_array.size() + left_rows, bucket_num);
start_row.row += left_rows;
left_rows = 0;
}
else
{
pod_array.resize_fill(pod_array.size() + available_block_rows, bucket_num);
left_rows -= available_block_rows;
start_row.block++;
start_row.row = 0;
}
}
current_partition_inserted_row += current_bucket_capacity;
bucket_num += 1;
}
}
private:
UInt64 buckets = 0;
RowNumber start_row;
UInt64 current_partition_rows = 0;
UInt64 current_partition_inserted_row = 0;

static void checkWindowFrameType(const WindowTransform * transform)
{
if (transform->order_by_indices.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's window frame must have order by clause");
if (transform->window_description.frame.type != WindowFrame::FrameType::ROWS)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's frame type must be ROWS");
}
if (transform->window_description.frame.begin_type != WindowFrame::BoundaryType::Unbounded)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's frame start type must be UNBOUNDED PRECEDING");
}

if (transform->window_description.frame.end_type != WindowFrame::BoundaryType::Unbounded)
{
// We must wait all for the partition end and get the total rows number in this
// partition. So before the end of this partition, there is no any block could be
// dropped out.
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's frame end type must be UNBOUNDED FOLLOWING");
}
}
};

// ClickHouse-specific variant of lag/lead that respects the window frame.
template <bool is_lead>
struct WindowFunctionLagLeadInFrame final : public WindowFunction
Expand Down Expand Up @@ -2338,6 +2478,13 @@ void registerWindowFunctions(AggregateFunctionFactory & factory)
parameters);
}, properties}, AggregateFunctionFactory::CaseInsensitive);

factory.registerFunction("ntile", {[](const std::string & name,
const DataTypes & argument_types, const Array & parameters, const Settings *)
{
return std::make_shared<WindowFunctionNtile>(name, argument_types,
parameters);
}, properties}, AggregateFunctionFactory::CaseInsensitive);

factory.registerFunction("nth_value", {[](const std::string & name,
const DataTypes & argument_types, const Array & parameters, const Settings *)
{
Expand Down
201 changes: 201 additions & 0 deletions tests/queries/0_stateless/02661_window_ntile.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
-- { echo }

-- Normal cases
select a, b, ntile(3) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20));
0 0 1
0 1 1
0 2 1
0 3 1
0 4 2
0 5 2
0 6 2
0 7 3
0 8 3
0 9 3
1 0 1
1 1 1
1 2 1
1 3 1
1 4 2
1 5 2
1 6 2
1 7 3
1 8 3
1 9 3
select a, b, ntile(2) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20));
0 0 1
0 1 1
0 2 1
0 3 1
0 4 1
0 5 2
0 6 2
0 7 2
0 8 2
0 9 2
1 0 1
1 1 1
1 2 1
1 3 1
1 4 1
1 5 2
1 6 2
1 7 2
1 8 2
1 9 2
select a, b, ntile(1) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20));
0 0 1
0 1 1
0 2 1
0 3 1
0 4 1
0 5 1
0 6 1
0 7 1
0 8 1
0 9 1
1 0 1
1 1 1
1 2 1
1 3 1
1 4 1
1 5 1
1 6 1
1 7 1
1 8 1
1 9 1
select a, b, ntile(100) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20));
0 0 1
0 1 2
0 2 3
0 3 4
0 4 5
0 5 6
0 6 7
0 7 8
0 8 9
0 9 10
1 0 1
1 1 2
1 2 3
1 3 4
1 4 5
1 5 6
1 6 7
1 7 8
1 8 9
1 9 10
select a, b, ntile(65535) over (partition by a order by b rows between unbounded preceding and unbounded following) from (select 1 as a, number as b from numbers(65535)) limit 100;
1 0 1
1 1 2
1 2 3
1 3 4
1 4 5
1 5 6
1 6 7
1 7 8
1 8 9
1 9 10
1 10 11
1 11 12
1 12 13
1 13 14
1 14 15
1 15 16
1 16 17
1 17 18
1 18 19
1 19 20
1 20 21
1 21 22
1 22 23
1 23 24
1 24 25
1 25 26
1 26 27
1 27 28
1 28 29
1 29 30
1 30 31
1 31 32
1 32 33
1 33 34
1 34 35
1 35 36
1 36 37
1 37 38
1 38 39
1 39 40
1 40 41
1 41 42
1 42 43
1 43 44
1 44 45
1 45 46
1 46 47
1 47 48
1 48 49
1 49 50
1 50 51
1 51 52
1 52 53
1 53 54
1 54 55
1 55 56
1 56 57
1 57 58
1 58 59
1 59 60
1 60 61
1 61 62
1 62 63
1 63 64
1 64 65
1 65 66
1 66 67
1 67 68
1 68 69
1 69 70
1 70 71
1 71 72
1 72 73
1 73 74
1 74 75
1 75 76
1 76 77
1 77 78
1 78 79
1 79 80
1 80 81
1 81 82
1 82 83
1 83 84
1 84 85
1 85 86
1 86 87
1 87 88
1 88 89
1 89 90
1 90 91
1 91 92
1 92 93
1 93 94
1 94 95
1 95 96
1 96 97
1 97 98
1 98 99
1 99 100
-- Bad arguments
select a, b, ntile(3.0) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile('2') over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile(0) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile(-2) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile(b + 1) over (partition by a order by b rows between unbounded preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
-- Bad window type
select a, b, ntile(2) over (partition by a) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile(2) over (partition by a order by b rows between 4 preceding and unbounded following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile(2) over (partition by a order by b rows between unbounded preceding and 4 following) from(select intDiv(number,10) as a, number%10 as b from numbers(20)); -- { serverError 36 }
select a, b, ntile(2) over (partition by a order by b rows between 4 preceding and 4 following) from(select intDiv(number,10) as a, number%10 as b from numbers(20));; -- { serverError 36 }
select a, b, ntile(2) over (partition by a order by b rows between current row and 4 following) from(select intDiv(number,10) as a, number%10 as b from numbers(20));; -- { serverError 36 }
select a, b, ntile(2) over (partition by a order by b range unbounded preceding) from(select intDiv(number,10) as a, number%10 as b from numbers(20));; -- { serverError 36 }