diff --git a/README.md b/README.md index f5ed9d5..34e5d22 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Positional arguments: Optional arguments: -c, --command COMMAND Run a single command and exit -C, --core Preset of settings to connect to Firebolt Core - -h, --host HOSTNAME Hostname to connect to + -h, --host HOSTNAME Hostname (and port) to connect to -d, --database DATABASE Database name to use -f, --format FORMAT Output format (e.g., TabSeparatedWithNames, PSQL, JSONLines_Compact, Vertical, ...) -e, --extra EXTRA Extra settings in the form --extra = diff --git a/src/args.rs b/src/args.rs index be57c7b..1376826 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,6 +1,6 @@ use gumdrop::Options; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::BTreeMap; use std::fs; use crate::utils::{config_path, init_root_path}; @@ -30,12 +30,12 @@ pub struct Args { #[serde(skip_serializing, skip_deserializing)] pub core: bool, - #[options(help = "Hostname to connect to", meta = "HOSTNAME")] + #[options(help = "Hostname (and port) to connect to", meta = "HOSTNAME")] #[serde(default)] pub host: String, #[options(help = "Database name to use")] - #[serde(default)] + #[serde(skip_serializing, skip_deserializing)] pub database: String, #[options(help = "Output format (e.g., TabSeparatedWithNames, PSQL, JSONLines_Compact, Vertical, ...)")] @@ -112,15 +112,27 @@ pub struct Args { } pub fn normalize_extras(extras: Vec, encode: bool) -> Result, Box> { - let x: HashMap<&str, String> = HashMap::from_iter(extras.iter().map(|e| { + let mut x: BTreeMap = BTreeMap::new(); + + for e in &extras { let kv: Vec<&str> = e.split('=').collect(); if kv.len() < 2 { - panic!("Cannot parse '{}': expected key=value format", e) + return Err(format!("Cannot parse '{}': expected key=value format", e).into()); } - let value = kv[1..].join("=").to_string(); - // uri encode params - (kv[0], if encode { urlencoding::encode(&value).into_owned() } else { value }) - })); + + let key = kv[0].to_string(); + let value = kv[1..].join("=").trim().to_string(); + let value = if value.starts_with('\'') && value.ends_with('\'') || value.starts_with('"') && value.ends_with('"') { + value[1..value.len() - 1].to_string() + } else { + value + }; + + let value = if encode { urlencoding::encode(&value).into_owned() } else { value }; + + x.insert(key, value); + } + let mut new_extras: Vec = vec![]; for (key, value) in &x { new_extras.push(format!("{key}={value}")) @@ -164,7 +176,6 @@ pub fn get_args() -> Result> { if args.update_defaults { args.host = args.host.or(default_host); - args.database = args.database.or(String::from("local_dev_db")); if args.core { args.format = args.format.or(String::from("PSQL")); } else { @@ -179,11 +190,6 @@ pub fn get_args() -> Result> { args.concise = args.concise || defaults.concise; args.hide_pii = args.hide_pii || defaults.hide_pii; - args.database = args - .database - .or(args.core.then(|| String::from("firebolt")).unwrap_or(defaults.database)) - .or(String::from("local_dev_db")); - if args.core { args.host = args.host.or(String::from("localhost:3473")); args.jwt = String::from(""); @@ -313,4 +319,25 @@ mod tests { assert!(url.contains("param=value%20with%20spaces")); assert!(!url.contains("param=value%2520with%2520spaces")); // No double encoding } + + #[test] + fn test_params_with_quotes() { + let extras = vec!["param1='value with spaces'".to_string(), "param2=\"value with spaces\"".to_string()]; + let result = normalize_extras(extras, true).unwrap(); + assert_eq!(result[0], "param1=value%20with%20spaces"); + assert_eq!(result[1], "param2=value%20with%20spaces"); + } + + #[test] + fn test_params_with_spaces() { + let extras = vec![ + "param1= value with spaces ".to_string(), + "param2= \"value with spaces\" ".to_string(), + "param3=\" value with spaces \"".to_string(), + ]; + let result = normalize_extras(extras, true).unwrap(); + assert_eq!(result[0], "param1=value%20with%20spaces"); + assert_eq!(result[1], "param2=value%20with%20spaces"); + assert_eq!(result[2], "param3=%20%20value%20with%20spaces%20"); + } }