diff --git a/Cargo.lock b/Cargo.lock index d9763a1..dde91fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -789,6 +789,7 @@ dependencies = [ "tempfile", "tiny_http", "toml", + "urlencoding", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ebbf04c..427d125 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ sysinfo = { version = "0.38.4", default-features = false, features = ["system"] self_update = { version = "0.42", default-features = false, features = ["rustls"] } lzma-rs = "0.3" tempfile = "3" +urlencoding = "2.1.3" [dev-dependencies] mockito = "1" diff --git a/src/api.rs b/src/api.rs index 4436ab3..59e9e09 100644 --- a/src/api.rs +++ b/src/api.rs @@ -32,6 +32,7 @@ pub struct ApiClient { pub api_url: String, workspace_id: Option, sandbox_id: Option, + database_id: Option, } impl ApiClient { @@ -117,6 +118,7 @@ impl ApiClient { } profile_config.sandbox }), + database_id: workspace_id.and_then(|ws| crate::config::load_current_database("default", ws)), } } @@ -129,6 +131,7 @@ impl ApiClient { api_url: api_url.to_string(), workspace_id: workspace_id.map(String::from), sandbox_id: None, + database_id: None, } } @@ -167,6 +170,9 @@ impl ApiClient { req = req.header("X-Session-Id", sid); req = req.header("X-Sandbox-Id", sid); } + if let Some(ref db_id) = self.database_id { + req = req.header("X-Database-Id", db_id); + } req } diff --git a/src/command.rs b/src/command.rs index a92193a..a8ac3de 100644 --- a/src/command.rs +++ b/src/command.rs @@ -71,7 +71,7 @@ pub enum Commands { /// Managed databases you create and populate with tables (parquet uploads) Databases { - /// Database name or connection ID (omit to use a subcommand) + /// Database id or description (omit to use a subcommand) name_or_id: Option, /// Workspace ID (defaults to first workspace from login) @@ -557,15 +557,15 @@ pub enum DatabasesCommands { /// Create a new managed database Create { - /// Database name (used as the connection name in SQL: `name.schema.table`) + /// Optional display label (not unique, not an identifier — databases are addressed by id) #[arg(long)] - name: String, + description: Option, /// Schema for tables declared at create time (default: public) #[arg(long, default_value = "public")] schema: String, - /// Table to declare up front (repeatable). Required before load on current API. + /// Table to declare up front (repeatable) #[arg(long = "table")] tables: Vec, @@ -574,6 +574,12 @@ pub enum DatabasesCommands { output: String, }, + /// Set the current database (used by default when no database is specified) + Set { + /// Database id or description + id_or_description: String, + }, + /// Delete a managed database and its tables Delete { /// Database name or connection ID @@ -610,8 +616,9 @@ pub enum DatabasesCommands { pub enum DatabaseTablesCommands { /// List tables in a managed database List { - /// Database name or connection ID - database: String, + /// Database id or description (defaults to current database) + #[arg(long)] + database: Option, /// Filter by schema name #[arg(long)] @@ -624,8 +631,9 @@ pub enum DatabaseTablesCommands { /// Load a parquet file into a table (creates or replaces the table) Load { - /// Database name or connection ID - database: String, + /// Database id or description (defaults to current database) + #[arg(long)] + database: Option, /// Table name table: String, @@ -649,8 +657,9 @@ pub enum DatabaseTablesCommands { /// Delete a table from a managed database Delete { - /// Database name or connection ID - database: String, + /// Database id or description (defaults to current database) + #[arg(long)] + database: Option, /// Table name table: String, diff --git a/src/config.rs b/src/config.rs index 43bdf5b..a0f680d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -101,6 +101,8 @@ pub struct ProfileConfig { pub workspaces: Vec, #[serde(default, skip_serializing_if = "Option::is_none", alias = "session")] pub sandbox: Option, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub current_databases: HashMap, } #[derive(Debug, Deserialize, Serialize)] @@ -227,6 +229,60 @@ pub fn clear_sandbox(profile: &str) -> Result<(), String> { write_config(&config_path, &content) } +pub fn save_current_database(profile: &str, workspace_id: &str, database_id: &str) -> Result<(), String> { + let config_path = config_path()?; + + let mut config_file: ConfigFile = if config_path.exists() { + let content = fs::read_to_string(&config_path) + .map_err(|e| format!("error reading config file: {e}"))?; + serde_yaml::from_str(&content).map_err(|e| format!("error parsing config file: {e}"))? + } else { + ConfigFile { profiles: HashMap::new() } + }; + + config_file + .profiles + .entry(profile.to_string()) + .or_default() + .current_databases + .insert(workspace_id.to_string(), database_id.to_string()); + + let content = serde_yaml::to_string(&config_file) + .map_err(|e| format!("error serializing config: {e}"))?; + write_config(&config_path, &content) +} + +pub fn load_current_database(profile: &str, workspace_id: &str) -> Option { + let config_path = config_path().ok()?; + if !config_path.exists() { + return None; + } + let content = fs::read_to_string(&config_path).ok()?; + let config_file: ConfigFile = serde_yaml::from_str(&content).ok()?; + config_file.profiles.get(profile)?.current_databases.get(workspace_id).cloned() +} + +pub fn clear_current_database(profile: &str, workspace_id: &str) -> Result<(), String> { + let config_path = config_path()?; + + if !config_path.exists() { + return Ok(()); + } + + let content = fs::read_to_string(&config_path) + .map_err(|e| format!("error reading config file: {e}"))?; + let mut config_file: ConfigFile = + serde_yaml::from_str(&content).map_err(|e| format!("error parsing config file: {e}"))?; + + if let Some(entry) = config_file.profiles.get_mut(profile) { + entry.current_databases.remove(workspace_id); + } + + let content = serde_yaml::to_string(&config_file) + .map_err(|e| format!("error serializing config: {e}"))?; + write_config(&config_path, &content) +} + pub fn resolve_workspace_id(provided: Option, profile_config: &ProfileConfig) -> Result { if let Some(id) = provided { return Ok(id); diff --git a/src/databases.rs b/src/databases.rs index 1ee73da..d6e0e1d 100644 --- a/src/databases.rs +++ b/src/databases.rs @@ -3,30 +3,34 @@ use indicatif::{ProgressBar, ProgressStyle}; use serde::{Deserialize, Serialize}; use std::path::Path; -const MANAGED_SOURCE_TYPE: &str = "managed"; const DEFAULT_SCHEMA: &str = "public"; +/// Summary row returned by `GET /databases` (no `default_connection_id`). #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] -pub struct Database { - pub id: String, - pub name: String, - pub source_type: String, +struct DatabaseSummary { + id: String, + description: Option, } #[derive(Deserialize)] -struct ListConnectionsResponse { - connections: Vec, +struct ListDatabasesResponse { + databases: Vec, } -#[derive(Deserialize, Serialize)] -struct DatabaseDetail { - id: String, - name: String, - source_type: String, - #[serde(default)] - table_count: u64, +/// Full record returned by `GET /databases/{id}`. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +pub struct Database { + pub id: String, + pub description: Option, + pub default_connection_id: String, #[serde(default)] - synced_table_count: u64, + attachments: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +struct DatabaseAttachment { + connection_id: String, + alias: Option, } #[derive(Deserialize)] @@ -56,10 +60,10 @@ struct TableRow { } #[derive(Deserialize, Serialize)] -struct CreateConnectionResponse { +struct CreateDatabaseResponse { id: String, - name: String, - source_type: String, + description: Option, + default_connection_id: String, } #[derive(Deserialize)] @@ -73,43 +77,45 @@ struct LoadManagedTableResponse { arrow_schema_json: String, } -fn is_managed(db: &Database) -> bool { - db.source_type == MANAGED_SOURCE_TYPE +fn fetch_database(api: &ApiClient, id: &str) -> Database { + api.get(&format!("/databases/{id}")) } -pub fn try_resolve_database(api: &ApiClient, name_or_id: &str) -> Result { - let body: ListConnectionsResponse = api.get("/connections"); - let by_id = body - .connections +pub fn try_resolve_database(api: &ApiClient, id_or_description: &str) -> Result { + // Try a direct id lookup first — avoids the list round-trip for the common case. + // Percent-encode the segment so descriptions containing spaces or other URL-unsafe + // characters don't cause a URL parse error before the list fallback can run. + let encoded = urlencoding::encode(id_or_description); + if let Some(db) = api.get_none_if_not_found(&format!("/databases/{encoded}")) { + return Ok(db); + } + + // Fall back to listing and matching by description. + let body: ListDatabasesResponse = api.get("/databases"); + let desc_matches: Vec<&DatabaseSummary> = body + .databases .iter() - .find(|c| c.id == name_or_id) - .cloned(); - let found = by_id.or_else(|| { - body.connections - .iter() - .find(|c| c.name == name_or_id) - .cloned() - }); - match found { - Some(db) if is_managed(&db) => Ok(db), - Some(db) => Err(format!( - "'{}' is not a managed database (source_type: {})", - db.name, db.source_type + .filter(|d| d.description.as_deref() == Some(id_or_description)) + .collect(); + + match desc_matches.len() { + 0 => Err(format!( + "no database with id or description '{id_or_description}'" + )), + 1 => Ok(fetch_database(api, &desc_matches[0].id)), + _ => Err(format!( + "multiple databases have description '{}' — use the database id instead", + id_or_description )), - None => Err(format!("no database named or with id '{name_or_id}'")), } } -pub fn resolve_database(api: &ApiClient, name_or_id: &str) -> Database { - match try_resolve_database(api, name_or_id) { +pub fn resolve_database(api: &ApiClient, id_or_description: &str) -> Database { + match try_resolve_database(api, id_or_description) { Ok(db) => db, Err(e) => { use crossterm::style::Stylize; - if e.contains("not a managed database") { - eprintln!("{}", format!("error: {e}. Use `hotdata connections` for remote sources.").red()); - } else { - eprintln!("{}", format!("error: {e}").red()); - } + eprintln!("{}", format!("error: {e}").red()); std::process::exit(1); } } @@ -119,28 +125,33 @@ fn schema_name(schema: Option<&str>) -> &str { schema.unwrap_or(DEFAULT_SCHEMA) } -/// Build managed-connection `config` with declared schemas/tables. -pub fn build_managed_config(schema: &str, tables: &[String]) -> serde_json::Value { - if tables.is_empty() { - return serde_json::json!({}); +/// Build the request body for `POST /v1/databases`. +pub fn create_database_request( + description: Option<&str>, + schema: &str, + tables: &[String], +) -> serde_json::Value { + let mut req = serde_json::Map::new(); + + if let Some(desc) = description { + req.insert( + "description".to_string(), + serde_json::Value::String(desc.to_string()), + ); } - let table_objs: Vec = tables - .iter() - .map(|t| serde_json::json!({ "name": t })) - .collect(); - serde_json::json!({ - "schemas": [{ "name": schema, "tables": table_objs }] - }) -} -/// Request body for `POST /v1/connections` when creating a managed database. -pub fn create_connection_request(name: &str, schema: &str, tables: &[String]) -> serde_json::Value { - serde_json::json!({ - "name": name, - "source_type": MANAGED_SOURCE_TYPE, - "config": build_managed_config(schema, tables), - "skip_discovery": true, - }) + if !tables.is_empty() { + let table_objs: Vec = tables + .iter() + .map(|t| serde_json::json!({ "name": t })) + .collect(); + req.insert( + "schemas".to_string(), + serde_json::json!([{ "name": schema, "tables": table_objs }]), + ); + } + + serde_json::Value::Object(req) } pub fn managed_table_load_path(connection_id: &str, schema: &str, table: &str) -> String { @@ -164,11 +175,11 @@ pub fn is_parquet_path(path: &str) -> bool { || Path::new(path).extension().and_then(|e| e.to_str()) == Some("parquet") } -fn table_rows_for_database(db_name: &str, tables: Vec) -> Vec { +fn table_rows(tables: Vec) -> Vec { tables .into_iter() .map(|t| TableRow { - full_name: format!("{}.{}.{}", db_name, t.schema, t.table), + full_name: format!("default.{}.{}", t.schema, t.table), schema: t.schema, table: t.table, synced: t.synced, @@ -177,7 +188,12 @@ fn table_rows_for_database(db_name: &str, tables: Vec) -> Vec, pb: &ProgressBar) -> String { +fn finish_upload( + api: &ApiClient, + reader: impl std::io::Read + Send + 'static, + size: Option, + pb: &ProgressBar, +) -> String { let (status, resp_body) = api.post_body("/files", "application/octet-stream", reader, size); pb.finish_and_clear(); @@ -251,7 +267,10 @@ fn upload_parquet_url(api: &ApiClient, url: &str) -> String { }; if !resp.status().is_success() { - eprintln!("error: remote server returned {} for '{url}'", resp.status()); + eprintln!( + "error: remote server returned {} for '{url}'", + resp.status() + ); std::process::exit(1); } @@ -285,9 +304,8 @@ fn collect_tables(api: &ApiClient, connection_id: &str, schema: Option<&str>) -> let mut out = Vec::new(); let mut cursor: Option = None; loop { - let mut params: Vec<(&str, Option)> = vec![ - ("connection_id", Some(connection_id.to_string())), - ]; + let mut params: Vec<(&str, Option)> = + vec![("connection_id", Some(connection_id.to_string()))]; if let Some(s) = schema { params.push(("schema", Some(s.to_string()))); } @@ -314,73 +332,97 @@ fn collect_tables(api: &ApiClient, connection_id: &str, schema: Option<&str>) -> pub fn list(workspace_id: &str, format: &str) { let api = ApiClient::new(Some(workspace_id)); - let body: ListConnectionsResponse = api.get("/connections"); - let databases: Vec<&Database> = body - .connections - .iter() - .filter(|c| is_managed(c)) - .collect(); + let body: ListDatabasesResponse = api.get("/databases"); match format { - "json" => println!("{}", serde_json::to_string_pretty(&databases).unwrap()), - "yaml" => print!("{}", serde_yaml::to_string(&databases).unwrap()), + "json" => println!("{}", serde_json::to_string_pretty(&body.databases).unwrap()), + "yaml" => print!("{}", serde_yaml::to_string(&body.databases).unwrap()), "table" => { - if databases.is_empty() { + if body.databases.is_empty() { use crossterm::style::Stylize; eprintln!("{}", "No databases found.".dark_grey()); eprintln!( "{}", - "Create one with: hotdata databases create --name ".dark_grey() + "Create one with: hotdata databases create".dark_grey() ); } else { - let rows: Vec> = databases + let rows: Vec> = body + .databases .iter() - .map(|d| vec![d.name.clone(), d.id.clone()]) + .map(|d| { + vec![ + d.description.as_deref().unwrap_or("-").to_string(), + d.id.clone(), + ] + }) .collect(); - crate::table::print(&["NAME", "ID"], &rows); + crate::table::print(&["DESCRIPTION", "ID"], &rows); } } _ => unreachable!(), } } -pub fn get(workspace_id: &str, name_or_id: &str, format: &str) { +pub fn get(workspace_id: &str, id_or_description: &str, format: &str) { let api = ApiClient::new(Some(workspace_id)); - let db = resolve_database(&api, name_or_id); - let detail: DatabaseDetail = api.get(&format!("/connections/{}", db.id)); + let db = resolve_database(&api, id_or_description); match format { - "json" => println!("{}", serde_json::to_string_pretty(&detail).unwrap()), - "yaml" => print!("{}", serde_yaml::to_string(&detail).unwrap()), + "json" => println!("{}", serde_json::to_string_pretty(&db).unwrap()), + "yaml" => print!("{}", serde_yaml::to_string(&db).unwrap()), "table" => { use crossterm::style::Stylize; - let label = |l: &str| format!("{:<16}", l).dark_grey().to_string(); - println!("{}{}", label("name:"), detail.name.clone().white()); - println!("{}{}", label("id:"), detail.id.dark_cyan()); + let label = |l: &str| format!("{:<24}", l).dark_grey().to_string(); + println!("{}{}", label("id:"), db.id.clone().dark_cyan()); + println!( + "{}{}", + label("description:"), + db.description.as_deref().unwrap_or("-").white() + ); println!( - "{}{} synced / {} total", - label("tables:"), - detail.synced_table_count.to_string().cyan(), - detail.table_count.to_string().cyan(), + "{}{}", + label("default_connection_id:"), + db.default_connection_id.clone().dark_cyan() ); println!( "{}{}", label("sql_prefix:"), - format!("{}.{{schema}}.{{table}}", detail.name).green() + "default.{schema}.{table} (pass X-Database-Id header when querying)".green() ); + if !db.attachments.is_empty() { + println!("{}({})", label("attached catalogs:"), db.attachments.len()); + for a in &db.attachments { + let alias = a + .alias + .as_deref() + .map(|al| format!(" as {al}")) + .unwrap_or_default(); + println!( + " {}{}", + a.connection_id.clone().dark_cyan(), + alias.dark_grey() + ); + } + } } _ => unreachable!(), } } -pub fn create(workspace_id: &str, name: &str, schema: &str, tables: &[String], format: &str) { +pub fn create( + workspace_id: &str, + description: Option<&str>, + schema: &str, + tables: &[String], + format: &str, +) { use crossterm::style::Stylize; - let body = create_connection_request(name, schema, tables); + let body = create_database_request(description, schema, tables); let api = ApiClient::new(Some(workspace_id)); let spinner = (format == "table").then(|| crate::util::spinner("Creating database...")); - let (status, resp_body) = api.post_raw("/connections", &body); + let (status, resp_body) = api.post_raw("/databases", &body); if let Some(s) = &spinner { s.finish_and_clear(); } @@ -390,7 +432,7 @@ pub fn create(workspace_id: &str, name: &str, schema: &str, tables: &[String], f std::process::exit(1); } - let result: CreateConnectionResponse = match serde_json::from_str(&resp_body) { + let result: CreateDatabaseResponse = match serde_json::from_str(&resp_body) { Ok(v) => v, Err(e) => { eprintln!("error parsing response: {e}"); @@ -398,47 +440,81 @@ pub fn create(workspace_id: &str, name: &str, schema: &str, tables: &[String], f } }; + if let Err(e) = crate::config::save_current_database("default", workspace_id, &result.id) { + use crossterm::style::Stylize; + eprintln!("{}", format!("warning: database created but could not set as current: {e}").yellow()); + } + match format { "json" => println!("{}", serde_json::to_string_pretty(&result).unwrap()), "yaml" => print!("{}", serde_yaml::to_string(&result).unwrap()), "table" => { println!("{}", "Database created".green()); - println!("name: {}", result.name); - println!("id: {}", result.id); + if let Some(desc) = &result.description { + println!("description: {desc}"); + } + println!("id: {}", result.id); } _ => unreachable!(), } } -pub fn delete(workspace_id: &str, name_or_id: &str) { +pub fn set(workspace_id: &str, id_or_description: &str) { + use crossterm::style::Stylize; + let api = ApiClient::new(Some(workspace_id)); + let db = resolve_database(&api, id_or_description); + if let Err(e) = crate::config::save_current_database("default", workspace_id, &db.id) { + eprintln!("{}", format!("error saving current database: {e}").red()); + std::process::exit(1); + } + println!("{}", format!("Current database set to {}", db.id).green()); +} + +fn resolve_current_database(provided: Option<&str>, workspace_id: &str) -> String { + if let Some(id) = provided { + return id.to_string(); + } + match crate::config::load_current_database("default", workspace_id) { + Some(id) => id, + None => { + use crossterm::style::Stylize; + eprintln!( + "{}", + "error: no current database set. Use 'hotdata databases set ' or pass a database id.".red() + ); + std::process::exit(1); + } + } +} + +pub fn delete(workspace_id: &str, id_or_description: &str) { use crossterm::style::Stylize; let api = ApiClient::new(Some(workspace_id)); - let db = resolve_database(&api, name_or_id); - let (status, resp_body) = api.delete_raw(&format!("/connections/{}", db.id)); + let db = resolve_database(&api, id_or_description); + let (status, resp_body) = api.delete_raw(&format!("/databases/{}", db.id)); if !status.is_success() { eprintln!("{}", crate::util::api_error(resp_body).red()); std::process::exit(1); } - println!( - "{}", - format!("Database '{}' deleted.", db.name).green() - ); + // If the deleted database was the current one, clear it so subsequent + // commands don't silently send a stale X-Database-Id header. + if crate::config::load_current_database("default", workspace_id).as_deref() == Some(&db.id) { + let _ = crate::config::clear_current_database("default", workspace_id); + } + + println!("{}", "Database deleted.".green()); } -pub fn tables_list( - workspace_id: &str, - database: &str, - schema: Option<&str>, - format: &str, -) { +pub fn tables_list(workspace_id: &str, database: Option<&str>, schema: Option<&str>, format: &str) { + let database = resolve_current_database(database, workspace_id); let api = ApiClient::new(Some(workspace_id)); - let db = resolve_database(&api, database); - let tables = collect_tables(&api, &db.id, schema); + let db = resolve_database(&api, &database); + let tables = collect_tables(&api, &db.default_connection_id, schema); - let rows = table_rows_for_database(&db.name, tables); + let rows = table_rows(tables); match format { "json" => println!("{}", serde_json::to_string_pretty(&rows).unwrap()), @@ -470,7 +546,7 @@ pub fn tables_list( pub fn tables_load( workspace_id: &str, - database: &str, + database: Option<&str>, table: &str, schema: Option<&str>, file: Option<&str>, @@ -479,8 +555,9 @@ pub fn tables_load( ) { use crossterm::style::Stylize; + let database = resolve_current_database(database, workspace_id); let api = ApiClient::new(Some(workspace_id)); - let db = resolve_database(&api, database); + let db = resolve_database(&api, &database); let schema = schema_name(schema); // clap enforces mutual exclusion; only one of these is ever Some. @@ -495,7 +572,7 @@ pub fn tables_load( _ => unreachable!(), }; - let path = managed_table_load_path(&db.id, schema, table); + let path = managed_table_load_path(&db.default_connection_id, schema, table); let body = load_table_request(&upload_id); let spinner = crate::util::spinner("Loading table..."); @@ -508,12 +585,9 @@ pub fn tables_load( eprintln!("{}", msg.red()); eprintln!( "{}", - format!( - "Declare the table when creating the database, e.g.:\n \ - hotdata databases create --name {} --table {}", - db.name, table - ) - .dark_grey() + "Declare the table when creating the database, e.g.:\n \ + hotdata databases create --table " + .dark_grey() ); } else { eprintln!("{}", msg.red()); @@ -529,25 +603,21 @@ pub fn tables_load( } }; - let full_name = format!("{}.{}.{}", db.name, result.schema_name, result.table_name); + let full_name = format!("default.{}.{}", result.schema_name, result.table_name); println!("{}", "Table loaded".green()); println!("full_name: {}", full_name.green()); println!("rows: {}", result.row_count); } -pub fn tables_delete( - workspace_id: &str, - database: &str, - table: &str, - schema: Option<&str>, -) { +pub fn tables_delete(workspace_id: &str, database: Option<&str>, table: &str, schema: Option<&str>) { use crossterm::style::Stylize; + let database = resolve_current_database(database, workspace_id); let api = ApiClient::new(Some(workspace_id)); - let db = resolve_database(&api, database); + let db = resolve_database(&api, &database); let schema = schema_name(schema); - let path = managed_table_delete_path(&db.id, schema, table); + let path = managed_table_delete_path(&db.default_connection_id, schema, table); let (status, resp_body) = api.delete_raw(&path); if !status.is_success() { @@ -557,7 +627,7 @@ pub fn tables_delete( println!( "{}", - format!("Table '{}.{}.{}' deleted.", db.name, schema, table).green() + format!("Table 'default.{}.{}' deleted.", schema, table).green() ); } @@ -572,115 +642,125 @@ mod tests { } #[test] - fn build_managed_config_empty_without_tables() { - assert_eq!(build_managed_config("public", &[]), serde_json::json!({})); + fn create_database_request_empty_without_description_or_tables() { + let req = create_database_request(None, "public", &[]); + assert_eq!(req, serde_json::json!({})); } #[test] - fn build_managed_config_declares_tables() { - let cfg = build_managed_config("public", &["orders".to_string(), "customers".to_string()]); - assert_eq!( - cfg, - serde_json::json!({ - "schemas": [{ - "name": "public", - "tables": [{ "name": "orders" }, { "name": "customers" }] - }] - }) - ); + fn create_database_request_includes_description() { + let req = create_database_request(Some("my db"), "public", &[]); + assert_eq!(req["description"], "my db"); + assert!(req.get("schemas").is_none()); } #[test] - fn is_managed_only_matches_managed_type() { - let db = Database { - id: "c1".into(), - name: "sales".into(), - source_type: "managed".into(), - }; - assert!(is_managed(&db)); - let pg = Database { - id: "c2".into(), - name: "warehouse".into(), - source_type: "postgres".into(), - }; - assert!(!is_managed(&pg)); + fn create_database_request_includes_schemas_when_tables_declared() { + let req = create_database_request( + Some("sales"), + "public", + &["orders".to_string(), "customers".to_string()], + ); + assert_eq!(req["description"], "sales"); + assert_eq!(req["schemas"][0]["name"], "public"); + assert_eq!(req["schemas"][0]["tables"][0]["name"], "orders"); + assert_eq!(req["schemas"][0]["tables"][1]["name"], "customers"); } #[test] - fn resolve_database_by_name_and_id() { - let mut server = mockito::Server::new(); - let mock = server - .mock("GET", "/connections") - .with_status(200) - .with_body( - r#"{"connections":[ - {"id":"conn_abc","name":"sales","source_type":"managed"}, - {"id":"conn_xyz","name":"warehouse","source_type":"postgres"} - ]}"#, - ) - .expect(2) - .create(); + fn create_database_request_schemas_without_description() { + let req = create_database_request(None, "analytics", &["events".to_string()]); + assert!(req.get("description").is_none()); + assert_eq!(req["schemas"][0]["name"], "analytics"); + } - let api = ApiClient::test_new(&server.url(), "k", Some("ws")); - let by_name = resolve_database(&api, "sales"); - assert_eq!(by_name.id, "conn_abc"); - let by_id = resolve_database(&api, "conn_abc"); - assert_eq!(by_id.name, "sales"); - mock.assert(); + fn full_detail(id: &str, desc: &str, conn_id: &str) -> String { + format!( + r#"{{"id":"{id}","description":"{desc}","default_connection_id":"{conn_id}","attachments":[]}}"# + ) } #[test] - fn try_resolve_database_rejects_non_managed() { + fn resolve_database_by_id_and_description() { let mut server = mockito::Server::new(); - let mock = server - .mock("GET", "/connections") + // by-id path: direct GET /databases/db_abc succeeds + let by_id_mock = server + .mock("GET", "/databases/db_abc") + .with_status(200) + .with_body(full_detail("db_abc", "sales", "conn_1")) + .create(); + // by-description path: GET /databases/warehouse → 404, then list, then detail + let not_id = server + .mock("GET", "/databases/warehouse") + .with_status(404) + .with_body(r#"{"error":"not found"}"#) + .create(); + let list = server + .mock("GET", "/databases") .with_status(200) .with_body( - r#"{"connections":[{"id":"c1","name":"warehouse","source_type":"postgres"}]}"#, + r#"{"databases":[{"id":"db_abc","description":"sales"},{"id":"db_xyz","description":"warehouse"}]}"#, ) .create(); + let detail = server + .mock("GET", "/databases/db_xyz") + .with_status(200) + .with_body(full_detail("db_xyz", "warehouse", "conn_2")) + .create(); - let api = ApiClient::test_new(&server.url(), "k", None); - let err = try_resolve_database(&api, "warehouse").unwrap_err(); - assert!(err.contains("not a managed database")); - mock.assert(); + let api = ApiClient::test_new(&server.url(), "k", Some("ws")); + let by_id = resolve_database(&api, "db_abc"); + assert_eq!(by_id.default_connection_id, "conn_1"); + let by_desc = resolve_database(&api, "warehouse"); + assert_eq!(by_desc.id, "db_xyz"); + by_id_mock.assert(); + not_id.assert(); + list.assert(); + detail.assert(); } #[test] fn try_resolve_database_not_found() { let mut server = mockito::Server::new(); - let mock = server - .mock("GET", "/connections") + // Direct id lookup returns 404 + server + .mock("GET", "/databases/missing") + .with_status(404) + .with_body(r#"{"error":"not found"}"#) + .create(); + // List also returns nothing + server + .mock("GET", "/databases") .with_status(200) - .with_body(r#"{"connections":[]}"#) + .with_body(r#"{"databases":[]}"#) .create(); let api = ApiClient::test_new(&server.url(), "k", None); let err = try_resolve_database(&api, "missing").unwrap_err(); - assert!(err.contains("no database named")); - mock.assert(); + assert!(err.contains("no database with id or description")); } #[test] - fn create_connection_request_includes_declared_tables() { - let body = create_connection_request( - "sales", - "public", - &["orders".to_string(), "customers".to_string()], - ); - assert_eq!(body["name"], "sales"); - assert_eq!(body["source_type"], "managed"); - assert_eq!(body["skip_discovery"], true); - assert_eq!( - body["config"]["schemas"][0]["tables"][0]["name"], - "orders" - ); - } + fn try_resolve_database_rejects_ambiguous_description() { + let mut server = mockito::Server::new(); + // Direct id lookup returns 404 (description isn't a valid id) + server + .mock("GET", "/databases/sales") + .with_status(404) + .with_body(r#"{"error":"not found"}"#) + .create(); + // List returns two entries with the same description + server + .mock("GET", "/databases") + .with_status(200) + .with_body( + r#"{"databases":[{"id":"db_1","description":"sales"},{"id":"db_2","description":"sales"}]}"#, + ) + .create(); - #[test] - fn create_connection_request_empty_config_without_tables() { - let body = create_connection_request("sales", "public", &[]); - assert_eq!(body["config"], serde_json::json!({})); + let api = ApiClient::test_new(&server.url(), "k", None); + let err = try_resolve_database(&api, "sales").unwrap_err(); + assert!(err.contains("multiple databases")); } #[test] @@ -712,19 +792,16 @@ mod tests { } #[test] - fn table_rows_for_database_builds_full_names() { - let rows = table_rows_for_database( - "sales", - vec![InfoTable { - connection: "sales".into(), - schema: "public".into(), - table: "orders".into(), - synced: true, - last_sync: Some("2026-05-19T00:00:00Z".into()), - }], - ); + fn table_rows_uses_default_prefix() { + let rows = table_rows(vec![InfoTable { + connection: "ignored".into(), + schema: "public".into(), + table: "orders".into(), + synced: true, + last_sync: Some("2026-05-19T00:00:00Z".into()), + }]); assert_eq!(rows.len(), 1); - assert_eq!(rows[0].full_name, "sales.public.orders"); + assert_eq!(rows[0].full_name, "default.public.orders"); assert!(rows[0].synced); } @@ -739,7 +816,7 @@ mod tests { ])) .with_status(200) .with_body( - r#"{"tables":[{"connection":"sales","schema":"public","table":"b","synced":true,"last_sync":null}],"has_more":false,"next_cursor":null}"#, + r#"{"tables":[{"connection":"default","schema":"public","table":"b","synced":true,"last_sync":null}],"has_more":false,"next_cursor":null}"#, ) .create(); let page0 = server @@ -750,7 +827,7 @@ mod tests { )) .with_status(200) .with_body( - r#"{"tables":[{"connection":"sales","schema":"public","table":"a","synced":false,"last_sync":null}],"has_more":true,"next_cursor":"cur2"}"#, + r#"{"tables":[{"connection":"default","schema":"public","table":"a","synced":false,"last_sync":null}],"has_more":true,"next_cursor":"cur2"}"#, ) .create(); @@ -764,18 +841,18 @@ mod tests { } #[test] - fn create_posts_managed_connection_with_schemas() { + fn create_posts_to_databases_endpoint() { let mut server = mockito::Server::new(); let mock = server - .mock("POST", "/connections") + .mock("POST", "/databases") .match_header("X-Workspace-Id", "ws-test") .with_status(201) .with_body( - r#"{"id":"conn_new","name":"mydb","source_type":"managed","tables_discovered":1,"discovery_status":"skipped"}"#, + r#"{"id":"db_new","description":"mydb","default_connection_id":"conn_abc"}"#, ) .match_body(mockito::Matcher::JsonString( - serde_json::to_string(&create_connection_request( - "mydb", + serde_json::to_string(&create_database_request( + Some("mydb"), "public", &["gdp".to_string()], )) @@ -784,34 +861,36 @@ mod tests { .create(); let api = ApiClient::test_new(&server.url(), "k", Some("ws-test")); - let body = create_connection_request("mydb", "public", &["gdp".to_string()]); - let (status, resp_body) = api.post_raw("/connections", &body); + let body = create_database_request(Some("mydb"), "public", &["gdp".to_string()]); + let (status, resp_body) = api.post_raw("/databases", &body); assert_eq!(status.as_u16(), 201); - let parsed: CreateConnectionResponse = serde_json::from_str(&resp_body).unwrap(); - assert_eq!(parsed.name, "mydb"); - assert_eq!(parsed.source_type, "managed"); + let parsed: CreateDatabaseResponse = serde_json::from_str(&resp_body).unwrap(); + assert_eq!(parsed.description.as_deref(), Some("mydb")); + assert_eq!(parsed.default_connection_id, "conn_abc"); mock.assert(); } #[test] - fn tables_load_posts_replace_with_upload_id() { + fn tables_load_uses_default_connection_id() { let mut server = mockito::Server::new(); - let list = server - .mock("GET", "/connections") + // resolve_database resolves by id directly + let resolve = server + .mock("GET", "/databases/db_1") .with_status(200) - .with_body( - r#"{"connections":[{"id":"conn1","name":"sales","source_type":"managed"}]}"#, - ) + .with_body(full_detail("db_1", "sales", "conn_default")) .create(); let load = server - .mock("POST", "/connections/conn1/schemas/public/tables/orders/loads") + .mock( + "POST", + "/connections/conn_default/schemas/public/tables/orders/loads", + ) .match_body(mockito::Matcher::JsonString( serde_json::to_string(&load_table_request("upl_123")).unwrap(), )) .with_status(200) .with_body( r#"{ - "connection_id":"conn1", + "connection_id":"conn_default", "schema_name":"public", "table_name":"orders", "row_count":42, @@ -821,40 +900,41 @@ mod tests { .create(); let api = ApiClient::test_new(&server.url(), "k", Some("ws1")); - let db = resolve_database(&api, "sales"); - let path = managed_table_load_path(&db.id, "public", "orders"); + let db = resolve_database(&api, "db_1"); + let path = managed_table_load_path(&db.default_connection_id, "public", "orders"); let body = load_table_request("upl_123"); let (status, resp_body) = api.post_raw(&path, &body); assert!(status.is_success()); let parsed: LoadManagedTableResponse = serde_json::from_str(&resp_body).unwrap(); assert_eq!(parsed.row_count, 42); assert_eq!(parsed.table_name, "orders"); - list.assert(); + resolve.assert(); load.assert(); } #[test] - fn tables_delete_calls_managed_table_endpoint() { + fn tables_delete_uses_default_connection_id() { let mut server = mockito::Server::new(); - let list = server - .mock("GET", "/connections") + let resolve = server + .mock("GET", "/databases/db_1") .with_status(200) - .with_body( - r#"{"connections":[{"id":"conn1","name":"sales","source_type":"managed"}]}"#, - ) + .with_body(full_detail("db_1", "sales", "conn_default")) .create(); let delete = server - .mock("DELETE", "/connections/conn1/schemas/public/tables/orders") + .mock( + "DELETE", + "/connections/conn_default/schemas/public/tables/orders", + ) .with_status(204) .with_body("") .create(); let api = ApiClient::test_new(&server.url(), "k", None); - let db = resolve_database(&api, "sales"); - let path = managed_table_delete_path(&db.id, "public", "orders"); + let db = resolve_database(&api, "db_1"); + let path = managed_table_delete_path(&db.default_connection_id, "public", "orders"); let (status, _) = api.delete_raw(&path); assert_eq!(status.as_u16(), 204); - list.assert(); + resolve.assert(); delete.assert(); } diff --git a/src/main.rs b/src/main.rs index 48f4569..4fb4a38 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,6 +56,10 @@ struct Cli { command: Option, } +/// Set once after workspace resolution so the database footer can reference it +/// without re-doing config I/O. +static ACTIVE_WORKSPACE_ID: std::sync::OnceLock = std::sync::OnceLock::new(); + fn resolve_workspace(provided: Option) -> String { // HOTDATA_WORKSPACE env var takes priority and blocks --workspace-id flag if let Ok(ws) = std::env::var("HOTDATA_WORKSPACE") { @@ -67,6 +71,7 @@ fn resolve_workspace(provided: Option) -> String { ); std::process::exit(1); } + let _ = ACTIVE_WORKSPACE_ID.set(ws.clone()); return ws; } if sandbox::find_sandbox_run_ancestor().is_some() { @@ -75,7 +80,10 @@ fn resolve_workspace(provided: Option) -> String { } match config::load("default") { Ok(profile) => match config::resolve_workspace_id(provided, &profile) { - Ok(id) => id, + Ok(id) => { + let _ = ACTIVE_WORKSPACE_ID.set(id.clone()); + id + } Err(e) => { eprintln!("error: {e}"); std::process::exit(1); @@ -125,11 +133,25 @@ extern "C" fn print_sandbox_footer() { ); } +extern "C" fn print_database_footer() { + use crossterm::style::Stylize; + if let Some(ws_id) = ACTIVE_WORKSPACE_ID.get() { + if let Some(id) = config::load_current_database("default", ws_id) { + eprintln!( + "{}", + format!("current database: {id} use 'hotdata databases set' to change") + .dark_grey(), + ); + } + } +} + fn main() { // Register before `Cli::parse`, since `--help` / `--version` exit // from inside the parser. Safety: `atexit` is async-signal-safe; // the callback only reads env vars / files and writes to stderr. unsafe { atexit(print_sandbox_footer) }; + unsafe { atexit(print_database_footer) }; dotenvy::dotenv().ok(); let cli = Cli::parse(); @@ -382,17 +404,20 @@ fn main() { databases::list(&workspace_id, &output) } Some(DatabasesCommands::Create { - name, + description, schema, tables, output, }) => databases::create( &workspace_id, - &name, + description.as_deref(), &schema, &tables, &output, ), + Some(DatabasesCommands::Set { id_or_description }) => { + databases::set(&workspace_id, &id_or_description) + } Some(DatabasesCommands::Delete { name_or_id }) => { databases::delete(&workspace_id, &name_or_id) } @@ -405,7 +430,7 @@ fn main() { let (database, schema, table) = parse_db_target(&target); databases::tables_load( &workspace_id, - &database, + Some(database.as_str()), &table, Some(schema.as_str()), file.as_deref(), @@ -420,7 +445,7 @@ fn main() { output, } => databases::tables_list( &workspace_id, - &database, + database.as_deref(), schema.as_deref(), &output, ), @@ -433,7 +458,7 @@ fn main() { upload_id, } => databases::tables_load( &workspace_id, - &database, + database.as_deref(), &table, Some(schema.as_str()), file.as_deref(), @@ -446,7 +471,7 @@ fn main() { schema, } => databases::tables_delete( &workspace_id, - &database, + database.as_deref(), &table, Some(schema.as_str()), ), diff --git a/tests/databases_cli.rs b/tests/databases_cli.rs index d9cd4ba..a479d54 100644 --- a/tests/databases_cli.rs +++ b/tests/databases_cli.rs @@ -28,7 +28,7 @@ fn databases_create_help_documents_table_flag() { assert!(output.status.success()); let help = String::from_utf8_lossy(&output.stdout); assert!(help.contains("--table")); - assert!(help.contains("--name")); + assert!(help.contains("--description")); } #[test] @@ -45,17 +45,6 @@ fn databases_tables_load_help_documents_file_and_upload_id() { assert!(help.contains("parquet")); } -#[test] -fn databases_create_requires_name() { - let output = hotdata().args(["databases", "create"]).output().unwrap(); - assert!(!output.status.success()); - let stderr = String::from_utf8_lossy(&output.stderr); - assert!( - stderr.contains("--name") || stderr.contains("required"), - "stderr: {stderr}" - ); -} - #[test] fn databases_tables_load_rejects_both_file_and_upload_id_at_parse_time() { let output = hotdata() @@ -63,7 +52,6 @@ fn databases_tables_load_rejects_both_file_and_upload_id_at_parse_time() { "databases", "tables", "load", - "mydb", "t1", "--file", "a.parquet",