From 04dce07c1ac9b904087338ad826dbf1939b7284e Mon Sep 17 00:00:00 2001 From: Eir Nym <485399+eirnym@users.noreply.github.com> Date: Sun, 27 Apr 2025 11:57:55 +0200 Subject: [PATCH 1/3] Add ability to execute queries --- src/args.rs | 16 +++- src/cell_to_string.rs | 31 +++---- src/error.rs | 63 ++++++++++++++ src/main.rs | 189 ++++++++++++++---------------------------- src/matching.rs | 78 +++++++++++++++++ src/pattern.rs | 25 ++++++ src/query.rs | 43 ++++++++++ src/select.rs | 38 +++++++++ 8 files changed, 333 insertions(+), 150 deletions(-) create mode 100644 src/error.rs create mode 100644 src/matching.rs create mode 100644 src/pattern.rs create mode 100644 src/query.rs diff --git a/src/args.rs b/src/args.rs index 166bebd..6709606 100644 --- a/src/args.rs +++ b/src/args.rs @@ -19,9 +19,19 @@ pub(crate) struct Args { pub(crate) verbose: u8, #[arg(short = 't', long = "table")] - #[arg(help = "Table or views to query. Can be used multiple times.")] - #[arg(action=ArgAction::Set)] - pub(crate) table: Option>, + #[arg(help = "Table or view to query. Can be used multiple times")] + #[arg(action=ArgAction::Append)] + pub(crate) table: Vec, + + #[arg(short = 's', long = "sql")] + #[arg(help = "SQL query to run. Can be used multiple times")] + #[arg(action=ArgAction::Append)] + pub(crate) query: Vec, + + #[arg(short = 'i', long = "ignore")] + #[arg(help = "Ignore non-readonly queries")] + #[arg(action=ArgAction::SetTrue)] + pub(crate) ignore_non_readonly: bool, #[arg(help = "Pattern to match every cell with")] pub(crate) pattern: String, diff --git a/src/cell_to_string.rs b/src/cell_to_string.rs index 111f196..ce2c2cb 100644 --- a/src/cell_to_string.rs +++ b/src/cell_to_string.rs @@ -1,5 +1,3 @@ -use std::fmt::Display; - use log::warn; use sqlx::sqlite::SqliteValueRef; use sqlx::Decode; @@ -8,14 +6,7 @@ use sqlx::Type; use sqlx::TypeInfo; use sqlx::ValueRef; -// REVIEW: better way convert an error to a string? -fn errr_format(value: impl Display) -> String { - format!("{value}") -} - pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result, String> { - // TODO: add an option to override types in some extent - if value_ref.is_null() { return Ok(None); } @@ -24,41 +15,45 @@ pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result>::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(value)); } // // INTEGER, INT4 if >::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(format!("{value}"))); } // REAL if >::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(format!("{value}"))); } // BOOL? if >::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(format!("{value}"))); } // DateTime if as Type>::compatible(&type_info) { let value = as Decode>::decode(value_ref) - .map_err(errr_format)?; + .map_err(|value| value.to_string())?; return Ok(Some(value.to_rfc3339())); } // Date if >::compatible(&type_info) { - let value = - >::decode(value_ref).map_err(errr_format)?; + let value = >::decode(value_ref) + .map_err(|value| value.to_string())?; return Ok(Some(value.format("%Y-%m-%d").to_string())); } // Time if >::compatible(&type_info) { - let value = - >::decode(value_ref).map_err(errr_format)?; + let value = >::decode(value_ref) + .map_err(|value| value.to_string())?; return Ok(Some(value.format("%H:%M:%S").to_string())); } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..9609940 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,63 @@ +use log::Level; +use sqlparser::parser::ParserError; + +pub(crate) enum SQLError { + Regex(regex::Error), + QueryError(QueryError), + ParseError(ParserError), + SqlX((String, sqlx::Error)), + ConvertCell((String, String)), +} + +pub(crate) enum QueryError { + ReadOnlyQueryAllowed, +} + +impl SQLError { + /// Report and return error code if needed + pub fn report(&self, level: Level) -> i32 { + match self { + SQLError::Regex(error) => { + log::log!(level, "Regex error: {error}"); + + 64 + } + SQLError::QueryError(query_error) => { + match query_error { + QueryError::ReadOnlyQueryAllowed => { + log::log!(level, "Only readonly query is allowed"); + } + }; + + 65 + } + SQLError::ParseError(error) => { + log::log!(level, "Unable to parse SQL: {error}"); + + 66 + } + SQLError::SqlX((context, error)) => { + let context = if context.is_empty() { + "".to_owned() + } else { + format!(" ({context})") + }; + + log::log!(level, "Query execution error{context}: {error}"); + + 74 + } + SQLError::ConvertCell((context, error)) => { + let context = if context.is_empty() { + "".to_owned() + } else { + format!(" ({context})") + }; + + log::log!(level, "Query execution error{context}: {error}"); + + 73 + } + } + } +} diff --git a/src/main.rs b/src/main.rs index d3657f9..3363a27 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,19 @@ mod args; mod cell_to_string; +mod error; +mod matching; +mod pattern; +mod query; mod select; +use error::SQLError; +use log::Level; +use matching::sqlite_check_rows; +use pattern::{Pattern, PatternKind}; +use query::{prepare_queries, SelectVariant}; use sqlparser::dialect::SQLiteDialect; -use sqlx::sqlite::SqliteConnectOptions; -use sqlx::Error; -use sqlx::{Column, Executor, Pool, Row, Sqlite, SqlitePool}; +use sqlx::{sqlite::SqliteConnectOptions, Executor as _, Pool, Row as _, Sqlite, SqlitePool}; #[tokio::main()] async fn main() { @@ -23,33 +30,32 @@ async fn main() { .init() .unwrap(); - let pattern = match regex::Regex::new(&args.pattern) { + let pattern = match Pattern::new(args.pattern.as_str(), PatternKind::Regex) { Ok(pattern) => pattern, - Err(err) => { - log::error!("Unable to compile pattern: {}", err); - std::process::exit(64); - } - }; - - let select_variant = match args.table { - None => SelectVariant::WholeDB, - Some(table_names) => SelectVariant::SpecificTables(table_names), + Err(err) => std::process::exit(err.report(Level::Error)), }; - match process_sqlite_database(args.database_uri, &pattern, select_variant).await { + match process_sqlite_database( + args.database_uri, + pattern, + args.table, + args.query, + args.ignore_non_readonly, + ) + .await + { Ok(_) => {} - Err(err) => { - log::error!("Unable read tables from database: {}", err); - std::process::exit(74); - } + Err(err) => std::process::exit(err.report(Level::Error)), } } async fn process_sqlite_database( database_uri: String, - pattern: ®ex::Regex, - select_variant: SelectVariant, -) -> Result<(), Error> { + pattern: Pattern, + tables: Vec, + queries: Vec, + ignore_non_read: bool, +) -> Result<(), SQLError> { let dialect = SQLiteDialect {}; let options: SqliteConnectOptions = match database_uri.parse::() { @@ -68,130 +74,55 @@ async fn process_sqlite_database( } }; - match select_variant { - SelectVariant::WholeDB => match sqlite_select_tables(&db).await { - Ok(table_names) => { - let mut table_names = table_names; - sqlite_select_from_tables(&db, &mut table_names, pattern, &dialect).await + let select_variant = prepare_queries( + tables.into_iter(), + queries.into_iter(), + &dialect, + ignore_non_read, + )?; + + let queries = match select_variant { + SelectVariant::Queries(queries) => queries, + SelectVariant::WholeDB => { + let tables = sqlite_select_tables(&db).await?; + let select_variant = prepare_queries( + tables.into_iter(), + Vec::new().into_iter(), + &dialect, + ignore_non_read, + )?; + match select_variant { + SelectVariant::WholeDB => Vec::new(), + SelectVariant::Queries(queries) => queries, } - Err(err) => Err(err), - }, - SelectVariant::SpecificTables(table_names) => { - let mut table_names = table_names.into_iter(); - sqlite_select_from_tables(&db, &mut table_names, pattern, &dialect).await } - } -} + }; -async fn sqlite_select_from_tables( - db: &Pool, - table_names: &mut Iter, - pattern: ®ex::Regex, - dialect: &SQLiteDialect, -) -> Result<(), Error> -where - Iter: Iterator, -{ - for table_name in table_names { - let select_query = select::generate_select(table_name.as_str(), dialect); - sqlite_check_rows(&table_name, db, select_query.as_str(), pattern).await; + for (query_id, query) in queries { + sqlite_check_rows(&db, query_id.as_str(), query.as_str(), &pattern).await } Ok(()) } -async fn sqlite_select_tables(db: &Pool) -> Result, Error> { - let select_query = "SELECT name - FROM sqlite_schema - WHERE type ='table'"; +async fn sqlite_select_tables(db: &Pool) -> Result, SQLError> { + let select_query = "SELECT name FROM sqlite_schema WHERE type = 'table'"; log::debug!("Execute query: {select_query}"); - let result = db.fetch_all(select_query).await?; + + let result = db + .fetch_all(select_query) + .await + .map_err(|err| SQLError::SqlX(("fetch tables".into(), err)))?; Ok(result .into_iter() .filter_map(|row| match row.try_get::("name") { Ok(value) => Some(value), Err(err) => { - log::warn!("Error while reading from table `sqlite_schema`: {}", err); + SQLError::SqlX(("fetch tables".into(), err)).report(Level::Warn); None } - })) -} - -async fn sqlite_check_rows( - table_name: &String, - db: &Pool, - select_query: &str, - pattern: ®ex::Regex, -) { - use futures::TryStreamExt; - use std::sync::atomic::AtomicI64; - use std::sync::atomic::Ordering; - - log::debug!("Execute query: {select_query}"); - let mut rows = db.fetch(select_query); - - log::debug!("==> {table_name}"); - let row_idx: AtomicI64 = AtomicI64::new(-1); - loop { - row_idx.fetch_add(1, Ordering::SeqCst); - let idx = row_idx.load(Ordering::SeqCst); - - let row = match rows.try_next().await { - Ok(None) => break, - Ok(Some(row)) => row, - Err(err) => { - log::warn!( - "Error while reading row {idx} from table `{table_name}`: {}", - err - ); - continue; - } - }; - - sqlite_process_row(idx, row, table_name, pattern); - } -} - -fn sqlite_process_row( - row_idx: i64, - row: sqlx::sqlite::SqliteRow, - table_name: &String, - pattern: ®ex::Regex, -) { - use sqlx::TypeInfo; - let columns = row.columns(); - for column in columns { - let index = column.ordinal(); - let column_name = column.name().to_owned(); - - let value_ref = match row.try_get_raw(index) { - Ok(value_ref) => value_ref, - Err(err) => { - log::warn!("Error while reading row {row_idx} from table {table_name} column {column_name}: {}", err); - continue; - } - }; - - let value_str = match cell_to_string::sqlite_cell_to_string(value_ref) { - Ok(Some(value_str)) => value_str, - Ok(None) => continue, - Err(err) => { - let column_type = column.type_info().name(); - log::warn!("Error while converting data from row {row_idx} from table {table_name} column {column_name} of type {column_type}: {}", err); - continue; - } - }; - - if pattern.is_match(&value_str) { - println!("{table_name}::{row_idx}::{column_name} => {value_str:?}"); - } - } -} - -#[non_exhaustive] -enum SelectVariant { - WholeDB, - SpecificTables(Vec), + }) + .collect()) } diff --git a/src/matching.rs b/src/matching.rs new file mode 100644 index 0000000..41df1f8 --- /dev/null +++ b/src/matching.rs @@ -0,0 +1,78 @@ +use crate::cell_to_string::sqlite_cell_to_string; +use crate::{Pattern, SQLError}; +use log::Level; +use sqlx::{Column, Executor, Pool, Row, Sqlite}; + +pub async fn sqlite_check_rows( + db: &Pool, + query_id: &str, + select_query: &str, + pattern: &Pattern, +) { + use futures::TryStreamExt; + use std::sync::atomic::AtomicI64; + use std::sync::atomic::Ordering; + + log::debug!("{query_id}: {select_query}"); + + let mut rows = db.fetch(select_query); + + let row_idx: AtomicI64 = AtomicI64::new(-1); + loop { + row_idx.fetch_add(1, Ordering::SeqCst); + let idx = row_idx.load(Ordering::SeqCst); + + let row = match rows.try_next().await { + Ok(None) => break, + Ok(Some(row)) => row, + Err(err) => { + log::warn!( + "Error while reading row {idx} while executing query: {}", + err + ); + continue; + } + }; + + sqlite_process_row(idx, row, query_id, pattern); + } +} + +fn sqlite_process_row( + row_idx: i64, + row: sqlx::sqlite::SqliteRow, + query_id: &str, + pattern: &Pattern, +) { + use sqlx::TypeInfo; + let columns = row.columns(); + for column in columns { + let index = column.ordinal(); + let column_name = column.name().to_owned(); + let column_type = column.type_info().name(); + + let error_context = + format!("Reading row {row_idx} from table {query_id} column {column_name} of type {column_type}"); + + let value_ref = match row.try_get_raw(index) { + Ok(value_ref) => value_ref, + Err(error) => { + SQLError::SqlX((error_context, error)).report(Level::Warn); + continue; + } + }; + + let value_str = match sqlite_cell_to_string(value_ref) { + Ok(Some(value_str)) => value_str, + Ok(None) => continue, + Err(error) => { + SQLError::ConvertCell((error_context, error)).report(Level::Warn); + continue; + } + }; + + if pattern.is_match(&value_str) { + println!("{query_id}::{row_idx}::{column_name} => {value_str:?}"); + } + } +} diff --git a/src/pattern.rs b/src/pattern.rs new file mode 100644 index 0000000..cb8a3e9 --- /dev/null +++ b/src/pattern.rs @@ -0,0 +1,25 @@ +use crate::error::SQLError; + +pub(crate) enum PatternKind { + Regex, +} + +pub(crate) enum Pattern { + Regex(regex::Regex), +} + +impl Pattern { + pub fn new(pattern: &str, kind: PatternKind) -> Result { + match kind { + PatternKind::Regex => regex::Regex::new(pattern) + .map(Self::Regex) + .map_err(SQLError::Regex), + } + } + + pub fn is_match(&self, value: &str) -> bool { + match self { + Pattern::Regex(regex) => regex.is_match(value), + } + } +} diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 0000000..0a6b0d1 --- /dev/null +++ b/src/query.rs @@ -0,0 +1,43 @@ +use sqlparser::dialect::Dialect; + +use crate::error::SQLError; +use crate::select::{escape_table_name, generate_select, read_verify_query}; + +#[non_exhaustive] +pub(crate) enum SelectVariant { + WholeDB, + Queries(Vec<(String, String)>), +} + +pub(crate) fn prepare_queries( + table: T, + queries: T, + dialect: &impl Dialect, + ignore_non_read: bool, +) -> Result +where + T: Iterator, +{ + let mut queries_result: Vec<(String, String)> = table + .map(|table_name| { + ( + format!("Table {}", escape_table_name(table_name.as_str(), dialect)), + generate_select(&table_name, dialect), + ) + }) + .collect(); + + let mut idx = 0usize; + queries.into_iter().try_fold((), |_, sql| { + read_verify_query(&sql, dialect, ignore_non_read, &mut idx)? + .into_iter() + .for_each(|query| queries_result.push((format!("Query #{idx}"), query))); + Ok(()) + })?; + + Ok(if queries_result.is_empty() { + SelectVariant::WholeDB + } else { + SelectVariant::Queries(queries_result) + }) +} diff --git a/src/select.rs b/src/select.rs index 09f80ea..d9332fd 100644 --- a/src/select.rs +++ b/src/select.rs @@ -1,8 +1,11 @@ use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::ast::*; use sqlparser::dialect::Dialect; +use sqlparser::parser::Parser; use sqlparser::tokenizer::Span; +use crate::error::{QueryError, SQLError}; + /// /// Generates wildcard select for given dialect: /// @@ -70,3 +73,38 @@ pub(crate) fn generate_select(table_name: &str, dialect: &impl Dialect) -> Strin ast.to_string() } + +pub(crate) fn escape_table_name(table_name: &str, dialect: &impl Dialect) -> String { + Ident { + value: table_name.to_owned(), + quote_style: dialect.identifier_quote_style(table_name), + span: Span::empty(), + } + .to_string() +} + +/// Checks and reformat select +pub(crate) fn read_verify_query( + sql: &str, + dialect: &impl Dialect, + ignore_non_read: bool, + idx: &mut usize, +) -> Result, SQLError> { + let ast = Parser::parse_sql(dialect, sql).map_err(SQLError::ParseError)?; + let mut acc: Vec = Vec::new(); + ast.iter().try_fold((), |_, statement| { + if matches!(statement, Statement::Query(_)) { + acc.push(statement.to_string()); + *idx += 1; + Ok(()) + } else { + if ignore_non_read { + Ok(()) + } else { + Err(SQLError::QueryError(QueryError::ReadOnlyQueryAllowed)) + } + } + })?; + + Ok(acc) +} From 6cb2a7e01aa50c34b06b36bb4ac7ee296c31390f Mon Sep 17 00:00:00 2001 From: Eir Nym <485399+eirnym@users.noreply.github.com> Date: Sun, 27 Apr 2025 12:17:48 +0200 Subject: [PATCH 2/3] Improve error handling --- src/cell_to_string.rs | 5 +---- src/error.rs | 32 +++++++++++++++++--------------- src/main.rs | 35 ++++++++++++++--------------------- src/matching.rs | 33 +++++++++++++++------------------ src/pattern.rs | 2 +- src/query.rs | 2 +- src/select.rs | 17 +++++++++-------- 7 files changed, 58 insertions(+), 68 deletions(-) diff --git a/src/cell_to_string.rs b/src/cell_to_string.rs index ce2c2cb..5ecafdb 100644 --- a/src/cell_to_string.rs +++ b/src/cell_to_string.rs @@ -1,9 +1,7 @@ -use log::warn; use sqlx::sqlite::SqliteValueRef; use sqlx::Decode; use sqlx::Sqlite; use sqlx::Type; -use sqlx::TypeInfo; use sqlx::ValueRef; pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result, String> { @@ -68,6 +66,5 @@ pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result { log::log!(level, "Only readonly query is allowed"); } - }; + } 65 } @@ -37,27 +39,27 @@ impl SQLError { 66 } SQLError::SqlX((context, error)) => { - let context = if context.is_empty() { - "".to_owned() - } else { - format!(" ({context})") - }; - - log::log!(level, "Query execution error{context}: {error}"); + let context = format_context(context); + log::log!(level, "SQL error{context}: {error}"); 74 } SQLError::ConvertCell((context, error)) => { - let context = if context.is_empty() { - "".to_owned() - } else { - format!(" ({context})") - }; + let context = format_context(context); - log::log!(level, "Query execution error{context}: {error}"); + log::log!(level, "Cell conversion error{context}: {error}"); 73 } } } } + +#[inline] +fn format_context(context: &String) -> String { + if context.is_empty() { + String::new() + } else { + format!(" ({context})") + } +} diff --git a/src/main.rs b/src/main.rs index 3363a27..b02de25 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,8 +6,8 @@ mod pattern; mod query; mod select; +use error::Level; use error::SQLError; -use log::Level; use matching::sqlite_check_rows; use pattern::{Pattern, PatternKind}; use query::{prepare_queries, SelectVariant}; @@ -19,7 +19,7 @@ use sqlx::{sqlite::SqliteConnectOptions, Executor as _, Pool, Row as _, Sqlite, async fn main() { let args = args::parse_args(); // Set default log level to 2. - let quiet_level: i16 = 2 + args.verbose as i16 - args.quiet as i16; + let quiet_level: i16 = 2 + i16::from(args.verbose) - i16::from(args.quiet); stderrlog::new() .module(module_path!()) @@ -30,7 +30,7 @@ async fn main() { .init() .unwrap(); - let pattern = match Pattern::new(args.pattern.as_str(), PatternKind::Regex) { + let pattern = match Pattern::new(args.pattern.as_str(), &PatternKind::Regex) { Ok(pattern) => pattern, Err(err) => std::process::exit(err.report(Level::Error)), }; @@ -44,7 +44,7 @@ async fn main() { ) .await { - Ok(_) => {} + Ok(()) => {} Err(err) => std::process::exit(err.report(Level::Error)), } } @@ -58,21 +58,14 @@ async fn process_sqlite_database( ) -> Result<(), SQLError> { let dialect = SQLiteDialect {}; - let options: SqliteConnectOptions = match database_uri.parse::() { - Ok(options) => options.read_only(true).immutable(true), - Err(err) => { - log::error!("Database URI error: {}", err); - std::process::exit(64); - } - }; + let options: SqliteConnectOptions = database_uri + .parse::() + .map(|options| options.read_only(true).immutable(true)) + .map_err(|error| SQLError::SqlX(("Database URI".into(), error)))?; - let db = match SqlitePool::connect_with(options).await { - Ok(db) => db, - Err(err) => { - log::error!("Database connection error: {}", err); - std::process::exit(74); - } - }; + let db = SqlitePool::connect_with(options) + .await + .map_err(|error| SQLError::SqlX(("Database connection".into(), error)))?; let select_variant = prepare_queries( tables.into_iter(), @@ -87,19 +80,19 @@ async fn process_sqlite_database( let tables = sqlite_select_tables(&db).await?; let select_variant = prepare_queries( tables.into_iter(), - Vec::new().into_iter(), + vec![].into_iter(), &dialect, ignore_non_read, )?; match select_variant { - SelectVariant::WholeDB => Vec::new(), + SelectVariant::WholeDB => vec![], SelectVariant::Queries(queries) => queries, } } }; for (query_id, query) in queries { - sqlite_check_rows(&db, query_id.as_str(), query.as_str(), &pattern).await + sqlite_check_rows(&db, query_id.as_str(), query.as_str(), &pattern).await; } Ok(()) diff --git a/src/matching.rs b/src/matching.rs index 41df1f8..889f062 100644 --- a/src/matching.rs +++ b/src/matching.rs @@ -1,6 +1,7 @@ use crate::cell_to_string::sqlite_cell_to_string; +use crate::error::Level; use crate::{Pattern, SQLError}; -use log::Level; + use sqlx::{Column, Executor, Pool, Row, Sqlite}; pub async fn sqlite_check_rows( @@ -10,37 +11,34 @@ pub async fn sqlite_check_rows( pattern: &Pattern, ) { use futures::TryStreamExt; - use std::sync::atomic::AtomicI64; + use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; log::debug!("{query_id}: {select_query}"); let mut rows = db.fetch(select_query); - let row_idx: AtomicI64 = AtomicI64::new(-1); + let row_counter: AtomicU64 = AtomicU64::new(0); loop { - row_idx.fetch_add(1, Ordering::SeqCst); - let idx = row_idx.load(Ordering::SeqCst); + let row_idx = row_counter.load(Ordering::SeqCst); let row = match rows.try_next().await { Ok(None) => break, Ok(Some(row)) => row, - Err(err) => { - log::warn!( - "Error while reading row {idx} while executing query: {}", - err - ); + Err(error) => { + SQLError::SqlX((format!("{query_id}::{row_idx}"), error)).report(Level::Warn); continue; } }; - sqlite_process_row(idx, row, query_id, pattern); + sqlite_process_row(row_idx, &row, query_id, pattern); + row_counter.fetch_add(1, Ordering::SeqCst); } } fn sqlite_process_row( - row_idx: i64, - row: sqlx::sqlite::SqliteRow, + row_idx: u64, + row: &sqlx::sqlite::SqliteRow, query_id: &str, pattern: &Pattern, ) { @@ -50,14 +48,12 @@ fn sqlite_process_row( let index = column.ordinal(); let column_name = column.name().to_owned(); let column_type = column.type_info().name(); - - let error_context = - format!("Reading row {row_idx} from table {query_id} column {column_name} of type {column_type}"); + let row_id = format!("{query_id}::{row_idx}::{column_name}"); let value_ref = match row.try_get_raw(index) { Ok(value_ref) => value_ref, Err(error) => { - SQLError::SqlX((error_context, error)).report(Level::Warn); + SQLError::SqlX((row_id, error)).report(Level::Warn); continue; } }; @@ -66,13 +62,14 @@ fn sqlite_process_row( Ok(Some(value_str)) => value_str, Ok(None) => continue, Err(error) => { + let error_context = format!("{row_id} cell type {column_type}"); SQLError::ConvertCell((error_context, error)).report(Level::Warn); continue; } }; if pattern.is_match(&value_str) { - println!("{query_id}::{row_idx}::{column_name} => {value_str:?}"); + println!("{row_id} => {value_str}"); } } } diff --git a/src/pattern.rs b/src/pattern.rs index cb8a3e9..60d1832 100644 --- a/src/pattern.rs +++ b/src/pattern.rs @@ -9,7 +9,7 @@ pub(crate) enum Pattern { } impl Pattern { - pub fn new(pattern: &str, kind: PatternKind) -> Result { + pub fn new(pattern: &str, kind: &PatternKind) -> Result { match kind { PatternKind::Regex => regex::Regex::new(pattern) .map(Self::Regex) diff --git a/src/query.rs b/src/query.rs index 0a6b0d1..36918f5 100644 --- a/src/query.rs +++ b/src/query.rs @@ -28,7 +28,7 @@ where .collect(); let mut idx = 0usize; - queries.into_iter().try_fold((), |_, sql| { + queries.into_iter().try_fold((), |(), sql| { read_verify_query(&sql, dialect, ignore_non_read, &mut idx)? .into_iter() .for_each(|query| queries_result.push((format!("Query #{idx}"), query))); diff --git a/src/select.rs b/src/select.rs index d9332fd..a639086 100644 --- a/src/select.rs +++ b/src/select.rs @@ -1,5 +1,8 @@ use sqlparser::ast::helpers::attached_token::AttachedToken; -use sqlparser::ast::*; +use sqlparser::ast::{ + GroupByExpr, Ident, Select, SelectFlavor, SelectItem, SetExpr, Statement, TableFactor, + TableWithJoins, WildcardAdditionalOptions, +}; use sqlparser::dialect::Dialect; use sqlparser::parser::Parser; use sqlparser::tokenizer::Span; @@ -91,18 +94,16 @@ pub(crate) fn read_verify_query( idx: &mut usize, ) -> Result, SQLError> { let ast = Parser::parse_sql(dialect, sql).map_err(SQLError::ParseError)?; - let mut acc: Vec = Vec::new(); - ast.iter().try_fold((), |_, statement| { + let mut acc: Vec = vec![]; + ast.iter().try_fold((), |(), statement| { if matches!(statement, Statement::Query(_)) { acc.push(statement.to_string()); *idx += 1; Ok(()) + } else if ignore_non_read { + Ok(()) } else { - if ignore_non_read { - Ok(()) - } else { - Err(SQLError::QueryError(QueryError::ReadOnlyQueryAllowed)) - } + Err(SQLError::QueryError(QueryError::ReadOnlyQueryAllowed)) } })?; From 9b236dd1183b90766dd317436a1f523deeaf9dc1 Mon Sep 17 00:00:00 2001 From: Eir Nym <485399+eirnym@users.noreply.github.com> Date: Sun, 27 Apr 2025 13:06:43 +0200 Subject: [PATCH 3/3] Read queries from stdin or files --- src/error.rs | 7 ++++++ src/main.rs | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 211b824..dd1b565 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ pub(crate) enum SQLError { ParseError(ParserError), SqlX((String, sqlx::Error)), ConvertCell((String, String)), + Io((String, std::io::Error)), } pub(crate) enum QueryError { @@ -38,6 +39,12 @@ impl SQLError { 66 } + SQLError::Io((context, error)) => { + let context = format_context(context); + log::log!(level, "IO error{context}: {error}"); + + 70 + } SQLError::SqlX((context, error)) => { let context = format_context(context); log::log!(level, "SQL error{context}: {error}"); diff --git a/src/main.rs b/src/main.rs index b02de25..a15b1a0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,9 @@ mod pattern; mod query; mod select; +use std::fs::OpenOptions; +use std::io::stdin; + use error::Level; use error::SQLError; use matching::sqlite_check_rows; @@ -35,11 +38,16 @@ async fn main() { Err(err) => std::process::exit(err.report(Level::Error)), }; + let queries = match read_queries(args.query) { + Ok(queries) => queries, + Err(err) => std::process::exit(err.report(Level::Error)), + }; + match process_sqlite_database( args.database_uri, pattern, args.table, - args.query, + queries, args.ignore_non_readonly, ) .await @@ -119,3 +127,55 @@ async fn sqlite_select_tables(db: &Pool) -> Result, SQLError }) .collect()) } + +fn read_queries(queries: Vec) -> Result, SQLError> { + let mut acc = vec![]; + + queries.into_iter().try_fold((), |(), query| { + if query.is_empty() { + return Ok(()); + } + + if query == "-" { + return read_query(&mut stdin(), "").map(|query| { + acc.push(query); + }); + } + + match query.strip_prefix('@') { + None => { + acc.push(query); + Ok(()) + } + Some(filename) => read_from_file(filename).map(|query| { + acc.push(query); + }), + } + })?; + + Ok(acc) +} + +#[inline] +fn read_from_file(filename: &str) -> Result { + let mut file = OpenOptions::new() + .read(true) + .write(false) + .create(false) + .open(filename) + .expect("Unable to open"); + + read_query(&mut file, filename) +} + +#[inline] +fn read_query(file: &mut File, filename: &str) -> Result +where + File: std::io::Read, +{ + let mut query = String::new(); + + file.read_to_string(&mut query) + .map_err(|error| SQLError::Io((format!("read {filename}"), error))) + .map(|_| query) +}