Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deltachat-ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3753,7 +3753,11 @@ pub unsafe extern "C" fn dc_provider_new_from_email(

match socks5_enabled {
Ok(socks5_enabled) => {
match block_on(provider::get_provider_info(addr.as_str(), socks5_enabled)) {
match block_on(provider::get_provider_info(
ctx,
addr.as_str(),
socks5_enabled,
)) {
Some(provider) => provider,
None => ptr::null_mut(),
}
Expand Down
2 changes: 1 addition & 1 deletion examples/repl/cmdline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ pub async fn cmdline(context: Context, line: &str, chat_id: &mut ChatId) -> Resu
let socks5_enabled = context
.get_config_bool(config::Config::Socks5Enabled)
.await?;
match provider::get_provider_info(arg1, socks5_enabled).await {
match provider::get_provider_info(&context, arg1, socks5_enabled).await {
Some(info) => {
println!("Information for provider belonging to {}:", arg1);
println!("status: {}", info.status as u32);
Expand Down
4 changes: 3 additions & 1 deletion src/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ async fn configure(ctx: &Context, param: &mut LoginParam) -> Result<()> {
"checking internal provider-info for offline autoconfig"
);

if let Some(provider) = provider::get_provider_info(&param_domain, socks5_enabled).await {
if let Some(provider) =
provider::get_provider_info(ctx, &param_domain, socks5_enabled).await
{
param.provider = Some(provider);
match provider.status {
provider::Status::Ok | provider::Status::Preparation => {
Expand Down
27 changes: 15 additions & 12 deletions src/oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub async fn dc_get_oauth2_url(
redirect_uri: &str,
) -> Result<Option<String>> {
let socks5_enabled = context.get_config_bool(Config::Socks5Enabled).await?;
if let Some(oauth2) = Oauth2::from_address(addr, socks5_enabled).await {
if let Some(oauth2) = Oauth2::from_address(context, addr, socks5_enabled).await {
context
.sql
.set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
Expand All @@ -79,7 +79,7 @@ pub async fn dc_get_oauth2_access_token(
regenerate: bool,
) -> Result<Option<String>> {
let socks5_enabled = context.get_config_bool(Config::Socks5Enabled).await?;
if let Some(oauth2) = Oauth2::from_address(addr, socks5_enabled).await {
if let Some(oauth2) = Oauth2::from_address(context, addr, socks5_enabled).await {
let lock = context.oauth2_mutex.lock().await;

// read generated token
Expand Down Expand Up @@ -225,7 +225,7 @@ pub async fn dc_get_oauth2_addr(
code: &str,
) -> Result<Option<String>> {
let socks5_enabled = context.get_config_bool(Config::Socks5Enabled).await?;
let oauth2 = match Oauth2::from_address(addr, socks5_enabled).await {
let oauth2 = match Oauth2::from_address(context, addr, socks5_enabled).await {
Some(o) => o,
None => return Ok(None),
};
Expand Down Expand Up @@ -253,13 +253,13 @@ pub async fn dc_get_oauth2_addr(
}

impl Oauth2 {
async fn from_address(addr: &str, skip_mx: bool) -> Option<Self> {
async fn from_address(context: &Context, addr: &str, skip_mx: bool) -> Option<Self> {
let addr_normalized = normalize_addr(addr);
if let Some(domain) = addr_normalized
.find('@')
.map(|index| addr_normalized.split_at(index + 1).1)
{
if let Some(oauth2_authorizer) = provider::get_provider_info(domain, skip_mx)
if let Some(oauth2_authorizer) = provider::get_provider_info(context, domain, skip_mx)
.await
.and_then(|provider| provider.oauth2_authorizer.as_ref())
{
Expand Down Expand Up @@ -356,30 +356,33 @@ mod tests {

#[async_std::test]
async fn test_oauth_from_address() {
let t = TestContext::new().await;
assert_eq!(
Oauth2::from_address("hello@gmail.com", false).await,
Oauth2::from_address(&t, "hello@gmail.com", false).await,
Some(OAUTH2_GMAIL)
);
assert_eq!(
Oauth2::from_address("hello@googlemail.com", false).await,
Oauth2::from_address(&t, "hello@googlemail.com", false).await,
Some(OAUTH2_GMAIL)
);
assert_eq!(
Oauth2::from_address("hello@yandex.com", false).await,
Oauth2::from_address(&t, "hello@yandex.com", false).await,
Some(OAUTH2_YANDEX)
);
assert_eq!(
Oauth2::from_address("hello@yandex.ru", false).await,
Oauth2::from_address(&t, "hello@yandex.ru", false).await,
Some(OAUTH2_YANDEX)
);

assert_eq!(Oauth2::from_address("hello@web.de", false).await, None);
assert_eq!(Oauth2::from_address(&t, "hello@web.de", false).await, None);
}

#[async_std::test]
async fn test_oauth_from_mx() {
// TODO: this does not test MX lookup, google.com is in our provider-db
// does anyone know a "good" Google Workspace (former G Suite) domain we can use for testing?
let t = TestContext::new().await;
assert_eq!(
Oauth2::from_address("hello@google.com", false).await,
Oauth2::from_address(&t, "hello@google.com", false).await,
Some(OAUTH2_GMAIL)
);
}
Expand Down
50 changes: 42 additions & 8 deletions src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
mod data;

use crate::config::Config;
use crate::context::Context;
use crate::provider::data::{PROVIDER_DATA, PROVIDER_IDS, PROVIDER_UPDATED};
use async_std_resolver::resolver_from_system_conf;
use async_std_resolver::{
config, resolver, resolver_from_system_conf, AsyncStdResolver, ResolveError,
};
use chrono::{NaiveDateTime, NaiveTime};

#[derive(Debug, Display, Copy, Clone, PartialEq, FromPrimitive, ToPrimitive)]
Expand Down Expand Up @@ -81,6 +84,22 @@ pub struct Provider {
pub oauth2_authorizer: Option<Oauth2Authorizer>,
}

/// Get resolver to query MX records.
///
/// We first try resolver_from_system_conf() which reads the system's resolver from `/etc/resolv.conf`.
/// This does not work at least on some Androids, therefore we use use ResolverConfig::default()
/// which default eg. to google's 8.8.8.8 or 8.8.4.4 as a fallback.
async fn get_resolver() -> Result<AsyncStdResolver, ResolveError> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok to simply import anyhow::Result at the top of the module rather than the test and return it here, so anyhow::Result is used as much as possible, preferrably across the whole crate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also first tried that, getting "expected struct anyhow::Error, found struct ResolveError" - so i assumed i have to convert the object somehow and did not want to dive into that yesterday :)

Copy link
Collaborator

@link2xt link2xt Nov 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like

let resolver = ...?;
Ok(resolver)

? will convert the Result automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's clever! i did not thought too much about these conversions in the past :)

i did a successor pr that uses map_err() at #2853 - but you would prefer the one above?

if let Ok(resolver) = resolver_from_system_conf().await {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log the error here?

return Ok(resolver);
}
resolver(
config::ResolverConfig::default(),
config::ResolverOpts::default(),
)
.await
}

/// Returns provider for the given domain.
///
/// This function looks up domain in offline database first. If not
Expand All @@ -89,15 +108,19 @@ pub struct Provider {
///
/// For compatibility, email address can be passed to this function
/// instead of the domain.
pub async fn get_provider_info(domain: &str, skip_mx: bool) -> Option<&'static Provider> {
pub async fn get_provider_info(
context: &Context,
domain: &str,
skip_mx: bool,
) -> Option<&'static Provider> {
let domain = domain.rsplitn(2, '@').next()?;

if let Some(provider) = get_provider_by_domain(domain) {
return Some(provider);
}

if !skip_mx {
if let Some(provider) = get_provider_by_mx(domain).await {
if let Some(provider) = get_provider_by_mx(context, domain).await {
return Some(provider);
}
}
Expand All @@ -117,8 +140,8 @@ pub fn get_provider_by_domain(domain: &str) -> Option<&'static Provider> {
/// Finds a provider based on MX record for the given domain.
///
/// For security reasons, only Gmail can be configured this way.
pub async fn get_provider_by_mx(domain: &str) -> Option<&'static Provider> {
if let Ok(resolver) = resolver_from_system_conf().await {
pub async fn get_provider_by_mx(context: &Context, domain: &str) -> Option<&'static Provider> {
if let Ok(resolver) = get_resolver().await {
let mut fqdn: String = domain.to_string();
if !fqdn.ends_with('.') {
fqdn.push('.');
Expand All @@ -143,6 +166,8 @@ pub async fn get_provider_by_mx(domain: &str) -> Option<&'static Provider> {
}
}
}
} else {
warn!(context, "cannot get a resolver to check MX records.");
}

None
Expand All @@ -169,6 +194,8 @@ mod tests {

use super::*;
use crate::dc_tools::time;
use crate::test_utils::TestContext;
use anyhow::Result;
use chrono::NaiveDate;

#[test]
Expand Down Expand Up @@ -218,12 +245,13 @@ mod tests {

#[async_std::test]
async fn test_get_provider_info() {
assert!(get_provider_info("", false).await.is_none());
assert!(get_provider_info("google.com", false).await.unwrap().id == "gmail");
let t = TestContext::new().await;
assert!(get_provider_info(&t, "", false).await.is_none());
assert!(get_provider_info(&t, "google.com", false).await.unwrap().id == "gmail");

// get_provider_info() accepts email addresses for backwards compatibility
assert!(
get_provider_info("example@google.com", false)
get_provider_info(&t, "example@google.com", false)
.await
.unwrap()
.id
Expand All @@ -242,4 +270,10 @@ mod tests {
assert!(get_provider_update_timestamp() <= time());
assert!(get_provider_update_timestamp() > timestamp_past);
}

#[async_std::test]
async fn test_get_resolver() -> Result<()> {
assert!(get_resolver().await.is_ok());
Ok(())
}
}