diff --git a/sqlx-postgres/src/options/doc.md b/sqlx-postgres/src/options/doc.md index 33dd63b7a8..cb173dffba 100644 --- a/sqlx-postgres/src/options/doc.md +++ b/sqlx-postgres/src/options/doc.md @@ -60,6 +60,7 @@ postgresql://user@localhost postgresql://user:secret@localhost postgresql://user:correct%20horse%20battery%20staple@localhost postgresql://localhost?dbname=mydb&user=postgres&password=postgres +postgresql:///mydb ``` See also [Note: Unix Domain Sockets](#note-unix-domain-sockets) below. @@ -106,8 +107,9 @@ This behavior is _only_ implemented for the environment variables, not the URL p Note: passing the SSL private key via environment variable may be a security risk. # Note: Unix Domain Sockets -If you want to connect to Postgres over a Unix domain socket, you can pass the path -to the _directory_ containing the socket as the `host` parameter. +If you want to connect to Postgres over a Unix domain socket, you can either pass the path +to the _directory_ containing the socket as the `host` parameter, or leave the host part of the URL empty. +When the host part is omitted, the standard Unix socket directory is used. The final path to the socket will be `{host}/.s.PGSQL.{port}` as is standard for Postgres. diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index efbc43989b..17167185d3 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -62,7 +62,9 @@ impl PgConnectOptions { let host = var("PGHOSTADDR") .ok() .or_else(|| var("PGHOST").ok()) - .unwrap_or_else(|| default_host(port)); + .unwrap_or_else(|| "localhost".into()); + + let socket = default_socket(port); let username = var("PGUSER").ok().unwrap_or_else(whoami::username); @@ -71,7 +73,7 @@ impl PgConnectOptions { PgConnectOptions { port, host, - socket: None, + socket, username, password: var("PGPASSWORD").ok(), database, @@ -575,7 +577,7 @@ impl PgConnectOptions { } } -fn default_host(port: u16) -> String { +fn default_socket(port: u16) -> Option { // try to check for the existence of a unix socket and uses that let socket = format!(".s.PGSQL.{port}"); let candidates = [ @@ -586,12 +588,11 @@ fn default_host(port: u16) -> String { for candidate in &candidates { if Path::new(candidate).join(&socket).exists() { - return candidate.to_string(); + return Some(PathBuf::from(candidate)); } } - // fallback to localhost if no socket was found - "localhost".to_owned() + None } /// Writer that escapes passed-in PostgreSQL options. diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs index e911305698..c1949aa4f2 100644 --- a/sqlx-postgres/src/options/parse.rs +++ b/sqlx-postgres/src/options/parse.rs @@ -273,6 +273,26 @@ fn it_parses_socket_correctly_percent_encoded() { assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); } + +#[test] +fn it_uses_default_socket_if_host_is_skipped() { + let path_to_socket = format!("{}/.s.PGSQL.5432", default_socket_dir()); + let _ = std::fs::File::create(&path_to_socket) + .unwrap_or_else(|e| panic!("error while creating tmp socket file: {}", e)); + + let url = "postgresql:///database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let _ = std::fs::remove_file(&path_to_socket) + .unwrap_or_else(|e| panic!("error while deleting tmp socket file: {}", e)); + + assert_eq!( + Some(std::path::PathBuf::from(default_socket_dir())), + opts.socket + ); + assert_eq!(Some("database".into()), opts.database); +} + #[test] fn it_parses_socket_correctly_with_username_percent_encoded() { let url = "postgres://some_user@%2Fvar%2Flib%2Fpostgres/database"; @@ -318,6 +338,48 @@ fn it_returns_the_parsed_url_when_socket() { assert_eq!(expected_url, opts.build_url()); } +#[test] +fn it_returns_the_parsed_url_with_default_socket_when_host_is_not_defined() { + let port = 5432; + let path_to_socket = + std::path::PathBuf::from(format!("{}/.s.PGSQL.{}", default_socket_dir(), port)); + + let _ = std::fs::File::create(&path_to_socket) + .unwrap_or_else(|e| panic!("error while creating tmp socket file: {}", e)); + + let url = "postgresql:///database"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + let _ = std::fs::remove_file(&path_to_socket) + .unwrap_or_else(|e| panic!("error while deleting tmp socket file: {}", e)); + + let encoded_socket = utf8_percent_encode(&default_socket_dir(), NON_ALPHANUMERIC).to_string(); + let encoded_url = format!( + "postgres://{}@{}/database", + whoami::username(), + encoded_socket + ); + let mut expected_url = Url::parse(&encoded_url).unwrap(); + // PgConnectOptions defaults + let query_string = "sslmode=prefer&statement-cache-capacity=100"; + expected_url.set_query(Some(query_string)); + let _ = expected_url.set_port(Some(port)); + + assert_eq!(expected_url, opts.build_url()); +} + +#[cfg(test)] +fn default_socket_dir() -> String { + #[cfg(target_os = "macos")] + { + "/private/tmp".into() + } + #[cfg(target_os = "linux")] + { + "/tmp".into() + } +} + #[test] fn it_returns_the_parsed_url_when_host() { let url = "postgres://username:p@ssw0rd@hostname:5432/database";