Skip to content

Commit

Permalink
feat(cubesql): Rewrites for pushdown of subqueries with empty source (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
waralexrom committed May 14, 2024
1 parent e366406 commit 86a58a5
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 96 deletions.
4 changes: 2 additions & 2 deletions packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -2918,10 +2918,10 @@ export class BaseQuery {
},
statements: {
select: 'SELECT {% if distinct %}DISTINCT {% endif %}' +
'{{ select_concat | map(attribute=\'aliased\') | join(\', \') }} \n' +
'{{ select_concat | map(attribute=\'aliased\') | join(\', \') }} {% if from %}\n' +
'FROM (\n' +
'{{ from | indent(2, true) }}\n' +
') AS {{ from_alias }}' +
') AS {{ from_alias }}{% endif %}' +
'{% if filter %}\nWHERE {{ filter }}{% endif %}' +
'{% if group_by %}\nGROUP BY {{ group_by | map(attribute=\'index\') | join(\', \') }}{% endif %}' +
'{% if order_by %}\nORDER BY {{ order_by | map(attribute=\'expr\') | join(\', \') }}{% endif %}' +
Expand Down
75 changes: 45 additions & 30 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ impl SqlQuery {
index
}

pub fn extend_values(&mut self, values: &Vec<Option<String>>) {
self.values.extend(values.iter().cloned());
}

pub fn replace_sql(&mut self, sql: String) {
self.sql = sql;
}
Expand Down Expand Up @@ -243,6 +247,7 @@ impl CubeScanWrapperNode {
self.clone().set_max_limit_for_node(wrapped_plan),
true,
Vec::new(),
None,
)
.await
.and_then(|SqlGenerationResult { data_source, mut sql, request, column_remapping, .. }| -> result::Result<_, CubeError> {
Expand Down Expand Up @@ -324,7 +329,8 @@ impl CubeScanWrapperNode {
load_request_meta: Arc<LoadRequestMeta>,
node: Arc<LogicalPlan>,
can_rename_columns: bool,
mut values: Vec<Option<String>>,
values: Vec<Option<String>>,
parent_data_source: Option<String>,
) -> Pin<Box<dyn Future<Output = result::Result<SqlGenerationResult, CubeError>> + Send>> {
Box::pin(async move {
match node.as_ref() {
Expand Down Expand Up @@ -435,7 +441,7 @@ impl CubeScanWrapperNode {
Some(Arc::new(cube_scan_node.clone()))
} else {
return Err(CubeError::internal(format!(
"Expected CubeScan node but found: {:?}",
"Expected ubeScan node but found: {:?}",
plan
)));
}
Expand All @@ -448,37 +454,12 @@ impl CubeScanWrapperNode {
} else {
None
};
let mut subqueries_sql = HashMap::new();
for subquery in subqueries.iter() {
let SqlGenerationResult {
data_source: _,
from_alias: _,
column_remapping: _,
sql,
request: _,
} = Self::generate_sql_for_node(
plan.clone(),
transport.clone(),
load_request_meta.clone(),
subquery.clone(),
true,
values,
)
.await?;

let (sql_string, new_values) = sql.unpack();
values = new_values;

let field = subquery.schema().field(0);
subqueries_sql.insert(field.qualified_name(), sql_string);
}

let subqueries_sql = Arc::new(subqueries_sql);
let SqlGenerationResult {
data_source,
from_alias,
column_remapping,
sql,
mut sql,
request,
} = if let Some(ungrouped_scan_node) = ungrouped_scan_node.clone() {
let data_sources = ungrouped_scan_node
Expand All @@ -499,7 +480,7 @@ impl CubeScanWrapperNode {
ungrouped_scan_node
)));
}
let sql = SqlQuery::new("".to_string(), values);
let sql = SqlQuery::new("".to_string(), values.clone());
SqlGenerationResult {
data_source: Some(data_sources[0].clone()),
from_alias: ungrouped_scan_node
Expand All @@ -519,10 +500,37 @@ impl CubeScanWrapperNode {
load_request_meta.clone(),
from.clone(),
true,
values,
values.clone(),
parent_data_source.clone(),
)
.await?
};

let mut subqueries_sql = HashMap::new();
for subquery in subqueries.iter() {
let SqlGenerationResult {
data_source: _,
from_alias: _,
column_remapping: _,
sql: subquery_sql,
request: _,
} = Self::generate_sql_for_node(
plan.clone(),
transport.clone(),
load_request_meta.clone(),
subquery.clone(),
true,
sql.values.clone(),
data_source.clone(),
)
.await?;

let (sql_string, new_values) = subquery_sql.unpack();
sql.extend_values(&new_values);
let field = subquery.schema().field(0);
subqueries_sql.insert(field.qualified_name(), sql_string);
}
let subqueries_sql = Arc::new(subqueries_sql);
let mut next_remapping = HashMap::new();
let alias = alias.or(from_alias.clone());
if let Some(data_source) = data_source {
Expand Down Expand Up @@ -825,6 +833,13 @@ impl CubeScanWrapperNode {
)));
}
}
LogicalPlan::EmptyRelation(_) => Ok(SqlGenerationResult {
data_source: parent_data_source,
from_alias: None,
sql: SqlQuery::new("".to_string(), values.clone()),
column_remapping: None,
request: V1LoadRequestQuery::new(),
}),
// LogicalPlan::Distinct(_) => {}
x => {
return Err(CubeError::internal(format!(
Expand Down
153 changes: 143 additions & 10 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use self::{
},
parser::parse_sql_to_statement,
qtrace::Qtrace,
rewrite::converter::LogicalPlanToLanguageConverter,
rewrite::converter::{LogicalPlanToLanguageContext, LogicalPlanToLanguageConverter},
};
use crate::{
sql::{
Expand Down Expand Up @@ -1313,7 +1313,11 @@ WHERE `TABLE_SCHEMA` = '{}'",
let mut converter = LogicalPlanToLanguageConverter::new(cube_ctx.clone());
let mut query_params = Some(HashMap::new());
let root = converter
.add_logical_plan_replace_params(&optimized_plan, &mut query_params)
.add_logical_plan_replace_params(
&optimized_plan,
&mut query_params,
&mut LogicalPlanToLanguageContext::default(),
)
.map_err(|e| CompilationError::internal(e.to_string()))?;

let mut finalized_graph = self
Expand Down Expand Up @@ -20115,14 +20119,149 @@ ORDER BY "source"."str0" ASC
}

#[tokio::test]
async fn test_simple_subquery_wrapper_projection() {
async fn test_simple_subquery_wrapper_projection_empty_source() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT (SELECT 'male' where 1 group by 'male' having 1 order by 'male' limit 1) as gender, avgPrice FROM KibanaSampleDataEcommerce a"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("(SELECT"));
assert!(sql.contains("utf8__male__"));

let _physical_plan = query_plan.as_physical_plan().await.unwrap();
//println!("phys plan {:?}", physical_plan);
}

#[tokio::test]
async fn test_simple_subquery_wrapper_filter_empty_source() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT (SELECT customer_gender FROM KibanaSampleDataEcommerce WHERE customer_gender = 'male' LIMIT 1) as gender, avgPrice FROM KibanaSampleDataEcommerce a"
"SELECT avgPrice FROM KibanaSampleDataEcommerce a where customer_gender = (SELECT 'male' )"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("(SELECT"));
assert!(sql.contains("utf8__male__"));

let _physical_plan = query_plan.as_physical_plan().await.unwrap();
//println!("phys plan {:?}", physical_plan);
}

#[tokio::test]
async fn test_simple_subquery_wrapper_projection_aggregate_empty_source() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT (SELECT 'male'), avg(avgPrice) FROM KibanaSampleDataEcommerce a GROUP BY 1"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("(SELECT"));
assert!(sql.contains("utf8__male__"));

let _physical_plan = query_plan.as_physical_plan().await.unwrap();
}

#[tokio::test]
async fn test_simple_subquery_wrapper_filter_in_empty_source() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT customer_gender, avgPrice FROM KibanaSampleDataEcommerce a where customer_gender in (select 'male')"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("IN (SELECT"));
assert!(sql.contains("utf8__male__"));

let _physical_plan = query_plan.as_physical_plan().await.unwrap();
}

#[tokio::test]
async fn test_simple_subquery_wrapper_filter_and_projection_empty_source() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT (select 'male'), avgPrice FROM KibanaSampleDataEcommerce a where customer_gender in (select 'female')"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();

let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("IN (SELECT"));
assert!(sql.contains("(SELECT"));
assert!(sql.contains("utf8__male__"));
assert!(sql.contains("utf8__female__"));

let _physical_plan = query_plan.as_physical_plan().await.unwrap();
}

#[tokio::test]
async fn test_simple_subquery_wrapper_projection_1() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT (SELECT customer_gender FROM KibanaSampleDataEcommerce LIMIT 1) as gender, avgPrice FROM KibanaSampleDataEcommerce a"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
Expand Down Expand Up @@ -20166,12 +20305,6 @@ ORDER BY "source"."str0" ASC
.unwrap()
.sql
.contains("(SELECT"));
assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("LIMIT 1"));

let _physical_plan = query_plan.as_physical_plan().await.unwrap();
}
Expand Down

0 comments on commit 86a58a5

Please sign in to comment.