Skip to content

Commit

Permalink
Merge pull request #7 from minghuaw/eventhubs_over_amqp
Browse files Browse the repository at this point in the history
Changing DEFAULT_SCOPE to DEFAULT_RESOURCE
  • Loading branch information
minghuaw committed Aug 31, 2023
2 parents 6d7268c + c891084 commit 1126bff
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 137 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use azeventhubs::producer::{
EventHubProducerClient, EventHubProducerClientOptions, SendEventOptions,
};
use azure_identity::DefaultAzureCredential;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let namespace = std::env::var("EVENT_HUBS_NAMESPACE")?;
let fqn = format!("{}.servicebus.windows.net", namespace);
let event_hub_name = std::env::var("EVENT_HUB_NAME")?;
let options = EventHubProducerClientOptions::default();
let default_credential = DefaultAzureCredential::default();

let mut producer_client = EventHubProducerClient::from_namespace_and_credential(
fqn,
event_hub_name,
default_credential,
options,
)
.await?;

let event = "test connect using azure identity";
let options = SendEventOptions::new().with_partition_id("0");
producer_client.send_event(event, options).await?;

producer_client.close().await?;

Ok(())
}
1 change: 1 addition & 0 deletions sdk/messaging_eventhubs/src/amqp/amqp_cbs_link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl AmqpCbsLinkHandle {
.send(command)
.await
.map_err(|_| AmqpCbsEventLoopStopped {})?;

result.await.map_err(|_| AmqpCbsEventLoopStopped {})
}

Expand Down
4 changes: 2 additions & 2 deletions sdk/messaging_eventhubs/src/amqp/amqp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl TransportClient for AmqpClient {
let access_token = self
.connection_scope
.credential
.get_token_using_default_scope()
.get_token_using_default_resource()
.await?;
let token_value = access_token.token.secret();
loop {
Expand Down Expand Up @@ -195,7 +195,7 @@ impl TransportClient for AmqpClient {
let access_token = self
.connection_scope
.credential
.get_token_using_default_scope()
.get_token_using_default_resource()
.await?;
let token_value = access_token.token.secret();
loop {
Expand Down
7 changes: 2 additions & 5 deletions sdk/messaging_eventhubs/src/amqp/amqp_phantom_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,14 @@ mod tests {
let message_iter = std::iter::once(event.amqp_message);

let batch = build_amqp_batch_from_messages(message_iter.clone(), None).unwrap();
let serialized_value = serialized_value_of_sendable(batch.sendable);
println!("serialized_value: {:?}", serialized_value);
let _serialized_value = serialized_value_of_sendable(batch.sendable);

let batch = build_amqp_batch_from_messages(message_iter.clone(), None).unwrap();
let serialized_bytes = serialized_bytes_of_sendable(batch.sendable);
println!("serialized_bytes: {:?}", serialized_bytes);
let _serialized_bytes = serialized_bytes_of_sendable(batch.sendable);

let batch = build_amqp_batch_from_messages(message_iter, None).unwrap();
let (phantom_size, ssize) =
phantom_size_and_serialized_size_of_sendable_envelope(batch.sendable);
println!("serialized_size: {}", ssize);
assert_eq!(phantom_size, ssize)
}

Expand Down
154 changes: 35 additions & 119 deletions sdk/messaging_eventhubs/src/amqp/cbs_token_provider.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use azure_core::auth::TokenResponse;
use fe2o3_amqp_cbs::{token::CbsToken, AsyncCbsTokenProvider};
use fe2o3_amqp_types::primitives::Timestamp;
use futures_util::{pin_mut, ready};
use std::{future::Future, sync::Arc, task::Poll};
use std::{future::Future, sync::Arc, pin::Pin};
use time::Duration as TimeSpan;
use tokio::sync::Semaphore;

use crate::authorization::event_hub_token_credential::EventHubTokenCredential;

Expand All @@ -31,9 +29,6 @@ impl CbsTokenProvider {
} else {
TokenType::JsonWebToken {
credential,
// Tokens are only cached for JWT-based credentials; no need
// to instantiate the semaphore if no caching is taking place.
semaphore: Semaphore::new(1),
cached_token: None,
}
};
Expand All @@ -49,74 +44,8 @@ fn is_nearing_expiration(token: &TokenResponse, token_expiration_buffer: TimeSpa
token.expires_on - token_expiration_buffer <= crate::util::time::now_utc()
}

pub struct CbsTokenFut<'a> {
provider: &'a mut CbsTokenProvider,
}

impl<'a> Future for CbsTokenFut<'a> {
type Output = Result<CbsToken<'a>, azure_core::error::Error>;

fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
let expiration_buffer = self.provider.token_expiration_buffer;
let entity_type = self.provider.token_type.entity_type().to_string(); // TODO: reduce clone/to_string
let result = match &mut self.provider.token_type {
TokenType::SharedAccessToken { credential } => {
let fut = credential.get_token_using_default_scope();
pin_mut!(fut);
ready!(fut.poll(cx))
}
TokenType::JsonWebToken {
credential,
semaphore,
cached_token,
} => match cached_token {
Some(cached) => {
let fut = semaphore.acquire();
pin_mut!(fut);
let _permit = ready!(fut.poll(cx)).map_err(|e| {
azure_core::error::Error::new(azure_core::error::ErrorKind::Credential, e)
})?;
if is_nearing_expiration(cached, expiration_buffer) {
let fut = credential.get_token_using_default_scope();
pin_mut!(fut);
let token = ready!(fut.poll(cx))?;
*cached = token;
}
Ok(cached.clone())
}
None => {
let fut = semaphore.acquire();
pin_mut!(fut);
let _permit = ready!(fut.poll(cx)).map_err(|e| {
azure_core::error::Error::new(azure_core::error::ErrorKind::Credential, e)
})?;

// GetTokenUsingDefaultScopeAsync
let fut = credential.get_token_using_default_scope();
pin_mut!(fut);
let token = ready!(fut.poll(cx))?;
*cached_token = Some(token.clone());
Ok(token)
}
},
};

match result {
Ok(token) => Poll::Ready(Ok(CbsToken::new(
token.token.secret().to_owned(),
entity_type,
Some(Timestamp::from(token.expires_on)),
))),
Err(err) => Poll::Ready(Err(err)),
}
}
}

impl AsyncCbsTokenProvider for CbsTokenProvider {
type Fut<'a> = CbsTokenFut<'a>;
type Fut<'a> = Pin<Box<dyn Future<Output = Result<CbsToken<'a>, azure_core::error::Error>> + Send + 'a>>;
type Error = azure_core::error::Error;

fn get_token_async(
Expand All @@ -125,7 +54,39 @@ impl AsyncCbsTokenProvider for CbsTokenProvider {
_resource_id: impl AsRef<str>,
_claims: impl IntoIterator<Item = impl AsRef<str>>,
) -> Self::Fut<'_> {
CbsTokenFut { provider: self }
// CbsTokenFut { provider: self }
Box::pin(async {
let expiration_buffer = self.token_expiration_buffer;
let entity_type = self.token_type.entity_type().to_string();

let result = match &mut self.token_type {
TokenType::SharedAccessToken { credential } => {
let token = credential.get_token_using_default_resource().await?;
Ok(token)
},
TokenType::JsonWebToken { credential, cached_token } => {
match cached_token {
Some(cached) => {
if is_nearing_expiration(cached, expiration_buffer) {
let token = credential.get_token_using_default_resource().await?;
*cached = token.clone();
}
Ok(cached.clone())
},
None => {
let token = credential.get_token_using_default_resource().await?;
*cached_token = Some(token.clone());
Ok(token)
}
}
}
};
result.map(|token| CbsToken::new(
token.token.secret().to_owned(),
entity_type,
Some(Timestamp::from(token.expires_on)),
))
})
}
}

Expand Down Expand Up @@ -266,50 +227,5 @@ mod tests {
second_token.expires_at_utc().clone()
);
}

// // TODO: This cannot be mock tested right now
// #[tokio::test]
// async fn get_token_does_not_cache_shared_access_credential() {
// // var value = "TOkEn!";
// // var signature = new SharedAccessSignature("hub-name", "keyName", "key", value, DateTimeOffset.UtcNow.AddHours(4));
// let signature = SharedAccessSignature::try_from_parts(
// "sb-name",
// "keyName",
// "key",
// Some(std::time::Duration::from_secs(4 * 60 * 60)),
// ).unwrap();
// }

// // TODO: This requires dispatching token provider into tasks, so a mutex is required
// #[tokio::test]
// async fn get_token_synchronizes_multiple_refresh_attempts_for_jwt_tokens() {
// let token_value = "ValuE_oF_tHE_tokEn";
// let buffer = TimeSpan::minutes(5);
// let expires_on: OffsetDateTime =
// crate::util::time::now_utc() - buffer + TimeSpan::seconds(-10);
// let mut mock_credential = MockTokenCredential::new();

// let mut seq = Sequence::new();
// mock_credential
// .expect_get_token()
// .times(1)
// .in_sequence(&mut seq)
// .returning(move |_resource| {
// Ok(TokenResponse {
// token: AccessToken::new(token_value),
// expires_on: expires_on,
// })
// });
// mock_credential
// .expect_get_token()
// .times(1)
// .in_sequence(&mut seq)
// .returning(move |_resource| {
// Ok(TokenResponse {
// token: AccessToken::new(token_value),
// expires_on: crate::util::time::now_utc() + TimeSpan::days(1),
// })
// });
// }
}
}
5 changes: 0 additions & 5 deletions sdk/messaging_eventhubs/src/amqp/token_type.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use azure_core::auth::TokenResponse;
use std::sync::Arc;
use tokio::sync::Semaphore;

use crate::{
authorization::event_hub_token_credential::EventHubTokenCredential,
Expand All @@ -17,10 +16,6 @@ pub(crate) enum TokenType {
JsonWebToken {
credential: Arc<EventHubTokenCredential>,

/// Tokens are only cached for JWT-based credentials; no need
/// to instantiate the semaphore if no caching is taking place.
semaphore: Semaphore,

/// The JWT-based token that is currently cached for authorization.
cached_token: Option<TokenResponse>,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ impl EventHubTokenCredential {
}

impl EventHubTokenCredential {
pub(crate) const DEFAULT_SCOPE: &str = "https://eventhubs.azure.net/.default";
// pub(crate) const DEFAULT_SCOPE: &str = "https://eventhubs.azure.net/.default";

// `azure_identity` appends "/.default" to the resource internally.
pub(crate) const DEFAULT_RESOURCE: &str = "https://eventhubs.azure.net/";

/// Gets a `TokenResponse` for the specified resource
pub(crate) async fn get_token(&self, resource: &str) -> azure_core::Result<TokenResponse> {
Expand All @@ -77,8 +80,12 @@ impl EventHubTokenCredential {
}
}

pub(crate) async fn get_token_using_default_scope(&self) -> azure_core::Result<TokenResponse> {
self.get_token(Self::DEFAULT_SCOPE).await
// pub(crate) async fn get_token_using_default_scope(&self) -> azure_core::Result<TokenResponse> {
// self.get_token(Self::DEFAULT_SCOPE).await
// }

pub(crate) async fn get_token_using_default_resource(&self) -> azure_core::Result<TokenResponse> {
self.get_token(Self::DEFAULT_RESOURCE).await
}
}

Expand Down Expand Up @@ -133,7 +140,12 @@ cfg_not_wasm32! {
use azure_identity::DefaultAzureCredential;

let default_credential = DefaultAzureCredential::default();
let _event_hub_token_credential = EventHubTokenCredential::from(default_credential);
let event_hub_token_credential = EventHubTokenCredential::from(default_credential);
let token = event_hub_token_credential
.get_token_using_default_resource()
.await
.unwrap();
assert!(!token.token.secret().is_empty())
}
}
}
23 changes: 22 additions & 1 deletion sdk/messaging_eventhubs/tests/event_hub_connection_live_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ cfg_not_wasm32! {
#[tokio::test]
async fn connection_can_connect_with_named_key_credential() {
common::setup_dotenv();
use messaging_eventhubs::authorization::{
use azeventhubs::authorization::{
SharedAccessCredential, AzureNamedKeyCredential,
build_connection_signature_authorization_resource,
};
Expand All @@ -79,4 +79,25 @@ cfg_not_wasm32! {
).await.unwrap();
connection.close().await.unwrap();
}

#[tokio::test]
async fn connection_can_connect_with_azure_identity_credential() {
common::setup_dotenv();

use azure_identity::DefaultAzureCredential;

let namespace = std::env::var("EVENT_HUBS_NAMESPACE").unwrap();
let fqn = format!("{}.servicebus.windows.net", namespace);
let event_hub_name = std::env::var("EVENT_HUB_NAME").unwrap();

let options = EventHubConnectionOptions::default();
let credential = DefaultAzureCredential::default();

let connection = EventHubConnection::from_namespace_and_credential(
fqn,
event_hub_name,
credential,
options,
).await.unwrap();
}
}
Loading

0 comments on commit 1126bff

Please sign in to comment.