Skip to content

Commit

Permalink
RUST-1048 Add method for retrieving the URI's default database from `…
Browse files Browse the repository at this point in the history
…Client` (#488)
  • Loading branch information
WindSoilder committed Oct 28, 2021
1 parent cdd7583 commit 578954b
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/client/mod.rs
Expand Up @@ -168,6 +168,18 @@ impl Client {
Database::new(self.clone(), name, Some(options))
}

/// Gets a handle to the default database specified in the `ClientOptions` or MongoDB connection
/// string used to construct this `Client`.
///
/// If no default database was specified, `None` will be returned.
pub fn default_database(&self) -> Option<Database> {
self.inner
.options
.default_database
.as_ref()
.map(|db_name| self.database(db_name))
}

async fn list_databases_common(
&self,
filter: impl Into<Option<Document>>,
Expand Down
16 changes: 15 additions & 1 deletion src/client/options/mod.rs
Expand Up @@ -517,6 +517,12 @@ pub struct ClientOptions {
#[builder(default)]
pub server_selection_timeout: Option<Duration>,

/// Default database for this client.
///
/// By default, no default database is specified.
#[builder(default)]
pub default_database: Option<String>,

#[builder(default, setter(skip))]
pub(crate) socket_timeout: Option<Duration>,

Expand Down Expand Up @@ -702,6 +708,7 @@ struct ClientOptionsParser {
pub zlib_compression: Option<i32>,
pub direct_connection: Option<bool>,
pub credential: Option<Credential>,
pub default_database: Option<String>,
max_staleness: Option<Duration>,
tls_insecure: Option<bool>,
auth_mechanism: Option<AuthMechanism>,
Expand Down Expand Up @@ -931,6 +938,7 @@ impl From<ClientOptionsParser> for ClientOptions {
retry_writes: parser.retry_writes,
socket_timeout: parser.socket_timeout,
direct_connection: parser.direct_connection,
default_database: parser.default_database,
driver_info: None,
credential: parser.credential,
cmap_event_handler: None,
Expand Down Expand Up @@ -969,6 +977,9 @@ impl ClientOptions {
///
/// The format of a MongoDB connection string is described [here](https://docs.mongodb.com/manual/reference/connection-string/#connection-string-formats).
///
/// Note that [default_database](ClientOptions::default_database) will be set from
/// `/defaultauthdb` in connection string.
///
/// The following options are supported in the options query string:
///
/// * `appName`: maps to the `app_name` field
Expand Down Expand Up @@ -1468,7 +1479,7 @@ impl ClientOptionsParser {
credential.source = options
.auth_source
.clone()
.or(db)
.or(db.clone())
.or_else(|| Some("admin".into()));
} else if authentication_requested {
return Err(ErrorKind::InvalidArgument {
Expand All @@ -1481,6 +1492,9 @@ impl ClientOptionsParser {
}
};

// set default database.
options.default_database = db;

if options.tls.is_none() && options.srv {
options.tls = Some(Tls::Enabled(Default::default()));
}
Expand Down
38 changes: 38 additions & 0 deletions src/client/options/test.rs
Expand Up @@ -232,3 +232,41 @@ async fn parse_unknown_options() {
.await;
parse_uri("maxstalenessms", Some("maxstalenessseconds")).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn parse_with_default_database() {
let uri = "mongodb://localhost/abc";

assert_eq!(
ClientOptions::parse(uri).await.unwrap(),
ClientOptions {
hosts: vec![ServerAddress::Tcp {
host: "localhost".to_string(),
port: None
}],
original_uri: Some(uri.into()),
default_database: Some("abc".to_string()),
..Default::default()
}
);
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn parse_with_no_default_database() {
let uri = "mongodb://localhost/";

assert_eq!(
ClientOptions::parse(uri).await.unwrap(),
ClientOptions {
hosts: vec![ServerAddress::Tcp {
host: "localhost".to_string(),
port: None
}],
original_uri: Some(uri.into()),
default_database: None,
..Default::default()
}
);
}
8 changes: 8 additions & 0 deletions src/sync/client/mod.rs
Expand Up @@ -119,6 +119,14 @@ impl Client {
Database::new(self.async_client.database_with_options(name, options))
}

/// Gets a handle to the default database specified in the `ClientOptions` or MongoDB connection
/// string used to construct this `Client`.
///
/// If no default database was specified, `None` will be returned.
pub fn default_database(&self) -> Option<Database> {
self.async_client.default_database().map(Database::new)
}

/// Gets information about each database present in the cluster the Client is connected to.
pub fn list_databases(
&self,
Expand Down
30 changes: 30 additions & 0 deletions src/sync/test.rs
Expand Up @@ -75,6 +75,36 @@ fn client() {
assert!(db_names.contains(&function_name!().to_string()));
}

#[test]
#[function_name::named]
fn default_database() {
// here we just test default database name matched, the database interactive logic
// is tested in `database`.
let _guard: RwLockReadGuard<()> = RUNTIME.block_on(async { LOCK.run_concurrently().await });

let options = CLIENT_OPTIONS.clone();
let client = Client::with_options(options).expect("client creation should succeed");
let default_db = client.default_database();
assert!(default_db.is_none());

// create client througth options.
let mut options = CLIENT_OPTIONS.clone();
options.default_database = Some("abcd".to_string());
let client = Client::with_options(options).expect("client creation should succeed");
let default_db = client
.default_database()
.expect("should have a default database.");
assert_eq!(default_db.name(), "abcd");

// create client directly through uri_str.
let client = Client::with_uri_str("mongodb://localhost:27017/abcd")
.expect("client creation should succeed");
let default_db = client
.default_database()
.expect("should have a default database.");
assert_eq!(default_db.name(), "abcd");
}

#[test]
#[function_name::named]
fn database() {
Expand Down

0 comments on commit 578954b

Please sign in to comment.