Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 65 additions & 64 deletions datafusion-postgres/src/sql.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashSet;
use std::ops::ControlFlow;
use std::sync::Arc;

use datafusion::sql::sqlparser::ast::Expr;
Expand All @@ -13,6 +14,8 @@ use datafusion::sql::sqlparser::ast::Statement;
use datafusion::sql::sqlparser::ast::TableFactor;
use datafusion::sql::sqlparser::ast::TableWithJoins;
use datafusion::sql::sqlparser::ast::Value;
use datafusion::sql::sqlparser::ast::VisitMut;
use datafusion::sql::sqlparser::ast::VisitorMut;
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::parser::ParserError;
Expand Down Expand Up @@ -272,8 +275,16 @@ impl RemoveUnsupportedTypes {

Self { unsupported_types }
}
}

struct RemoveUnsupportedTypesVisitor<'a> {
unsupported_types: &'a HashSet<String>,
}

impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> {
type Break = ();

fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) {
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
match expr {
// This is the key part: identify constants with type annotations.
Expr::TypedString { value, data_type } => {
Expand All @@ -297,65 +308,58 @@ impl RemoveUnsupportedTypes {
*expr = *value.clone();
}
}
// Handle binary operations by recursively rewriting both sides.
Expr::BinaryOp { left, right, .. } => {
self.rewrite_expr_unsupported_types(left);
self.rewrite_expr_unsupported_types(right);
}
// Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
_ => {}
}

ControlFlow::Continue(())
}
}

impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
fn rewrite(&self, mut s: Statement) -> Statement {
// Traverse the AST to find the WHERE clause and rewrite it.
if let Statement::Query(query) = &mut s {
if let SetExpr::Select(select) = query.body.as_mut() {
if let Some(expr) = &mut select.selection {
self.rewrite_expr_unsupported_types(expr);
}
}
}

s
fn rewrite(&self, mut statement: Statement) -> Statement {
let mut visitor = RemoveUnsupportedTypesVisitor {
unsupported_types: &self.unsupported_types,
};
let _ = statement.visit(&mut visitor);
statement
}
}

#[cfg(test)]
mod tests {
use super::*;

macro_rules! assert_rewrite {
($rules:expr, $orig:expr, $rewt:expr) => {
let sql = $orig;
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, $rules);
assert_eq!(statement.to_string(), $rewt);
};
}

#[test]
fn test_alias_rewrite() {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(AliasDuplicatedProjectionRewrite)];

let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
"SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
);

let sql = "SELECT oid, * FROM pg_catalog.pg_namespace";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT oid, * FROM pg_catalog.pg_namespace",
"SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
);

let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
);
}
Expand All @@ -365,30 +369,21 @@ mod tests {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(ResolveUnqualifiedIdentifer)];

let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
);

let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
);

let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
);
}
Expand All @@ -398,21 +393,27 @@ mod tests {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(RemoveUnsupportedTypes::new())];

let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
);

let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);
assert_rewrite!(
&rules,
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
);

assert_rewrite!(
&rules,
"SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
assert_rewrite!(
&rules,
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
);
}
Expand Down
Loading