Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing DEFAULT_SCOPE to DEFAULT_RESOURCE #7

Merged
merged 6 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use azeventhubs::producer::{
EventHubProducerClient, EventHubProducerClientOptions, SendEventOptions,
};
use azure_identity::DefaultAzureCredential;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let namespace = std::env::var("EVENT_HUBS_NAMESPACE")?;
let fqn = format!("{}.servicebus.windows.net", namespace);
let event_hub_name = std::env::var("EVENT_HUB_NAME")?;
let options = EventHubProducerClientOptions::default();
let default_credential = DefaultAzureCredential::default();

let mut producer_client = EventHubProducerClient::from_namespace_and_credential(
fqn,
event_hub_name,
default_credential,
options,
)
.await?;

let event = "test connect using azure identity";
let options = SendEventOptions::new().with_partition_id("0");
producer_client.send_event(event, options).await?;

producer_client.close().await?;

Ok(())
}
1 change: 1 addition & 0 deletions sdk/messaging_eventhubs/src/amqp/amqp_cbs_link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl AmqpCbsLinkHandle {
.send(command)
.await
.map_err(|_| AmqpCbsEventLoopStopped {})?;

result.await.map_err(|_| AmqpCbsEventLoopStopped {})
}

Expand Down
4 changes: 2 additions & 2 deletions sdk/messaging_eventhubs/src/amqp/amqp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl TransportClient for AmqpClient {
let access_token = self
.connection_scope
.credential
.get_token_using_default_scope()
.get_token_using_default_resource()
.await?;
let token_value = access_token.token.secret();
loop {
Expand Down Expand Up @@ -195,7 +195,7 @@ impl TransportClient for AmqpClient {
let access_token = self
.connection_scope
.credential
.get_token_using_default_scope()
.get_token_using_default_resource()
.await?;
let token_value = access_token.token.secret();
loop {
Expand Down
7 changes: 2 additions & 5 deletions sdk/messaging_eventhubs/src/amqp/amqp_phantom_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,14 @@ mod tests {
let message_iter = std::iter::once(event.amqp_message);

let batch = build_amqp_batch_from_messages(message_iter.clone(), None).unwrap();
let serialized_value = serialized_value_of_sendable(batch.sendable);
println!("serialized_value: {:?}", serialized_value);
let _serialized_value = serialized_value_of_sendable(batch.sendable);

let batch = build_amqp_batch_from_messages(message_iter.clone(), None).unwrap();
let serialized_bytes = serialized_bytes_of_sendable(batch.sendable);
println!("serialized_bytes: {:?}", serialized_bytes);
let _serialized_bytes = serialized_bytes_of_sendable(batch.sendable);

let batch = build_amqp_batch_from_messages(message_iter, None).unwrap();
let (phantom_size, ssize) =
phantom_size_and_serialized_size_of_sendable_envelope(batch.sendable);
println!("serialized_size: {}", ssize);
assert_eq!(phantom_size, ssize)
}

Expand Down
154 changes: 35 additions & 119 deletions sdk/messaging_eventhubs/src/amqp/cbs_token_provider.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use azure_core::auth::TokenResponse;
use fe2o3_amqp_cbs::{token::CbsToken, AsyncCbsTokenProvider};
use fe2o3_amqp_types::primitives::Timestamp;
use futures_util::{pin_mut, ready};
use std::{future::Future, sync::Arc, task::Poll};
use std::{future::Future, sync::Arc, pin::Pin};
use time::Duration as TimeSpan;
use tokio::sync::Semaphore;

use crate::authorization::event_hub_token_credential::EventHubTokenCredential;

Expand All @@ -31,9 +29,6 @@ impl CbsTokenProvider {
} else {
TokenType::JsonWebToken {
credential,
// Tokens are only cached for JWT-based credentials; no need
// to instantiate the semaphore if no caching is taking place.
semaphore: Semaphore::new(1),
cached_token: None,
}
};
Expand All @@ -49,74 +44,8 @@ fn is_nearing_expiration(token: &TokenResponse, token_expiration_buffer: TimeSpa
token.expires_on - token_expiration_buffer <= crate::util::time::now_utc()
}

pub struct CbsTokenFut<'a> {
provider: &'a mut CbsTokenProvider,
}

impl<'a> Future for CbsTokenFut<'a> {
type Output = Result<CbsToken<'a>, azure_core::error::Error>;

fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
let expiration_buffer = self.provider.token_expiration_buffer;
let entity_type = self.provider.token_type.entity_type().to_string(); // TODO: reduce clone/to_string
let result = match &mut self.provider.token_type {
TokenType::SharedAccessToken { credential } => {
let fut = credential.get_token_using_default_scope();
pin_mut!(fut);
ready!(fut.poll(cx))
}
TokenType::JsonWebToken {
credential,
semaphore,
cached_token,
} => match cached_token {
Some(cached) => {
let fut = semaphore.acquire();
pin_mut!(fut);
let _permit = ready!(fut.poll(cx)).map_err(|e| {
azure_core::error::Error::new(azure_core::error::ErrorKind::Credential, e)
})?;
if is_nearing_expiration(cached, expiration_buffer) {
let fut = credential.get_token_using_default_scope();
pin_mut!(fut);
let token = ready!(fut.poll(cx))?;
*cached = token;
}
Ok(cached.clone())
}
None => {
let fut = semaphore.acquire();
pin_mut!(fut);
let _permit = ready!(fut.poll(cx)).map_err(|e| {
azure_core::error::Error::new(azure_core::error::ErrorKind::Credential, e)
})?;

// GetTokenUsingDefaultScopeAsync
let fut = credential.get_token_using_default_scope();
pin_mut!(fut);
let token = ready!(fut.poll(cx))?;
*cached_token = Some(token.clone());
Ok(token)
}
},
};

match result {
Ok(token) => Poll::Ready(Ok(CbsToken::new(
token.token.secret().to_owned(),
entity_type,
Some(Timestamp::from(token.expires_on)),
))),
Err(err) => Poll::Ready(Err(err)),
}
}
}

impl AsyncCbsTokenProvider for CbsTokenProvider {
type Fut<'a> = CbsTokenFut<'a>;
type Fut<'a> = Pin<Box<dyn Future<Output = Result<CbsToken<'a>, azure_core::error::Error>> + Send + 'a>>;
type Error = azure_core::error::Error;

fn get_token_async(
Expand All @@ -125,7 +54,39 @@ impl AsyncCbsTokenProvider for CbsTokenProvider {
_resource_id: impl AsRef<str>,
_claims: impl IntoIterator<Item = impl AsRef<str>>,
) -> Self::Fut<'_> {
CbsTokenFut { provider: self }
// CbsTokenFut { provider: self }
Box::pin(async {
let expiration_buffer = self.token_expiration_buffer;
let entity_type = self.token_type.entity_type().to_string();

let result = match &mut self.token_type {
TokenType::SharedAccessToken { credential } => {
let token = credential.get_token_using_default_resource().await?;
Ok(token)
},
TokenType::JsonWebToken { credential, cached_token } => {
match cached_token {
Some(cached) => {
if is_nearing_expiration(cached, expiration_buffer) {
let token = credential.get_token_using_default_resource().await?;
*cached = token.clone();
}
Ok(cached.clone())
},
None => {
let token = credential.get_token_using_default_resource().await?;
*cached_token = Some(token.clone());
Ok(token)
}
}
}
};
result.map(|token| CbsToken::new(
token.token.secret().to_owned(),
entity_type,
Some(Timestamp::from(token.expires_on)),
))
})
}
}

Expand Down Expand Up @@ -266,50 +227,5 @@ mod tests {
second_token.expires_at_utc().clone()
);
}

// // TODO: This cannot be mock tested right now
// #[tokio::test]
// async fn get_token_does_not_cache_shared_access_credential() {
// // var value = "TOkEn!";
// // var signature = new SharedAccessSignature("hub-name", "keyName", "key", value, DateTimeOffset.UtcNow.AddHours(4));
// let signature = SharedAccessSignature::try_from_parts(
// "sb-name",
// "keyName",
// "key",
// Some(std::time::Duration::from_secs(4 * 60 * 60)),
// ).unwrap();
// }

// // TODO: This requires dispatching token provider into tasks, so a mutex is required
// #[tokio::test]
// async fn get_token_synchronizes_multiple_refresh_attempts_for_jwt_tokens() {
// let token_value = "ValuE_oF_tHE_tokEn";
// let buffer = TimeSpan::minutes(5);
// let expires_on: OffsetDateTime =
// crate::util::time::now_utc() - buffer + TimeSpan::seconds(-10);
// let mut mock_credential = MockTokenCredential::new();

// let mut seq = Sequence::new();
// mock_credential
// .expect_get_token()
// .times(1)
// .in_sequence(&mut seq)
// .returning(move |_resource| {
// Ok(TokenResponse {
// token: AccessToken::new(token_value),
// expires_on: expires_on,
// })
// });
// mock_credential
// .expect_get_token()
// .times(1)
// .in_sequence(&mut seq)
// .returning(move |_resource| {
// Ok(TokenResponse {
// token: AccessToken::new(token_value),
// expires_on: crate::util::time::now_utc() + TimeSpan::days(1),
// })
// });
// }
}
}
5 changes: 0 additions & 5 deletions sdk/messaging_eventhubs/src/amqp/token_type.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use azure_core::auth::TokenResponse;
use std::sync::Arc;
use tokio::sync::Semaphore;

use crate::{
authorization::event_hub_token_credential::EventHubTokenCredential,
Expand All @@ -17,10 +16,6 @@ pub(crate) enum TokenType {
JsonWebToken {
credential: Arc<EventHubTokenCredential>,

/// Tokens are only cached for JWT-based credentials; no need
/// to instantiate the semaphore if no caching is taking place.
semaphore: Semaphore,

/// The JWT-based token that is currently cached for authorization.
cached_token: Option<TokenResponse>,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ impl EventHubTokenCredential {
}

impl EventHubTokenCredential {
pub(crate) const DEFAULT_SCOPE: &str = "https://eventhubs.azure.net/.default";
// pub(crate) const DEFAULT_SCOPE: &str = "https://eventhubs.azure.net/.default";

// `azure_identity` appends "/.default" to the resource internally.
pub(crate) const DEFAULT_RESOURCE: &str = "https://eventhubs.azure.net/";

/// Gets a `TokenResponse` for the specified resource
pub(crate) async fn get_token(&self, resource: &str) -> azure_core::Result<TokenResponse> {
Expand All @@ -77,8 +80,12 @@ impl EventHubTokenCredential {
}
}

pub(crate) async fn get_token_using_default_scope(&self) -> azure_core::Result<TokenResponse> {
self.get_token(Self::DEFAULT_SCOPE).await
// pub(crate) async fn get_token_using_default_scope(&self) -> azure_core::Result<TokenResponse> {
// self.get_token(Self::DEFAULT_SCOPE).await
// }

pub(crate) async fn get_token_using_default_resource(&self) -> azure_core::Result<TokenResponse> {
self.get_token(Self::DEFAULT_RESOURCE).await
}
}

Expand Down Expand Up @@ -133,7 +140,12 @@ cfg_not_wasm32! {
use azure_identity::DefaultAzureCredential;

let default_credential = DefaultAzureCredential::default();
let _event_hub_token_credential = EventHubTokenCredential::from(default_credential);
let event_hub_token_credential = EventHubTokenCredential::from(default_credential);
let token = event_hub_token_credential
.get_token_using_default_resource()
.await
.unwrap();
assert!(!token.token.secret().is_empty())
}
}
}
23 changes: 22 additions & 1 deletion sdk/messaging_eventhubs/tests/event_hub_connection_live_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ cfg_not_wasm32! {
#[tokio::test]
async fn connection_can_connect_with_named_key_credential() {
common::setup_dotenv();
use messaging_eventhubs::authorization::{
use azeventhubs::authorization::{
SharedAccessCredential, AzureNamedKeyCredential,
build_connection_signature_authorization_resource,
};
Expand All @@ -79,4 +79,25 @@ cfg_not_wasm32! {
).await.unwrap();
connection.close().await.unwrap();
}

#[tokio::test]
async fn connection_can_connect_with_azure_identity_credential() {
common::setup_dotenv();

use azure_identity::DefaultAzureCredential;

let namespace = std::env::var("EVENT_HUBS_NAMESPACE").unwrap();
let fqn = format!("{}.servicebus.windows.net", namespace);
let event_hub_name = std::env::var("EVENT_HUB_NAME").unwrap();

let options = EventHubConnectionOptions::default();
let credential = DefaultAzureCredential::default();

let connection = EventHubConnection::from_namespace_and_credential(
fqn,
event_hub_name,
credential,
options,
).await.unwrap();
}
}
Loading
Loading