Skip to content

Commit

Permalink
fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ariesdevil committed Mar 27, 2023
1 parent 7cf33e3 commit 0e30622
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 79 deletions.
2 changes: 1 addition & 1 deletion src/query/service/src/pipelines/pipeline_builder.rs
Expand Up @@ -800,7 +800,7 @@ impl PipelineBuilder {
self.main_pipeline.resize(max_threads)?;
}

self.build_sort_pipeline(input_schema.clone(), sort_desc, sort.plan_id, sort.limit)
self.build_sort_pipeline(input_schema, sort_desc, sort.plan_id, sort.limit)
}

fn build_sort_pipeline(
Expand Down
Expand Up @@ -286,7 +286,7 @@ impl TransformWindow {
if self.frame_started {
return;
}
match &self.frame_kind.start {
match &self.frame_kind.start_bound {
WindowFuncFrameBound::CurrentRow => {
self.frame_started = true;
self.frame_start = self.current_row;
Expand All @@ -303,20 +303,14 @@ impl TransformWindow {
self.frame_started = true;
self.frame_start = self.partition_start;
}
<<<<<<< HEAD
WindowFrameBound::Following(Some(n)) => {
WindowFuncFrameBound::Following(Some(n)) => {
self.frame_start = if self.current_row_in_partition == 1 {
self.add_rows_within_partition(self.current_row, *n)
} else {
self.advance_row(self.prev_frame_start)
.min(self.partition_end)
};
self.frame_started = self.partition_ended || self.frame_start < self.partition_end;
=======
WindowFuncFrameBound::Following(Some(n)) => {
self.frame_start = self.current_row_add_within_partition(*n);
self.frame_started = self.partition_ended || self.frame_start < self.partition_end
>>>>>>> 826550fb47 (fix)
}
WindowFuncFrameBound::Following(_) => {
unreachable!()
Expand All @@ -325,7 +319,7 @@ impl TransformWindow {
}

fn advance_frame_end(&mut self) {
match &self.frame_kind.end {
match &self.frame_kind.end_bound {
WindowFuncFrameBound::CurrentRow => {
self.frame_ended = true;
self.frame_end = self.current_row;
Expand All @@ -341,8 +335,7 @@ impl TransformWindow {
WindowFuncFrameBound::Preceding(_) => {
unreachable!()
}
<<<<<<< HEAD
WindowFrameBound::Following(Some(n)) => {
WindowFuncFrameBound::Following(Some(n)) => {
self.frame_end = if self.current_row_in_partition == 1 {
let next_end = self.add_rows_within_partition(self.current_row, *n);
self.frame_ended = self.partition_ended || next_end < self.partition_end;
Expand All @@ -354,13 +347,6 @@ impl TransformWindow {
self.advance_row(self.prev_frame_end)
}
.min(self.partition_end);
=======
WindowFuncFrameBound::Following(Some(n)) => {
self.frame_end = self.current_row_add_within_partition(*n);
self.frame_ended = self.partition_ended || self.frame_end < self.partition_end;
// Frame end is excluded.
self.frame_end = self.advance_row(self.frame_end).min(self.partition_end);
>>>>>>> 826550fb47 (fix)
}
WindowFuncFrameBound::Following(_) => {
self.frame_ended = self.partition_ended;
Expand Down Expand Up @@ -476,15 +462,16 @@ impl TransformWindow {
}

fn apply_aggregate(&mut self) -> Result<()> {
let WindowFrame {
let WindowFuncFrame {
start_bound,
end_bound,
..
} = &self.frame_kind;
match (start_bound, end_bound) {
(WindowFrameBound::Preceding(None), WindowFrameBound::Following(None)) => {
(WindowFuncFrameBound::Preceding(None), WindowFuncFrameBound::Following(None)) => {
self.apply_aggregate_for_unbounded_frame()
}
(WindowFrameBound::Preceding(None), _) => {
(WindowFuncFrameBound::Preceding(None), _) => {
self.apply_aggregate_for_unbounded_preceding()
}
(_, _) => self.apply_aggregate_common(),
Expand Down Expand Up @@ -654,14 +641,11 @@ mod tests {
use common_pipeline_core::processors::connect;
use common_pipeline_core::processors::port::InputPort;
use common_pipeline_core::processors::port::OutputPort;
<<<<<<< HEAD
use common_pipeline_core::processors::processor::Event;
use common_pipeline_core::processors::Processor;
=======
use common_sql::plans::WindowFuncFrame;
use common_sql::plans::WindowFuncFrameBound;
use common_sql::plans::WindowFuncFrameUnits;
>>>>>>> 826550fb47 (fix)

use super::TransformWindow;
use super::WindowBlock;
Expand Down Expand Up @@ -702,8 +686,8 @@ mod tests {
let mut transform = get_transform_window_with_data(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::CurrentRow,
end: WindowFuncFrameBound::CurrentRow,
start_bound: WindowFuncFrameBound::CurrentRow,
end_bound: WindowFuncFrameBound::CurrentRow,
},
Int32Type::from_data(vec![1, 1, 1]),
)?;
Expand All @@ -718,8 +702,8 @@ mod tests {
let mut transform = get_transform_window_with_data(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::CurrentRow,
end: WindowFuncFrameBound::CurrentRow,
start_bound: WindowFuncFrameBound::CurrentRow,
end_bound: WindowFuncFrameBound::CurrentRow,
},
Int32Type::from_data(vec![1, 1, 2]),
)?;
Expand All @@ -738,8 +722,8 @@ mod tests {
let mut transform = get_transform_window_with_data(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Following(Some(4)),
end: WindowFuncFrameBound::Following(Some(5)),
start_bound: WindowFuncFrameBound::Following(Some(4)),
end_bound: WindowFuncFrameBound::Following(Some(5)),
},
Int32Type::from_data(vec![1, 1, 1]),
)?;
Expand All @@ -757,8 +741,8 @@ mod tests {
let mut transform = get_transform_window_with_data(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(Some(2)),
end: WindowFuncFrameBound::Following(Some(5)),
start_bound: WindowFuncFrameBound::Preceding(Some(2)),
end_bound: WindowFuncFrameBound::Following(Some(5)),
},
Int32Type::from_data(vec![1, 1, 1]),
)?;
Expand All @@ -781,8 +765,8 @@ mod tests {
let mut transform = get_transform_window_with_data(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(Some(2)),
end: WindowFuncFrameBound::Following(Some(1)),
start_bound: WindowFuncFrameBound::Preceding(Some(2)),
end_bound: WindowFuncFrameBound::Following(Some(1)),
},
Int32Type::from_data(vec![1, 1, 1]),
)?;
Expand All @@ -808,8 +792,8 @@ mod tests {
let mut transform = get_transform_window_with_data(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(None),
end: WindowFuncFrameBound::Following(None),
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(None),
},
Int32Type::from_data(vec![1, 1, 1, 2]),
)?;
Expand Down Expand Up @@ -840,8 +824,8 @@ mod tests {
let mut transform = get_transform_window(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(None),
end: WindowFuncFrameBound::Following(None),
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(None),
},
DataType::Number(NumberDataType::Int32),
)?;
Expand Down Expand Up @@ -876,8 +860,8 @@ mod tests {
let mut transform = get_transform_window(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(None),
end: WindowFuncFrameBound::Following(None),
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(None),
},
DataType::Number(NumberDataType::Int32),
)?;
Expand Down Expand Up @@ -938,8 +922,8 @@ mod tests {
let mut transform = get_transform_window(
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(None),
end: WindowFuncFrameBound::Following(None),
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(None),
},
DataType::Number(NumberDataType::Int32),
)?;
Expand Down Expand Up @@ -998,10 +982,10 @@ mod tests {

{
let mut transform = get_transform_window(
<<<<<<< HEAD
WindowFrame {
start_bound: WindowFrameBound::Preceding(None),
end_bound: WindowFrameBound::Following(None),
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(None),
},
DataType::Number(NumberDataType::Int32),
)?;
Expand Down Expand Up @@ -1060,9 +1044,10 @@ mod tests {

{
let mut transform = get_transform_window(
WindowFrame {
start_bound: WindowFrameBound::Preceding(None),
end_bound: WindowFrameBound::Following(Some(1)),
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(Some(1)),
},
DataType::Number(NumberDataType::Int32),
)?;
Expand Down Expand Up @@ -1121,15 +1106,10 @@ mod tests {

{
let mut transform = get_transform_window(
WindowFrame {
start_bound: WindowFrameBound::Preceding(Some(1)),
end_bound: WindowFrameBound::Following(Some(1)),
=======
WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start: WindowFuncFrameBound::Preceding(Some(1)),
end: WindowFuncFrameBound::Following(Some(1)),
>>>>>>> 826550fb47 (fix)
start_bound: WindowFuncFrameBound::Preceding(Some(1)),
end_bound: WindowFuncFrameBound::Following(Some(1)),
},
DataType::Number(NumberDataType::Int32),
)?;
Expand Down Expand Up @@ -1191,7 +1171,7 @@ mod tests {

#[allow(clippy::type_complexity)]
fn get_transform_window_and_ports(
window_frame: WindowFrame,
window_frame: WindowFuncFrame,
) -> Result<(Box<dyn Processor>, Arc<InputPort>, Arc<OutputPort>)> {
let function = AggregateFunctionFactory::instance()
.get("sum", vec![], vec![DataType::Number(NumberDataType::Int32)])?;
Expand All @@ -1214,9 +1194,10 @@ mod tests {
{
let upstream_output = OutputPort::create();
let downstream_input = InputPort::create();
let (mut transform, input, output) = get_transform_window_and_ports(WindowFrame {
start_bound: WindowFrameBound::Preceding(Some(1)),
end_bound: WindowFrameBound::Following(Some(1)),
let (mut transform, input, output) = get_transform_window_and_ports(WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start_bound: WindowFuncFrameBound::Preceding(Some(1)),
end_bound: WindowFuncFrameBound::Following(Some(1)),
})?;

unsafe {
Expand Down Expand Up @@ -1289,9 +1270,10 @@ mod tests {
{
let upstream_output = OutputPort::create();
let downstream_input = InputPort::create();
let (mut transform, input, output) = get_transform_window_and_ports(WindowFrame {
start_bound: WindowFrameBound::Preceding(None),
end_bound: WindowFrameBound::Following(None),
let (mut transform, input, output) = get_transform_window_and_ports(WindowFuncFrame {
units: WindowFuncFrameUnits::Rows,
start_bound: WindowFuncFrameBound::Preceding(None),
end_bound: WindowFuncFrameBound::Following(None),
})?;

unsafe {
Expand Down
2 changes: 1 addition & 1 deletion src/query/sql/src/planner/binder/project.rs
Expand Up @@ -120,7 +120,7 @@ impl Binder {
})
} else {
let mut window_checker = WindowChecker::new(bind_context);
let scalar = window_checker.resolve(&item.scalar, None)?;
let scalar = window_checker.resolve(&item.scalar)?;
Ok(ScalarItem {
scalar,
index: item.index,
Expand Down
4 changes: 2 additions & 2 deletions src/query/sql/src/planner/plans/scalar_expr.rs
Expand Up @@ -449,8 +449,8 @@ pub struct WindowOrderBy {
#[derive(Default, Clone, PartialEq, Eq, Hash, Debug, serde::Serialize, serde::Deserialize)]
pub struct WindowFuncFrame {
pub units: WindowFuncFrameUnits,
pub start: WindowFuncFrameBound,
pub end: WindowFuncFrameBound,
pub start_bound: WindowFuncFrameBound,
pub end_bound: WindowFuncFrameBound,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
Expand Down
6 changes: 5 additions & 1 deletion src/query/sql/src/planner/semantic/type_check.rs
Expand Up @@ -1016,7 +1016,11 @@ impl<'a> TypeChecker<'a> {
agg_func,
partition_by: partitions,
order_by: order_bys,
frame: WindowFuncFrame { units, start, end },
frame: WindowFuncFrame {
units,
start_bound: start,
end_bound: end,
},
};

Ok(Box::new((window_func.into(), return_type)))
Expand Down
21 changes: 10 additions & 11 deletions src/query/sql/src/planner/semantic/window_check.rs
Expand Up @@ -14,7 +14,6 @@

use common_exception::ErrorCode;
use common_exception::Result;
use common_exception::Span;

use crate::plans::AndExpr;
use crate::plans::BoundColumnRef;
Expand All @@ -37,36 +36,36 @@ impl<'a> WindowChecker<'a> {
Self { bind_context }
}

pub fn resolve(&mut self, scalar: &ScalarExpr, span: Span) -> Result<ScalarExpr> {
pub fn resolve(&mut self, scalar: &ScalarExpr) -> Result<ScalarExpr> {
match scalar {
ScalarExpr::BoundColumnRef(_)
| ScalarExpr::BoundInternalColumnRef(_)
| ScalarExpr::ConstantExpr(_) => Ok(scalar.clone()),
ScalarExpr::AndExpr(scalar) => Ok(AndExpr {
left: Box::new(self.resolve(&scalar.left, span)?),
right: Box::new(self.resolve(&scalar.right, span)?),
left: Box::new(self.resolve(&scalar.left)?),
right: Box::new(self.resolve(&scalar.right)?),
}
.into()),
ScalarExpr::OrExpr(scalar) => Ok(OrExpr {
left: Box::new(self.resolve(&scalar.left, span)?),
right: Box::new(self.resolve(&scalar.right, span)?),
left: Box::new(self.resolve(&scalar.left)?),
right: Box::new(self.resolve(&scalar.right)?),
}
.into()),
ScalarExpr::NotExpr(scalar) => Ok(NotExpr {
argument: Box::new(self.resolve(&scalar.argument, span)?),
argument: Box::new(self.resolve(&scalar.argument)?),
}
.into()),
ScalarExpr::ComparisonExpr(scalar) => Ok(ComparisonExpr {
op: scalar.op.clone(),
left: Box::new(self.resolve(&scalar.left, span)?),
right: Box::new(self.resolve(&scalar.right, span)?),
left: Box::new(self.resolve(&scalar.left)?),
right: Box::new(self.resolve(&scalar.right)?),
}
.into()),
ScalarExpr::FunctionCall(func) => {
let args = func
.arguments
.iter()
.map(|arg| self.resolve(arg, span))
.map(|arg| self.resolve(arg))
.collect::<Result<Vec<ScalarExpr>>>()?;
Ok(FunctionCall {
span: func.span,
Expand All @@ -79,7 +78,7 @@ impl<'a> WindowChecker<'a> {
ScalarExpr::CastExpr(cast) => Ok(CastExpr {
span: cast.span,
is_try: cast.is_try,
argument: Box::new(self.resolve(&cast.argument, span)?),
argument: Box::new(self.resolve(&cast.argument)?),
target_type: cast.target_type.clone(),
}
.into()),
Expand Down

0 comments on commit 0e30622

Please sign in to comment.