From e9189fbc8f4b2454cdcf882296927418a8aeeffe Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Sun, 5 Oct 2025 19:07:10 +0200 Subject: [PATCH 1/9] initial query hook support --- datafusion-postgres/src/handlers.rs | 22 +++++++++++++++++++++- datafusion-postgres/src/lib.rs | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 24d5beb..c36fc97 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -23,6 +23,16 @@ use tokio::sync::Mutex; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; +#[async_trait] +pub trait QueryHook: Send + Sync { + async fn handle_query( + &self, + query: &str, + session_context: &SessionContext, + client: &dyn ClientInfo, + ) -> Option>>>; +} + #[derive(Debug, Clone, Copy, PartialEq)] pub enum TransactionState { None, @@ -44,7 +54,7 @@ pub struct HandlerFactory { impl HandlerFactory { pub fn new(session_context: Arc, auth_manager: Arc) -> Self { let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone())); + Arc::new(DfSessionService::new(session_context, auth_manager.clone(), None)); HandlerFactory { session_service } } } @@ -70,12 +80,14 @@ pub struct DfSessionService { timezone: Arc>, transaction_state: Arc>, auth_manager: Arc, + query_hook: Option>, } impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, + query_hook: Option>, ) -> DfSessionService { let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -86,6 +98,7 @@ impl DfSessionService { timezone: Arc::new(Mutex::new("UTC".to_string())), transaction_state: Arc::new(Mutex::new(TransactionState::None)), auth_manager, + query_hook, } } @@ -374,6 +387,13 @@ impl SimpleQueryHandler for DfSessionService { } } + // Check query hook first + if let Some(hook) = &self.query_hook { + if let Some(result) = hook.handle_query(query, &self.session_context, client).await { + return result; + } + } + let df_result = self.session_context.sql(query).await; // Handle query execution errors and transaction state diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index ba43d00..bdf4903 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor; use crate::auth::AuthManager; use handlers::HandlerFactory; -pub use handlers::{DfSessionService, Parser}; +pub use handlers::{DfSessionService, Parser, QueryHook}; /// re-exports pub use arrow_pg; From 0006fe17434ada0bc154eb804762255127b1664a Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Sun, 5 Oct 2025 19:27:37 +0200 Subject: [PATCH 2/9] introduce QueryHook --- datafusion-postgres/src/handlers.rs | 40 ++++++++++++++++++++++++----- datafusion-postgres/src/lib.rs | 2 +- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index c36fc97..5634d6c 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -23,11 +23,14 @@ use tokio::sync::Mutex; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; +/// Statement type represents a parsed SQL query with its logical plan +pub type Statement = (String, LogicalPlan); + #[async_trait] pub trait QueryHook: Send + Sync { async fn handle_query( &self, - query: &str, + statement: &Statement, session_context: &SessionContext, client: &dyn ClientInfo, ) -> Option>>>; @@ -387,11 +390,25 @@ impl SimpleQueryHandler for DfSessionService { } } - // Check query hook first + // Parse query into logical plan for hook if let Some(hook) = &self.query_hook { - if let Some(result) = hook.handle_query(query, &self.session_context, client).await { - return result; + // Create logical plan from query + let state = self.session_context.state(); + let logical_plan_result = state.create_logical_plan(query).await; + + if let Ok(logical_plan) = logical_plan_result { + // Optimize the logical plan + let optimized_result = state.optimize(&logical_plan); + + if let Ok(optimized) = optimized_result { + // Create Statement tuple and call hook + let statement = (query.to_string(), optimized); + if let Some(result) = hook.handle_query(&statement, &self.session_context, client).await { + return result; + } + } } + // If parsing or optimization fails, we'll continue with normal processing } let df_result = self.session_context.sql(query).await; @@ -443,7 +460,7 @@ impl SimpleQueryHandler for DfSessionService { #[async_trait] impl ExtendedQueryHandler for DfSessionService { - type Statement = (String, LogicalPlan); + type Statement = Statement; type QueryParser = Parser; fn query_parser(&self) -> Arc { @@ -513,6 +530,17 @@ impl ExtendedQueryHandler for DfSessionService { .to_string(); log::debug!("Received execute extended query: {}", query); // Log for debugging + // Check query hook first + if let Some(hook) = &self.query_hook { + if let Some(result) = hook.handle_query(&portal.statement.statement, &self.session_context, client).await { + // Convert Vec to single Response + // For extended query, we expect a single response + if let Some(response) = result?.into_iter().next() { + return Ok(response); + } + } + } + // Check permissions for the query (skip for SET and SHOW statements) if !query.starts_with("set") && !query.starts_with("show") { self.check_query_permission(client, &portal.statement.statement.0) @@ -553,7 +581,7 @@ pub struct Parser { #[async_trait] impl QueryParser for Parser { - type Statement = (String, LogicalPlan); + type Statement = Statement; async fn parse_sql( &self, diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index bdf4903..57864e3 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor; use crate::auth::AuthManager; use handlers::HandlerFactory; -pub use handlers::{DfSessionService, Parser, QueryHook}; +pub use handlers::{DfSessionService, Parser, QueryHook, Statement}; /// re-exports pub use arrow_pg; From 81d0f42406d5e1a4b741f7f622be9405b56a476a Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Sun, 5 Oct 2025 20:00:15 +0200 Subject: [PATCH 3/9] switch to accepting datafusion statements --- datafusion-postgres/src/handlers.rs | 66 +++++++++++++++-------------- datafusion-postgres/src/lib.rs | 2 +- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index ec73224..3e62525 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -30,9 +30,6 @@ use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; -/// Statement type represents a parsed SQL query with its logical plan -pub type Statement = (String, Option); - #[async_trait] pub trait QueryHook: Send + Sync { async fn handle_query( @@ -40,7 +37,7 @@ pub trait QueryHook: Send + Sync { statement: &Statement, session_context: &SessionContext, client: &dyn ClientInfo, - ) -> Option>>>; + ) -> Option>>; } #[derive(Debug, Clone, Copy, PartialEq)] @@ -66,8 +63,11 @@ pub struct HandlerFactory { impl HandlerFactory { pub fn new(session_context: Arc, auth_manager: Arc) -> Self { - let session_service = - Arc::new(DfSessionService::new(session_context, auth_manager.clone(), None)); + let session_service = Arc::new(DfSessionService::new( + session_context, + auth_manager.clone(), + None, + )); HandlerFactory { session_service } } } @@ -491,26 +491,16 @@ impl SimpleQueryHandler for DfSessionService { self.check_query_permission(client, &query).await?; } - // Parse query into logical plan for hook - if let Some(hook) = &self.query_hook { - // Create logical plan from query - let state = self.session_context.state(); - let logical_plan_result = state.create_logical_plan(query).await; - - if let Ok(logical_plan) = logical_plan_result { - // Optimize the logical plan - let optimized_result = state.optimize(&logical_plan); - - if let Ok(optimized) = optimized_result { - // Create Statement tuple and call hook - let statement = (query.to_string(), optimized); - if let Some(result) = hook.handle_query(&statement, &self.session_context, client).await { - return result; - } + // Call query hook with the parsed statement + if let Some(hook) = &self.query_hook { + let wrapped_statement = Statement::Statement(Box::new(statement.clone())); + if let Some(result) = hook + .handle_query(&wrapped_statement, &self.session_context, client) + .await + { + return result; } } - // If parsing or optimization fails, we'll continue with normal processing - } if let Some(resp) = self .try_respond_set_statements(client, &query_lower) @@ -578,7 +568,7 @@ impl SimpleQueryHandler for DfSessionService { #[async_trait] impl ExtendedQueryHandler for DfSessionService { - type Statement = Statement; + type Statement = (String, Option); type QueryParser = Parser; fn query_parser(&self) -> Arc { @@ -656,11 +646,25 @@ impl ExtendedQueryHandler for DfSessionService { // Check query hook first if let Some(hook) = &self.query_hook { - if let Some(result) = hook.handle_query(&portal.statement.statement, &self.session_context, client).await { - // Convert Vec to single Response - // For extended query, we expect a single response - if let Some(response) = result?.into_iter().next() { - return Ok(response); + // Parse the SQL to get the Statement for the hook + let sql = &portal.statement.statement.0; + let statements = self + .parser + .sql_parser + .parse(sql) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + if let Some(statement) = statements.into_iter().next() { + let wrapped_statement = Statement::Statement(Box::new(statement)); + if let Some(result) = hook + .handle_query(&wrapped_statement, &self.session_context, client) + .await + { + // Convert Vec to single Response + // For extended query, we expect a single response + if let Some(response) = result?.into_iter().next() { + return Ok(response); + } } } } @@ -837,7 +841,7 @@ impl Parser { #[async_trait] impl QueryParser for Parser { - type Statement = Statement; + type Statement = (String, Option); async fn parse_sql( &self, diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index fd9d7ea..9cb568a 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -20,7 +20,7 @@ use tokio_rustls::TlsAcceptor; use crate::auth::AuthManager; use handlers::HandlerFactory; -pub use handlers::{DfSessionService, Parser, QueryHook, Statement}; +pub use handlers::{DfSessionService, Parser, QueryHook}; /// re-exports pub use arrow_pg; From a373572cd77e331e74134ac191d88e57a5e3850c Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Sun, 5 Oct 2025 20:01:24 +0200 Subject: [PATCH 4/9] remove transaction state --- datafusion-postgres/src/handlers.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 3e62525..a642154 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -40,13 +40,6 @@ pub trait QueryHook: Send + Sync { ) -> Option>>; } -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum TransactionState { - None, - Active, - Failed, -} - // Metadata keys for session-level settings const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms"; @@ -653,7 +646,7 @@ impl ExtendedQueryHandler for DfSessionService { .sql_parser .parse(sql) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - + if let Some(statement) = statements.into_iter().next() { let wrapped_statement = Statement::Statement(Box::new(statement)); if let Some(result) = hook From c1747476b93e3936bd8648ffbdf633fa793ccea0 Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Mon, 6 Oct 2025 00:03:45 +0200 Subject: [PATCH 5/9] accept a vector of hooks instead of options of hooks --- datafusion-postgres/src/handlers.rs | 26 ++++++++++++++----------- datafusion-postgres/src/lib.rs | 2 +- datafusion-postgres/tests/common/mod.rs | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index a642154..207770e 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -55,11 +55,15 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new(session_context: Arc, auth_manager: Arc) -> Self { + pub fn new( + session_context: Arc, + auth_manager: Arc, + query_hooks: Vec>, + ) -> Self { let session_service = Arc::new(DfSessionService::new( session_context, auth_manager.clone(), - None, + query_hooks, )); HandlerFactory { session_service } } @@ -100,14 +104,14 @@ pub struct DfSessionService { parser: Arc, timezone: Arc>, auth_manager: Arc, - query_hook: Option>, + query_hooks: Vec>, } impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, - query_hook: Option>, + query_hooks: Vec>, ) -> DfSessionService { let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -118,7 +122,7 @@ impl DfSessionService { parser, timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, - query_hook, + query_hooks, } } @@ -484,8 +488,8 @@ impl SimpleQueryHandler for DfSessionService { self.check_query_permission(client, &query).await?; } - // Call query hook with the parsed statement - if let Some(hook) = &self.query_hook { + // Call query hooks with the parsed statement + for hook in &self.query_hooks { let wrapped_statement = Statement::Statement(Box::new(statement.clone())); if let Some(result) = hook .handle_query(&wrapped_statement, &self.session_context, client) @@ -637,8 +641,8 @@ impl ExtendedQueryHandler for DfSessionService { .to_string(); log::debug!("Received execute extended query: {query}"); // Log for debugging - // Check query hook first - if let Some(hook) = &self.query_hook { + // Check query hooks first + for hook in &self.query_hooks { // Parse the SQL to get the Statement for the hook let sql = &portal.statement.statement.0; let statements = self @@ -961,7 +965,7 @@ mod tests { async fn test_statement_timeout_set_and_show() { let session_context = Arc::new(SessionContext::new()); let auth_manager = Arc::new(AuthManager::new()); - let service = DfSessionService::new(session_context, auth_manager); + let service = DfSessionService::new(session_context, auth_manager, vec![]); let mut client = MockClient::new(); // Test setting timeout to 5000ms @@ -987,7 +991,7 @@ mod tests { async fn test_statement_timeout_disable() { let session_context = Arc::new(SessionContext::new()); let auth_manager = Arc::new(AuthManager::new()); - let service = DfSessionService::new(session_context, auth_manager); + let service = DfSessionService::new(session_context, auth_manager, vec![]); let mut client = MockClient::new(); // Set timeout first diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index 9cb568a..cdfe6dd 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -85,7 +85,7 @@ pub async fn serve( auth_manager: Arc, ) -> Result<(), std::io::Error> { // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, vec![])); serve_with_handlers(factory, opts).await } diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/tests/common/mod.rs index 054b38d..46512bd 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/tests/common/mod.rs @@ -20,7 +20,7 @@ pub fn setup_handlers() -> DfSessionService { ) .expect("Failed to setup sesession context"); - DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new())) + DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()), vec![]) } #[derive(Debug, Default)] From b3e3360073211267e1342eae64cc4feacd3c8e12 Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Mon, 6 Oct 2025 00:07:59 +0200 Subject: [PATCH 6/9] add a small test --- datafusion-postgres/src/handlers.rs | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 207770e..d834172 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -1009,4 +1009,46 @@ mod tests { let timeout = DfSessionService::get_statement_timeout(&client); assert_eq!(timeout, None); } + + struct TestHook; + + #[async_trait] + impl QueryHook for TestHook { + async fn handle_query( + &self, + statement: &Statement, + _ctx: &SessionContext, + _client: &dyn ClientInfo, + ) -> Option>> { + if statement.to_string().contains("magic") { + Some(Ok(vec![Response::EmptyQuery])) + } else { + None + } + } + } + + #[tokio::test] + async fn test_query_hooks() { + let hook = TestHook; + let ctx = SessionContext::new(); + let client = MockClient::new(); + + // Parse a statement that contains "magic" + let parser = PostgresCompatibilityParser::new(); + let statements = parser.parse("SELECT magic").unwrap(); + let stmt = Statement::Statement(Box::new(statements[0].clone())); + + // Hook should intercept + let result = hook.handle_query(&stmt, &ctx, &client).await; + assert!(result.is_some()); + + // Parse a normal statement + let statements = parser.parse("SELECT 1").unwrap(); + let stmt = Statement::Statement(Box::new(statements[0].clone())); + + // Hook should not intercept + let result = hook.handle_query(&stmt, &ctx, &client).await; + assert!(result.is_none()); + } } From 3bd7513a364be0468a33fb1fe17eab91ba7e0aa4 Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Mon, 6 Oct 2025 00:10:25 +0200 Subject: [PATCH 7/9] run cargo fmt --- datafusion-postgres/src/handlers.rs | 8 ++++---- datafusion-postgres/tests/common/mod.rs | 6 +++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index d834172..5ac4974 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -1033,20 +1033,20 @@ mod tests { let hook = TestHook; let ctx = SessionContext::new(); let client = MockClient::new(); - + // Parse a statement that contains "magic" let parser = PostgresCompatibilityParser::new(); let statements = parser.parse("SELECT magic").unwrap(); let stmt = Statement::Statement(Box::new(statements[0].clone())); - + // Hook should intercept let result = hook.handle_query(&stmt, &ctx, &client).await; assert!(result.is_some()); - + // Parse a normal statement let statements = parser.parse("SELECT 1").unwrap(); let stmt = Statement::Statement(Box::new(statements[0].clone())); - + // Hook should not intercept let result = hook.handle_query(&stmt, &ctx, &client).await; assert!(result.is_none()); diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/tests/common/mod.rs index 46512bd..6c646ff 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/tests/common/mod.rs @@ -20,7 +20,11 @@ pub fn setup_handlers() -> DfSessionService { ) .expect("Failed to setup sesession context"); - DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()), vec![]) + DfSessionService::new( + Arc::new(session_context), + Arc::new(AuthManager::new()), + vec![], + ) } #[derive(Debug, Default)] From 822d6240a766884bc32dbd84decf90a805f1ea9d Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Tue, 7 Oct 2025 20:14:05 +0200 Subject: [PATCH 8/9] update trait to accept sqlparser::ast::Statement instead, so we don't need to wrap in a datafusion::Statement --- datafusion-postgres/src/handlers.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 5ac4974..68ac76b 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -27,6 +27,7 @@ use tokio::sync::Mutex; use crate::auth::AuthManager; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; +use datafusion::sql::sqlparser; use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; @@ -34,7 +35,7 @@ use datafusion_pg_catalog::sql::PostgresCompatibilityParser; pub trait QueryHook: Send + Sync { async fn handle_query( &self, - statement: &Statement, + statement: &sqlparser::ast::Statement, session_context: &SessionContext, client: &dyn ClientInfo, ) -> Option>>; @@ -490,9 +491,8 @@ impl SimpleQueryHandler for DfSessionService { // Call query hooks with the parsed statement for hook in &self.query_hooks { - let wrapped_statement = Statement::Statement(Box::new(statement.clone())); if let Some(result) = hook - .handle_query(&wrapped_statement, &self.session_context, client) + .handle_query(&statement, &self.session_context, client) .await { return result; @@ -652,9 +652,8 @@ impl ExtendedQueryHandler for DfSessionService { .map_err(|e| PgWireError::ApiError(Box::new(e)))?; if let Some(statement) = statements.into_iter().next() { - let wrapped_statement = Statement::Statement(Box::new(statement)); if let Some(result) = hook - .handle_query(&wrapped_statement, &self.session_context, client) + .handle_query(&statement, &self.session_context, client) .await { // Convert Vec to single Response From b366428ca0b509bd50d262d8d57e52e67d46e287 Mon Sep 17 00:00:00 2001 From: Anthony Alaribe Date: Tue, 7 Oct 2025 20:20:45 +0200 Subject: [PATCH 9/9] since handle_query only accepts a single statement, only a single response is expected. --- datafusion-postgres/src/handlers.rs | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 68ac76b..a3a8374 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -38,7 +38,7 @@ pub trait QueryHook: Send + Sync { statement: &sqlparser::ast::Statement, session_context: &SessionContext, client: &dyn ClientInfo, - ) -> Option>>; + ) -> Option>; } // Metadata keys for session-level settings @@ -495,7 +495,7 @@ impl SimpleQueryHandler for DfSessionService { .handle_query(&statement, &self.session_context, client) .await { - return result; + return result.map(|response| vec![response]); } } @@ -656,11 +656,7 @@ impl ExtendedQueryHandler for DfSessionService { .handle_query(&statement, &self.session_context, client) .await { - // Convert Vec to single Response - // For extended query, we expect a single response - if let Some(response) = result?.into_iter().next() { - return Ok(response); - } + return result; } } } @@ -1015,12 +1011,12 @@ mod tests { impl QueryHook for TestHook { async fn handle_query( &self, - statement: &Statement, + statement: &sqlparser::ast::Statement, _ctx: &SessionContext, _client: &dyn ClientInfo, - ) -> Option>> { + ) -> Option> { if statement.to_string().contains("magic") { - Some(Ok(vec![Response::EmptyQuery])) + Some(Ok(Response::EmptyQuery)) } else { None } @@ -1036,18 +1032,18 @@ mod tests { // Parse a statement that contains "magic" let parser = PostgresCompatibilityParser::new(); let statements = parser.parse("SELECT magic").unwrap(); - let stmt = Statement::Statement(Box::new(statements[0].clone())); + let stmt = &statements[0]; // Hook should intercept - let result = hook.handle_query(&stmt, &ctx, &client).await; + let result = hook.handle_query(stmt, &ctx, &client).await; assert!(result.is_some()); // Parse a normal statement let statements = parser.parse("SELECT 1").unwrap(); - let stmt = Statement::Statement(Box::new(statements[0].clone())); + let stmt = &statements[0]; // Hook should not intercept - let result = hook.handle_query(&stmt, &ctx, &client).await; + let result = hook.handle_query(stmt, &ctx, &client).await; assert!(result.is_none()); } }