Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions src/catalog/pg_roles.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::session::db_handler::{DbHandler, DbResponse};
use super::where_evaluator::WhereEvaluator;
use crate::PgSqliteError;
use sqlparser::ast::{Select, SelectItem, Expr};
use tracing::debug;
use crate::session::db_handler::{DbHandler, DbResponse};
use sqlparser::ast::{Expr, Select, SelectItem};
use std::collections::HashMap;
use super::where_evaluator::WhereEvaluator;
use tracing::debug;

pub struct PgRolesHandler;

Expand Down Expand Up @@ -31,15 +31,15 @@ impl PgRolesHandler {
"rolconfig".to_string(),
];

// Determine which columns to return
// Determine which output columns to return and which source columns supply values
let selected_columns = Self::get_selected_columns(&select.projection, &all_columns);

// Build default roles (since SQLite doesn't have role management)
let roles = Self::get_default_roles();

// Apply WHERE clause filtering if present
let filtered_roles = if let Some(where_clause) = &select.selection {
Self::apply_where_filter(&roles, where_clause, &selected_columns)?
Self::apply_where_filter(&roles, where_clause)?
} else {
roles
};
Expand All @@ -48,54 +48,82 @@ impl PgRolesHandler {
let mut rows = Vec::new();
for role in filtered_roles {
let mut row = Vec::new();
for column in &selected_columns {
let value = role.get(column).cloned().unwrap_or_else(|| b"".to_vec());
for (_, source_column) in &selected_columns {
let value = role
.get(source_column)
.cloned()
.unwrap_or_else(|| b"".to_vec());
row.push(Some(value));
}
rows.push(row);
}

let rows_count = rows.len();
Ok(DbResponse {
columns: selected_columns,
columns: selected_columns
.into_iter()
.map(|(output_column, _)| output_column)
.collect(),
rows,
rows_affected: rows_count,
})
}

fn get_selected_columns(projection: &[SelectItem], all_columns: &[String]) -> Vec<String> {
fn get_selected_columns(
projection: &[SelectItem],
all_columns: &[String],
) -> Vec<(String, String)> {
let mut selected = Vec::new();

for item in projection {
match item {
SelectItem::Wildcard(_) => {
selected.extend_from_slice(all_columns);
selected.extend(
all_columns
.iter()
.map(|column| (column.clone(), column.clone())),
);
break;
}
SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
let col_name = ident.value.to_lowercase();
if all_columns.contains(&col_name) {
selected.push(col_name);
SelectItem::UnnamedExpr(expr) => {
if let Some(col_name) = Self::extract_source_column(expr) {
if all_columns.contains(&col_name) {
selected.push((col_name.clone(), col_name));
}
}
}
SelectItem::ExprWithAlias { expr: Expr::Identifier(ident), alias } => {
let col_name = ident.value.to_lowercase();
if all_columns.contains(&col_name) {
selected.push(alias.value.clone());
SelectItem::ExprWithAlias { expr, alias } => {
if let Some(col_name) = Self::extract_source_column(expr) {
if all_columns.contains(&col_name) {
selected.push((alias.value.clone(), col_name));
}
}
}
SelectItem::QualifiedWildcard(_, _) => {
// For qualified wildcard like pg_roles.*, return all columns
selected.extend_from_slice(all_columns);
selected.extend(
all_columns
.iter()
.map(|column| (column.clone(), column.clone())),
);
break;
}
_ => {}
}
}

selected
}

fn extract_source_column(expr: &Expr) -> Option<String> {
match expr {
Expr::Identifier(ident) => Some(ident.value.to_lowercase()),
Expr::CompoundIdentifier(parts) => parts.last().map(|ident| ident.value.to_lowercase()),
Expr::Cast { expr, .. } => Self::extract_source_column(expr),
Expr::Nested(expr) => Self::extract_source_column(expr),
_ => None,
}
}

fn get_default_roles() -> Vec<HashMap<String, Vec<u8>>> {
let mut roles = Vec::new();

Expand Down Expand Up @@ -156,7 +184,6 @@ impl PgRolesHandler {
fn apply_where_filter(
roles: &[HashMap<String, Vec<u8>>],
where_clause: &Expr,
_selected_columns: &[String],
) -> Result<Vec<HashMap<String, Vec<u8>>>, PgSqliteError> {
let mut filtered = Vec::new();

Expand All @@ -177,4 +204,4 @@ impl PgRolesHandler {

Ok(filtered)
}
}
}
64 changes: 46 additions & 18 deletions src/catalog/query_interceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,10 +1054,11 @@ impl CatalogInterceptor {
}
}

fn handle_pg_namespace_query(_select: &Select) -> DbResponse {
// Return basic namespaces
let columns = vec!["oid".to_string(), "nspname".to_string()];
let rows = vec![
fn handle_pg_namespace_query(select: &Select) -> DbResponse {
let all_columns = vec!["oid".to_string(), "nspname".to_string()];
let (columns, column_indices) = Self::extract_selected_columns(select, &all_columns);

let full_rows = vec![
vec![
Some("11".to_string().into_bytes()),
Some("pg_catalog".to_string().into_bytes()),
Expand All @@ -1067,6 +1068,15 @@ impl CatalogInterceptor {
Some("public".to_string().into_bytes()),
],
];
let rows: Vec<Vec<Option<Vec<u8>>>> = full_rows
.into_iter()
.map(|full_row| {
column_indices
.iter()
.map(|&idx| full_row[idx].clone())
.collect()
})
.collect();

let rows_affected = rows.len();
debug!("Returning {} rows for pg_type query with {} columns: {:?}", rows_affected, columns.len(), columns);
Expand Down Expand Up @@ -1713,10 +1723,13 @@ impl CatalogInterceptor {
})
}

/// Extract selected columns from a SELECT query for information_schema views
/// Extract selected output columns and source indices from a SELECT query.
fn extract_selected_columns(select: &Select, all_columns: &[String]) -> (Vec<String>, Vec<usize>) {
if select.projection.len() == 1
&& let SelectItem::Wildcard(_) = &select.projection[0] {
&& matches!(
&select.projection[0],
SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _)
) {
// SELECT * - return all columns
return (all_columns.to_vec(), (0..all_columns.len()).collect::<Vec<_>>());
}
Expand All @@ -1726,29 +1739,44 @@ impl CatalogInterceptor {
let mut indices = Vec::new();
for item in &select.projection {
match item {
SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
let col_name = ident.value.to_string();
if let Some(idx) = all_columns.iter().position(|c| c == &col_name) {
cols.push(col_name);
indices.push(idx);
SelectItem::UnnamedExpr(expr) => {
if let Some(col_name) = Self::extract_projection_source_column(expr) {
if let Some(idx) = all_columns.iter().position(|c| c == &col_name) {
cols.push(all_columns[idx].clone());
indices.push(idx);
}
}
}
SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
// Handle compound identifiers like c.table_name
if let Some(last_part) = parts.last() {
let col_name = last_part.value.to_string();
SelectItem::ExprWithAlias { expr, alias } => {
if let Some(col_name) = Self::extract_projection_source_column(expr) {
if let Some(idx) = all_columns.iter().position(|c| c == &col_name) {
cols.push(col_name);
cols.push(alias.value.clone());
indices.push(idx);
}
}
}
_ => {}
SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
cols.extend_from_slice(all_columns);
indices.extend(0..all_columns.len());
break;
}
}
}
(cols, indices)
}

fn extract_projection_source_column(expr: &Expr) -> Option<String> {
match expr {
Expr::Identifier(ident) => Some(ident.value.to_lowercase()),
Expr::CompoundIdentifier(parts) => {
parts.last().map(|ident| ident.value.to_lowercase())
}
Expr::Cast { expr, .. } => Self::extract_projection_source_column(expr),
Expr::Nested(expr) => Self::extract_projection_source_column(expr),
_ => None,
}
}

async fn handle_information_schema_schemata_query(select: &Select, _db: &DbHandler) -> DbResponse {
debug!("Handling information_schema.schemata query");

Expand Down Expand Up @@ -3865,4 +3893,4 @@ impl CatalogInterceptor {

Ok(filtered)
}
}
}
56 changes: 56 additions & 0 deletions tests/catalog_alias_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use pgsqlite::catalog::CatalogInterceptor;
use pgsqlite::session::db_handler::DbHandler;
use std::sync::Arc;

async fn catalog_query(query: &str) -> pgsqlite::session::db_handler::DbResponse {
let db = Arc::new(DbHandler::new(":memory:").unwrap());
CatalogInterceptor::intercept_query(query, db, None)
.await
.expect("query should be intercepted")
.expect("catalog query should succeed")
}

fn text_cell(row: &[Option<Vec<u8>>], index: usize) -> String {
String::from_utf8(row[index].clone().expect("cell should not be NULL")).unwrap()
}

#[tokio::test]
async fn test_pg_catalog_database_alias_projection() {
let response = catalog_query("SELECT oid AS did, datname FROM pg_catalog.pg_database").await;

assert_eq!(response.columns, vec!["did", "datname"]);
assert_eq!(response.rows.len(), 1);
assert_eq!(text_cell(&response.rows[0], 0), "1");
assert_eq!(text_cell(&response.rows[0], 1), "main");
}

#[tokio::test]
async fn test_pg_catalog_roles_alias_projection_uses_source_column() {
let response = catalog_query("SELECT rolname AS rolsuper FROM pg_catalog.pg_roles").await;

assert_eq!(response.columns, vec!["rolsuper"]);
assert_eq!(response.rows.len(), 3);
let role_names: Vec<String> = response.rows.iter().map(|row| text_cell(row, 0)).collect();
assert_eq!(role_names, vec!["postgres", "public", "pgsqlite_user"]);
}

#[tokio::test]
async fn test_pg_catalog_namespace_alias_projection() {
let response = catalog_query("SELECT oid AS did FROM pg_catalog.pg_namespace").await;

assert_eq!(response.columns, vec!["did"]);
assert_eq!(response.rows.len(), 2);
assert_eq!(text_cell(&response.rows[0], 0), "11");
assert_eq!(text_cell(&response.rows[1], 0), "2200");
}

#[tokio::test]
async fn test_information_schema_schemata_alias_projection() {
let response = catalog_query("SELECT catalog_name AS x FROM information_schema.schemata").await;

assert_eq!(response.columns, vec!["x"]);
assert_eq!(response.rows.len(), 3);
for row in &response.rows {
assert_eq!(text_cell(row, 0), "main");
}
}
Loading