Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions rust/cubesql/cubesql/e2e/tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,92 @@ impl PostgresIntegrationTestSuite {
Ok(())
}

async fn test_temp_tables(&self) -> RunResult<()> {
// Create temporary table in current session
self.test_simple_query(
r#"
CREATE TEMPORARY TABLE temp_table AS
SELECT 5 AS i, 'c' AS s
UNION ALL
SELECT 10 AS i, 'd' AS s
"#
.to_string(),
|messages| {
let SimpleQueryMessage::CommandComplete(rows) = &messages[0] else {
panic!("Must be CommandComplete");
};

assert_eq!(*rows, 2);
},
)
.await?;

// Check that we can query it and we get the correct data
self.test_simple_query(
"SELECT i AS i, s AS s FROM temp_table GROUP BY 1, 2 ORDER BY i ASC".to_string(),
|messages| {
assert_eq!(messages.len(), 3);

let SimpleQueryMessage::Row(row) = &messages[0] else {
panic!("Must be Row, 0");
};

assert_eq!(row.get(0), Some("5"));
assert_eq!(row.get(1), Some("c"));

let SimpleQueryMessage::Row(row) = &messages[1] else {
panic!("Must be Row, 1");
};

assert_eq!(row.get(0), Some("10"));
assert_eq!(row.get(1), Some("d"));

let SimpleQueryMessage::CommandComplete(rows) = &messages[2] else {
panic!("Must be CommandComplete, 2");
};

assert_eq!(*rows, 2);
},
)
.await?;

// Other sessions must have no access to temp tables
let new_client = Self::create_client(
format!(
"host=127.0.0.1 port={} dbname=meow user=test password=test",
self.port
)
.parse()
.unwrap(),
)
.await;

let result = new_client
.simple_query("SELECT i AS i, s AS s FROM temp_table GROUP BY 1, 2 ORDER BY i ASC")
.await;
assert!(result.is_err());

// Drop table, make sure we can't query it anymore
self.test_simple_query("DROP TABLE temp_table".to_string(), |messages| {
let SimpleQueryMessage::CommandComplete(rows) = &messages[0] else {
panic!("Must be CommandComplete");
};

assert_eq!(*rows, 0);
})
.await?;

let result = self
.test_simple_query(
"SELECT i AS i, s AS s FROM temp_table GROUP BY 1, 2 ORDER BY i ASC".to_string(),
|_| {},
)
.await;
assert!(result.is_err());

Ok(())
}

fn assert_row(&self, message: &SimpleQueryMessage, expected_value: String) {
if let SimpleQueryMessage::Row(row) = message {
assert_eq!(row.get(0), Some(expected_value.as_str()));
Expand Down Expand Up @@ -973,6 +1059,7 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
self.test_df_panic_handle().await?;
self.test_simple_query_discard_all().await?;
self.test_database_change().await?;
self.test_temp_tables().await?;

// PostgreSQL doesn't support unsigned integers in the protocol, it's a constraint only
self.test_snapshot_execute_query(
Expand Down
38 changes: 24 additions & 14 deletions rust/cubesql/cubesql/src/compile/engine/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ use datafusion::{

use crate::{
compile::MetaContext,
sql::{session::DatabaseProtocol, ColumnType, SessionManager, SessionState},
sql::{
session::DatabaseProtocol, temp_tables::TempTableProvider, ColumnType, SessionManager,
SessionState,
},
transport::V1CubeMetaExt,
CubeError,
};
Expand Down Expand Up @@ -272,7 +275,9 @@ impl DatabaseProtocol {
table_provider: Arc<dyn datasource::TableProvider>,
) -> Result<String, CubeError> {
let any = table_provider.as_any();
Ok(if let Some(t) = any.downcast_ref::<CubeTableProvider>() {
Ok(if let Some(t) = any.downcast_ref::<TempTableProvider>() {
t.name().to_string()
} else if let Some(t) = any.downcast_ref::<CubeTableProvider>() {
t.table_name().to_string()
} else if let Some(_) = any.downcast_ref::<PostgresSchemaColumnsProvider>() {
"information_schema.columns".to_string()
Expand Down Expand Up @@ -399,23 +404,28 @@ impl DatabaseProtocol {
table.to_ascii_lowercase(),
),
datafusion::catalog::TableReference::Bare { table } => {
if table.starts_with("pg_") {
(
context.session_state.database().unwrap_or("db".to_string()),
"pg_catalog".to_string(),
table.to_ascii_lowercase(),
)
let table_lower = table.to_ascii_lowercase();
let schema = if context.session_state.temp_tables().has(&table_lower) {
"pg_temp_3"
} else if table.starts_with("pg_") {
"pg_catalog"
} else {
(
context.session_state.database().unwrap_or("db".to_string()),
"public".to_string(),
table.to_ascii_lowercase(),
)
}
"public"
};
(
context.session_state.database().unwrap_or("db".to_string()),
schema.to_string(),
table_lower,
)
}
};

match schema.as_str() {
"pg_temp_3" => {
if let Some(temp_table) = context.session_state.temp_tables().get(&table) {
return Some(Arc::new(TempTableProvider::new(table, temp_table)));
}
}
"public" => {
if let Some(cube) = context
.meta
Expand Down
Loading