diff --git a/rust/cubesql/cubesql/e2e/tests/postgres.rs b/rust/cubesql/cubesql/e2e/tests/postgres.rs index 3697674360..b65183e8cb 100644 --- a/rust/cubesql/cubesql/e2e/tests/postgres.rs +++ b/rust/cubesql/cubesql/e2e/tests/postgres.rs @@ -924,6 +924,92 @@ impl PostgresIntegrationTestSuite { Ok(()) } + async fn test_temp_tables(&self) -> RunResult<()> { + // Create temporary table in current session + self.test_simple_query( + r#" + CREATE TEMPORARY TABLE temp_table AS + SELECT 5 AS i, 'c' AS s + UNION ALL + SELECT 10 AS i, 'd' AS s + "# + .to_string(), + |messages| { + let SimpleQueryMessage::CommandComplete(rows) = &messages[0] else { + panic!("Must be CommandComplete"); + }; + + assert_eq!(*rows, 2); + }, + ) + .await?; + + // Check that we can query it and we get the correct data + self.test_simple_query( + "SELECT i AS i, s AS s FROM temp_table GROUP BY 1, 2 ORDER BY i ASC".to_string(), + |messages| { + assert_eq!(messages.len(), 3); + + let SimpleQueryMessage::Row(row) = &messages[0] else { + panic!("Must be Row, 0"); + }; + + assert_eq!(row.get(0), Some("5")); + assert_eq!(row.get(1), Some("c")); + + let SimpleQueryMessage::Row(row) = &messages[1] else { + panic!("Must be Row, 1"); + }; + + assert_eq!(row.get(0), Some("10")); + assert_eq!(row.get(1), Some("d")); + + let SimpleQueryMessage::CommandComplete(rows) = &messages[2] else { + panic!("Must be CommandComplete, 2"); + }; + + assert_eq!(*rows, 2); + }, + ) + .await?; + + // Other sessions must have no access to temp tables + let new_client = Self::create_client( + format!( + "host=127.0.0.1 port={} dbname=meow user=test password=test", + self.port + ) + .parse() + .unwrap(), + ) + .await; + + let result = new_client + .simple_query("SELECT i AS i, s AS s FROM temp_table GROUP BY 1, 2 ORDER BY i ASC") + .await; + assert!(result.is_err()); + + // Drop table, make sure we can't query it anymore + self.test_simple_query("DROP TABLE temp_table".to_string(), |messages| { + let SimpleQueryMessage::CommandComplete(rows) = &messages[0] else { + panic!("Must be CommandComplete"); + }; + + assert_eq!(*rows, 0); + }) + .await?; + + let result = self + .test_simple_query( + "SELECT i AS i, s AS s FROM temp_table GROUP BY 1, 2 ORDER BY i ASC".to_string(), + |_| {}, + ) + .await; + assert!(result.is_err()); + + Ok(()) + } + fn assert_row(&self, message: &SimpleQueryMessage, expected_value: String) { if let SimpleQueryMessage::Row(row) = message { assert_eq!(row.get(0), Some(expected_value.as_str())); @@ -973,6 +1059,7 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite { self.test_df_panic_handle().await?; self.test_simple_query_discard_all().await?; self.test_database_change().await?; + self.test_temp_tables().await?; // PostgreSQL doesn't support unsigned integers in the protocol, it's a constraint only self.test_snapshot_execute_query( diff --git a/rust/cubesql/cubesql/src/compile/engine/provider.rs b/rust/cubesql/cubesql/src/compile/engine/provider.rs index 6c0974624b..23451cfea2 100644 --- a/rust/cubesql/cubesql/src/compile/engine/provider.rs +++ b/rust/cubesql/cubesql/src/compile/engine/provider.rs @@ -14,7 +14,10 @@ use datafusion::{ use crate::{ compile::MetaContext, - sql::{session::DatabaseProtocol, ColumnType, SessionManager, SessionState}, + sql::{ + session::DatabaseProtocol, temp_tables::TempTableProvider, ColumnType, SessionManager, + SessionState, + }, transport::V1CubeMetaExt, CubeError, }; @@ -272,7 +275,9 @@ impl DatabaseProtocol { table_provider: Arc, ) -> Result { let any = table_provider.as_any(); - Ok(if let Some(t) = any.downcast_ref::() { + Ok(if let Some(t) = any.downcast_ref::() { + t.name().to_string() + } else if let Some(t) = any.downcast_ref::() { t.table_name().to_string() } else if let Some(_) = any.downcast_ref::() { "information_schema.columns".to_string() @@ -399,23 +404,28 @@ impl DatabaseProtocol { table.to_ascii_lowercase(), ), datafusion::catalog::TableReference::Bare { table } => { - if table.starts_with("pg_") { - ( - context.session_state.database().unwrap_or("db".to_string()), - "pg_catalog".to_string(), - table.to_ascii_lowercase(), - ) + let table_lower = table.to_ascii_lowercase(); + let schema = if context.session_state.temp_tables().has(&table_lower) { + "pg_temp_3" + } else if table.starts_with("pg_") { + "pg_catalog" } else { - ( - context.session_state.database().unwrap_or("db".to_string()), - "public".to_string(), - table.to_ascii_lowercase(), - ) - } + "public" + }; + ( + context.session_state.database().unwrap_or("db".to_string()), + schema.to_string(), + table_lower, + ) } }; match schema.as_str() { + "pg_temp_3" => { + if let Some(temp_table) = context.session_state.temp_tables().get(&table) { + return Some(Arc::new(TempTableProvider::new(table, temp_table))); + } + } "public" => { if let Some(cube) = context .meta diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 8087e799a0..7f76f099c7 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -24,7 +24,7 @@ use datafusion::{ use itertools::Itertools; use log::warn; use serde::Serialize; -use sqlparser::ast::{self, escape_single_quote_string}; +use sqlparser::ast::{self, escape_single_quote_string, ObjectName}; use std::{ backtrace::Backtrace, collections::HashMap, env, fmt::Formatter, future::Future, pin::Pin, sync::Arc, time::SystemTime, @@ -84,6 +84,7 @@ use crate::{ ApproximateCountDistinctVisitor, CastReplacer, RedshiftDatePartReplacer, SensitiveDataSanitizer, ToTimestampReplacer, UdfWildcardArgReplacer, }, + temp_tables::TempTableManager, types::{CommandCompletion, StatusFlags}, ColumnFlags, ColumnType, Session, SessionManager, SessionState, }, @@ -279,6 +280,7 @@ impl QueryPlanner { if schema_name.to_lowercase() == "information_schema" || schema_name.to_lowercase() == "performance_schema" || schema_name.to_lowercase() == "pg_catalog" + || schema_name.to_lowercase() == "pg_temp_3" { return self .create_df_logical_plan(stmt.clone(), qtrace, span_id.clone()) @@ -435,6 +437,12 @@ impl QueryPlanner { ) -> CompilationResult { let plan = match (stmt, &self.state.protocol) { (ast::Statement::Query(q), _) => { + if let ast::SetExpr::Select(select) = &q.body { + if let Some(into) = &select.into { + return self.select_into_to_plan(into, q, qtrace, span_id).await; + } + } + self.select_to_plan(stmt, q, qtrace, span_id.clone()).await } (ast::Statement::SetTransaction { .. }, _) => Ok(QueryPlan::MetaTabular( @@ -546,6 +554,34 @@ impl QueryPlanner { CommandCompletion::Discard(object_type.to_string()), )) } + ( + ast::Statement::CreateTable { + query: Some(query), + name, + columns, + constraints, + table_properties, + with_options, + temporary, + .. + }, + DatabaseProtocol::PostgreSQL, + ) if columns.is_empty() + && constraints.is_empty() + && table_properties.is_empty() + && with_options.is_empty() + && *temporary => + { + let stmt = ast::Statement::Query(query.clone()); + self.create_table_to_plan(name, &stmt, qtrace, span_id.clone()) + .await + } + ( + ast::Statement::Drop { + object_type, names, .. + }, + DatabaseProtocol::PostgreSQL, + ) if object_type == &ast::ObjectType::Table => self.drop_table_to_plan(names).await, _ => Err(CompilationError::unsupported(format!( "Unsupported query type: {}", stmt.to_string() @@ -971,7 +1007,9 @@ WHERE `TABLE_SCHEMA` = '{}'", )])], )), )), - QueryPlan::DataFusionSelect(flags, plan, context) => { + QueryPlan::DataFusionSelect(flags, plan, context) + | QueryPlan::CreateTempTable(flags, plan, context, _, _) => { + // EXPLAIN over CREATE TABLE AS shows the SELECT query plan let plan = Arc::new(plan); let schema = LogicalPlan::explain_schema(); let schema = schema.to_dfschema_ref().map_err(|err| { @@ -1220,6 +1258,91 @@ WHERE `TABLE_SCHEMA` = '{}'", } } + async fn create_table_to_plan( + &self, + name: &ast::ObjectName, + stmt: &ast::Statement, + qtrace: &mut Option, + span_id: Option>, + ) -> Result { + let ast::Statement::Query(query) = stmt else { + return Err(CompilationError::internal( + "statement is unexpectedly not a Query".to_string(), + )); + }; + let plan = self.select_to_plan(stmt, query, qtrace, span_id).await?; + let QueryPlan::DataFusionSelect(flags, plan, ctx) = plan else { + return Err(CompilationError::internal( + "unable to build DataFusion plan from Query".to_string(), + )); + }; + + let ObjectName(ident_parts) = name; + let Some(table_name) = ident_parts.last() else { + return Err(CompilationError::internal( + "table name contains no ident parts".to_string(), + )); + }; + Ok(QueryPlan::CreateTempTable( + flags, + plan, + ctx, + table_name.value.to_string(), + self.state.temp_tables(), + )) + } + + async fn select_into_to_plan( + &self, + into: &ast::SelectInto, + query: &Box, + qtrace: &mut Option, + span_id: Option>, + ) -> Result { + if !into.temporary || !into.table { + return Err(CompilationError::unsupported( + "only TEMPORARY TABLE is supported for SELECT INTO".to_string(), + )); + } + + let mut new_query = query.clone(); + if let ast::SetExpr::Select(ref mut select) = new_query.body { + select.into = None + } else { + return Err(CompilationError::internal( + "query is unexpectedly not SELECT".to_string(), + )); + } + let new_stmt = ast::Statement::Query(new_query); + self.create_table_to_plan(&into.name, &new_stmt, qtrace, span_id) + .await + } + + async fn drop_table_to_plan( + &self, + names: &[ast::ObjectName], + ) -> Result { + if names.len() != 1 { + return Err(CompilationError::unsupported( + "DROP TABLE supports dropping only one table at a time".to_string(), + )); + } + let ObjectName(ident_parts) = names.first().unwrap(); + let Some(table_name) = ident_parts.last() else { + return Err(CompilationError::internal( + "table name contains no ident parts".to_string(), + )); + }; + let table_name_lower = table_name.value.to_ascii_lowercase(); + let temp_tables = self.state.temp_tables(); + tokio::task::spawn_blocking(move || temp_tables.remove(&table_name_lower)) + .await + .map_err(|err| CompilationError::internal(err.to_string()))? + .map_err(|err| CompilationError::internal(err.to_string()))?; + let flags = StatusFlags::empty(); + Ok(QueryPlan::MetaOk(flags, CommandCompletion::DropTable)) + } + fn create_execution_ctx(&self) -> DFSessionContext { let query_planner = Arc::new(CubeQueryPlanner::new( self.session_manager.server.transport.clone(), @@ -1753,6 +1876,14 @@ pub enum QueryPlan { MetaTabular(StatusFlags, Box), // Query will be executed via Data Fusion DataFusionSelect(StatusFlags, LogicalPlan, DFSessionContext), + // Query will be executed via DataFusion and saved to session + CreateTempTable( + StatusFlags, + LogicalPlan, + DFSessionContext, + String, + Arc, + ), } impl fmt::Debug for QueryPlan { @@ -1775,6 +1906,12 @@ impl fmt::Debug for QueryPlan { flags )) }, + QueryPlan::CreateTempTable(flags, _, _, name, _) => { + f.write_str(&format!( + "CreateTempTable(StatusFlags: {:?}, LogicalPlan: hidden, DFSessionContext: hidden, Name: {:?}, SessionState: hidden", + flags, name + )) + }, } } } @@ -1782,7 +1919,8 @@ impl fmt::Debug for QueryPlan { impl QueryPlan { pub fn as_logical_plan(&self) -> LogicalPlan { match self { - QueryPlan::DataFusionSelect(_, plan, _) => plan.clone(), + QueryPlan::DataFusionSelect(_, plan, _) + | QueryPlan::CreateTempTable(_, plan, _, _, _) => plan.clone(), QueryPlan::MetaOk(_, _) | QueryPlan::MetaTabular(_, _) => { panic!("This query doesnt have a plan, because it already has values for response") } @@ -1791,10 +1929,13 @@ impl QueryPlan { pub async fn as_physical_plan(&self) -> Result, CubeError> { match self { - QueryPlan::DataFusionSelect(_, plan, ctx) => DataFrame::new(ctx.state.clone(), plan) - .create_physical_plan() - .await - .map_err(|e| CubeError::user(e.to_string())), + QueryPlan::DataFusionSelect(_, plan, ctx) + | QueryPlan::CreateTempTable(_, plan, ctx, _, _) => { + DataFrame::new(ctx.state.clone(), plan) + .create_physical_plan() + .await + .map_err(|e| CubeError::user(e.to_string())) + } QueryPlan::MetaOk(_, _) | QueryPlan::MetaTabular(_, _) => { panic!("This query doesnt have a plan, because it already has values for response") } @@ -1803,7 +1944,8 @@ impl QueryPlan { pub fn print(&self, pretty: bool) -> Result { match self { - QueryPlan::DataFusionSelect(_, plan, _) => { + QueryPlan::DataFusionSelect(_, plan, _) + | QueryPlan::CreateTempTable(_, plan, _, _, _) => { if pretty { Ok(plan.display_indent().to_string()) } else { @@ -5716,7 +5858,7 @@ from output.push(frame.print()); output_flags = flags; } - QueryPlan::MetaOk(flags, _) => { + QueryPlan::MetaOk(flags, _) | QueryPlan::CreateTempTable(flags, _, _, _, _) => { output_flags = flags; } } @@ -10223,12 +10365,7 @@ from get_test_session(DatabaseProtocol::PostgreSQL, meta).await, ) .await; - match select_into_query { - Err(CompilationError::Unsupported(msg, _)) => { - assert_eq!(msg, "Unsupported query type: SELECT INTO") - } - _ => panic!("SELECT INTO should throw CompilationError::unsupported"), - } + assert!(select_into_query.is_ok()); } // This tests asserts that our DF fork contains support for IS TRUE|FALSE diff --git a/rust/cubesql/cubesql/src/sql/mod.rs b/rust/cubesql/cubesql/src/sql/mod.rs index 147e64632e..d61d105d04 100644 --- a/rust/cubesql/cubesql/src/sql/mod.rs +++ b/rust/cubesql/cubesql/src/sql/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod service; pub(crate) mod session; pub(crate) mod session_manager; pub(crate) mod statement; +pub(crate) mod temp_tables; pub(crate) mod types; pub use auth_service::{ diff --git a/rust/cubesql/cubesql/src/sql/mysql/service.rs b/rust/cubesql/cubesql/src/sql/mysql/service.rs index dbc893b0d7..7fa82e67ab 100644 --- a/rust/cubesql/cubesql/src/sql/mysql/service.rs +++ b/rust/cubesql/cubesql/src/sql/mysql/service.rs @@ -224,6 +224,7 @@ impl MySqlConnection { return Ok(QueryResponse::ResultSet(status, Box::new(response))) } + crate::compile::QueryPlan::CreateTempTable(_, _, _, _, _) => return Err(CubeError::internal("CREATE TABLE is not supported over MySQL".to_string())), } } diff --git a/rust/cubesql/cubesql/src/sql/postgres/extended.rs b/rust/cubesql/cubesql/src/sql/postgres/extended.rs index 8b1536ebc2..015646897b 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/extended.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/extended.rs @@ -3,6 +3,7 @@ use crate::{ sql::{ dataframe::{batch_to_dataframe, DataFrame, TableValue}, statement::PostgresStatementParamsBinder, + temp_tables::TempTable, writer::BatchWriter, }, CubeError, @@ -19,7 +20,7 @@ use datafusion::{ physical_plan::SendableRecordBatchStream, }; use futures::*; -use pg_srv::protocol::{PortalCompletion, PortalSuspended}; +use pg_srv::protocol::{CommandComplete, PortalCompletion, PortalSuspended}; use crate::transport::SpanId; use async_stream::stream; @@ -510,6 +511,30 @@ impl Portal { Err(err) => return yield Err(CubeError::panic(err).into()), } } + QueryPlan::CreateTempTable(_, plan, ctx, name, temp_tables) => { + let df = DFDataFrame::new(ctx.state.clone(), &plan); + let record_batch = df.collect(); + let row_count = match record_batch.await { + Ok(record_batch) => { + let row_count: u32 = record_batch.iter().map(|batch| batch.num_rows() as u32).sum(); + let temp_table = TempTable::new(Arc::clone(plan.schema()), vec![record_batch]); + let save_result = tokio::task::spawn_blocking(move || { + temp_tables.save(&name.to_ascii_lowercase(), temp_table) + }).await; + if let Err(err) = save_result { + return yield Err(err.into()) + }; + row_count + } + Err(err) => return yield Err(CubeError::panic(Box::new(err)).into()), + }; + + self.state = Some(PortalState::Finished(FinishedState { description })); + + return yield Ok(PortalBatch::Completion(PortalCompletion::Complete( + CommandComplete::Select(row_count), + ))); + } } } PortalState::InExecutionFrame(frame_state) => { diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 0e303e4143..671cce1da8 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -69,7 +69,7 @@ impl QueryPlanExt for QueryPlan { required_format: protocol::Format, ) -> Result, ConnectionError> { match &self { - QueryPlan::MetaOk(_, _) => Ok(None), + QueryPlan::MetaOk(_, _) | QueryPlan::CreateTempTable(_, _, _, _, _) => Ok(None), QueryPlan::MetaTabular(_, frame) => { let mut result = vec![]; diff --git a/rust/cubesql/cubesql/src/sql/session.rs b/rust/cubesql/cubesql/src/sql/session.rs index 56f03d078b..bd26c21e25 100644 --- a/rust/cubesql/cubesql/src/sql/session.rs +++ b/rust/cubesql/cubesql/src/sql/session.rs @@ -15,6 +15,7 @@ use crate::{ DatabaseVariablesToUpdate, }, extended::PreparedStatement, + temp_tables::TempTableManager, }, transport::LoadRequestMeta, RWLockAsync, @@ -91,6 +92,9 @@ pub struct SessionState { // session db variables variables: RwLockSync>, + // session temporary tables + temp_tables: Arc, + properties: RwLockSync, // @todo Remove RWLock after split of Connection & SQLWorker @@ -124,6 +128,7 @@ impl SessionState { client_port, protocol, variables: RwLockSync::new(None), + temp_tables: Arc::new(TempTableManager::new()), properties: RwLockSync::new(SessionProperties::new(None, None)), auth_context: RwLockSync::new((auth_context, SystemTime::now())), transaction: RwLockSync::new(TransactionState::None), @@ -369,6 +374,10 @@ impl SessionState { } } + pub fn temp_tables(&self) -> Arc { + Arc::clone(&self.temp_tables) + } + pub fn get_load_request_meta(&self) -> LoadRequestMeta { let application_name = if let Some(var) = self.get_variable("application_name") { Some(var.value.to_string()) diff --git a/rust/cubesql/cubesql/src/sql/temp_tables.rs b/rust/cubesql/cubesql/src/sql/temp_tables.rs new file mode 100644 index 0000000000..2e75134307 --- /dev/null +++ b/rust/cubesql/cubesql/src/sql/temp_tables.rs @@ -0,0 +1,144 @@ +use std::{any::Any, collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use datafusion::{ + arrow::{ + datatypes::{Schema, SchemaRef}, + record_batch::RecordBatch, + }, + datasource::TableProvider, + error::DataFusionError, + logical_plan::{DFSchema, DFSchemaRef, Expr}, + physical_plan::{memory::MemoryExec, ExecutionPlan}, +}; + +use crate::{CubeError, RWLockSync}; + +#[derive(Debug)] +pub struct TempTableManager { + temp_tables: RWLockSync>>, +} + +impl TempTableManager { + pub fn new() -> Self { + Self { + temp_tables: RWLockSync::new(HashMap::new()), + } + } + + pub fn get(&self, name: &str) -> Option> { + self.temp_tables + .read() + .expect("failed to unlock temp tables for reading") + .get(name) + .cloned() + } + + pub fn has(&self, name: &str) -> bool { + self.temp_tables + .read() + .expect("failed to unlock temp tables for reading") + .contains_key(name) + } + + pub fn save(&self, name: &str, temp_table: TempTable) -> Result<(), CubeError> { + let mut guard = self + .temp_tables + .write() + .expect("failed to unlock temp tables for writing"); + + if guard.contains_key(name) { + return Err(CubeError::user(format!( + "relation \"{}\" already exists", + name + ))); + } + + guard.insert(name.to_string(), Arc::new(temp_table)); + Ok(()) + } + + pub fn remove(&self, name: &str) -> Result<(), CubeError> { + let mut guard = self + .temp_tables + .write() + .expect("failed to unlock temp tables for writing"); + + if guard.remove(name).is_none() { + return Err(CubeError::user(format!( + "table \"{}\" does not exist", + name + ))); + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct TempTable { + schema: SchemaRef, + record_batch: Vec>, +} + +impl TempTable { + pub fn new(schema: DFSchemaRef, record_batch: Vec>) -> Self { + let arrow_schema = df_schema_to_arrow_schema(&schema); + Self { + schema: arrow_schema, + record_batch, + } + } +} + +fn df_schema_to_arrow_schema(df_schema: &DFSchema) -> SchemaRef { + let arrow_schema = Schema::new_with_metadata( + df_schema + .fields() + .iter() + .map(|f| f.field().clone()) + .collect(), + df_schema.metadata().clone(), + ); + Arc::new(arrow_schema) +} + +#[derive(Debug, Clone)] +pub struct TempTableProvider { + name: String, + temp_table: Arc, +} + +impl TempTableProvider { + pub fn new(name: String, temp_table: Arc) -> Self { + Self { name, temp_table } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +#[async_trait] +impl TableProvider for TempTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.temp_table.schema) + } + + async fn scan( + &self, + projection: &Option>, + _filters: &[Expr], + _limit: Option, + ) -> Result, DataFusionError> { + Ok(Arc::new(MemoryExec::try_new( + &self.temp_table.record_batch, + self.schema(), + projection.clone(), + )?)) + } +} diff --git a/rust/cubesql/cubesql/src/sql/types.rs b/rust/cubesql/cubesql/src/sql/types.rs index 274fe742f5..99fbf7b276 100644 --- a/rust/cubesql/cubesql/src/sql/types.rs +++ b/rust/cubesql/cubesql/src/sql/types.rs @@ -148,6 +148,7 @@ pub enum CommandCompletion { Deallocate, DeallocateAll, Discard(String), + DropTable, } impl CommandCompletion { @@ -174,6 +175,7 @@ impl CommandCompletion { CommandCompletion::Discard(tp) => CommandComplete::Plain(format!("DISCARD {}", tp)), // ROWS COUNT CommandCompletion::Select(rows) => CommandComplete::Select(rows), + CommandCompletion::DropTable => CommandComplete::Plain("DROP TABLE".to_string()), } } }