Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 27 additions & 51 deletions src/communication/in_memory_rpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,22 @@ use tokio::time::timeout;
use tracing::{debug, info};

use crate::{
LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus,
UTransport, UUri, UUID,
LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri, UUID,
};

use super::{
build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload,
};

/// Handles an RPC Response message received from the transport layer.
fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
let Some(attribs) = response.attributes.as_ref() else {
return Err(ServiceInvocationError::InvalidArgument(
"response message does not contain attributes".to_string(),
));
};

match attribs.commstatus.map(|v| v.enum_value_or_default()) {
match response.commstatus() {
Some(UCode::OK) | None => {
// successful invocation
response.payload.map_or(Ok(None), |payload| {
Ok(Some(UPayload::new(
payload,
attribs.payload_format.enum_value_or_default(),
)))
})
let payload_format = response.payload_format().unwrap_or_default();
Ok(response
.payload
.map(|payload| UPayload::new(payload, payload_format)))
}
Some(code) => {
// try to extract UStatus from response payload
Expand Down Expand Up @@ -86,15 +78,20 @@ impl ResponseListener {
}
}

fn handle_response(&self, reqid: &UUID, response_message: UMessage) {
let Ok(mut pending_requests) = self.pending_requests.lock() else {
info!(
request_id = reqid.to_hyphenated_string(),
"failed to process response message, cannot acquire lock for pending requests map"
);
return;
fn handle_response(&self, response_message: UMessage) {
let reqid = response_message.request_id_unchecked().clone();
let response_sender = {
// drop lock as soon as possible
let Ok(mut pending_requests) = self.pending_requests.lock() else {
info!(
request_id = reqid.to_hyphenated_string(),
"failed to process response message, cannot acquire lock for pending requests map"
);
return;
};
pending_requests.remove(&reqid)
};
if let Some(sender) = pending_requests.remove(reqid) {
if let Some(sender) = response_sender {
if let Err(_e) = sender.send(response_message) {
// channel seems to be closed already
debug!(
Expand Down Expand Up @@ -133,27 +130,15 @@ impl ResponseListener {
#[async_trait]
impl UListener for ResponseListener {
async fn on_receive(&self, msg: UMessage) {
let message_type = msg
.attributes
.get_or_default()
.type_
.enum_value_or_default();
if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE {
// it is sufficient to check if the message is a response
// because the transport implementation forwards valid UMessages only
if msg.is_response() {
self.handle_response(msg);
} else {
debug!(
message_type = message_type.to_cloudevent_type(),
"service provider replied with message that is not an RPC Response"
message_type = msg.type_unchecked().to_cloudevent_type(),
"ignoring non-response message received by RPC client"
);
return;
}

if let Some(reqid) = msg
.attributes
.as_ref()
.and_then(|attribs| attribs.reqid.clone().into_option())
{
self.handle_response(&reqid, msg);
} else {
debug!("ignoring malformed response message not containing request ID");
}
}
}
Expand Down Expand Up @@ -617,13 +602,4 @@ mod tests {
assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
assert!(!client.contains_pending_request(&message_id));
}

#[test]
fn test_handle_response_message_fails_for_missing_attributes() {
let response_msg = UMessage {
..Default::default()
};
let result = handle_response_message(response_msg);
assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
}
}
206 changes: 20 additions & 186 deletions src/communication/in_memory_rpc_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use async_trait::async_trait;
use tracing::{debug, info};

use crate::{
communication::build_message, LocalUriProvider, UAttributes, UAttributesError,
UAttributesValidators, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri,
communication::build_message, LocalUriProvider, UListener, UMessage, UMessageBuilder, UStatus,
UTransport, UUri,
};

use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload};
Expand All @@ -37,26 +37,16 @@ impl RequestListener {
async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) {
let transport_clone = self.transport.clone();
let request_handler_clone = self.request_handler.clone();
let mut response_builder =
UMessageBuilder::response_for_request(request_message.attributes_unchecked());

let request_id = request_message
.attributes
.get_or_default()
.id
.get_or_default();
let request_timeout = request_message
.attributes
.get_or_default()
.ttl
.unwrap_or(10_000);
let request_message_id = request_message.id_unchecked().clone();
let request_timeout = request_message.ttl_unchecked();
let payload_format = request_message.payload_format().unwrap_or_default();
let payload = request_message.payload;
let payload_format = request_message
.attributes
.get_or_default()
.payload_format
.enum_value_or_default();
let request_payload = payload.map(|data| UPayload::new(data, payload_format));

debug!(ttl = request_timeout, id = %request_id, "processing RPC request");
debug!(ttl = request_timeout, id = %request_message_id, "processing RPC request");

let invocation_result_future = request_handler_clone.handle_request(
resource_id,
Expand All @@ -75,15 +65,10 @@ impl RequestListener {
.and_then(|v| v);

let response = match outcome {
Ok(response_payload) => {
let mut builder = UMessageBuilder::response_for_request(
request_message.attributes.get_or_default(),
);
build_message(&mut builder, response_payload)
}
Ok(response_payload) => build_message(&mut response_builder, response_payload),
Err(e) => {
let error = UStatus::from(e);
UMessageBuilder::response_for_request(request_message.attributes.get_or_default())
response_builder
.with_comm_status(error.get_code())
.build_with_protobuf_payload(&error)
}
Expand All @@ -100,64 +85,20 @@ impl RequestListener {
}
}
}

async fn process_invalid_request(
&self,
validation_error: UAttributesError,
request_attributes: &UAttributes,
) {
// all we need is a valid source address and a message ID to be able to send back an error message
let (Some(id), Some(source_address)) = (
request_attributes.id.to_owned().into_option(),
request_attributes
.source
.to_owned()
.into_option()
.filter(|uri| uri.is_rpc_response()),
) else {
debug!("invalid request message does not contain enough data to create response");
return;
};

debug!(id = %id, "processing invalid request message");

let response_payload =
UStatus::fail_with_code(UCode::INVALID_ARGUMENT, validation_error.to_string());
let Ok(response_message) = UMessageBuilder::response(
source_address,
id,
request_attributes.sink.get_or_default().to_owned(),
)
.with_comm_status(response_payload.get_code())
.build_with_protobuf_payload(&response_payload) else {
info!("failed to create error message");
return;
};

if let Err(e) = self.transport.send(response_message).await {
info!(ucode = e.code.value(), "failed to send error response");
}
}
}

#[async_trait]
impl UListener for RequestListener {
async fn on_receive(&self, msg: UMessage) {
let Some(attributes) = msg.attributes.as_ref() else {
debug!("ignoring invalid message having no attributes");
return;
};

let validator = UAttributesValidators::Request.validator();
if let Err(e) = validator.validate(attributes) {
self.process_invalid_request(e, attributes).await;
} else if let Some(resource_id) = attributes
.sink
.as_ref()
.and_then(|uri| u16::try_from(uri.resource_id).ok())
{
// the conversion cannot fail because request message validation has succeeded
self.process_valid_request(resource_id, msg).await;
if msg.is_request() {
// cannot fail because inbound messages are validated at the transport layer already
let method_id = msg.sink_unchecked().resource_id();
self.process_valid_request(method_id, msg).await;
} else {
debug!(
message_type = msg.type_unchecked().to_cloudevent_type(),
"ignoring non-request message received by RPC server"
);
}
}
}
Expand Down Expand Up @@ -295,7 +236,7 @@ mod tests {

use crate::{
communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider,
UAttributes, UMessageType, UPriority, UUri, UUID,
UAttributes, UCode, UUri, UUID,
};

fn new_uri_provider() -> Arc<dyn LocalUriProvider> {
Expand Down Expand Up @@ -426,113 +367,6 @@ mod tests {
assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener)));
}

#[tokio::test]
async fn test_request_listener_returns_response_for_invalid_request() {
// GIVEN an RpcServer for a transport
let mut request_handler = MockRequestHandler::new();
let mut transport = MockTransport::new();
let notify = Arc::new(Notify::new());
let notify_clone = notify.clone();
let message_id = UUID::build();
let request_id = message_id.clone();

request_handler.expect_handle_request().never();
transport
.expect_do_send()
.once()
.withf(move |response_message| {
if !response_message.is_response() {
return false;
}
if response_message.request_id_unchecked() != &request_id {
return false;
}
let error: UStatus = response_message.extract_protobuf().unwrap();
error.get_code() == UCode::INVALID_ARGUMENT
&& response_message.commstatus_unchecked() == error.get_code()
})
.returning(move |_msg| {
notify_clone.notify_one();
Ok(())
});

// WHEN the server receives a message on an endpoint which is not a
// valid RPC Request message but contains enough information to
// create a response
let invalid_request_attributes = UAttributes {
type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(),
sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(),
source: UUri::try_from("up://localhost/A100/1/0").ok().into(),
id: Some(message_id.clone()).into(),
priority: UPriority::UPRIORITY_CS5.into(),
..Default::default()
};
assert!(
UAttributesValidators::Request
.validator()
.validate(&invalid_request_attributes)
.is_err(),
"request message attributes are supposed to be invalid (no TTL)"
);
let invalid_request_message = UMessage {
attributes: Some(invalid_request_attributes).into(),
..Default::default()
};

let request_listener = RequestListener {
request_handler: Arc::new(request_handler),
transport: Arc::new(transport),
};
request_listener.on_receive(invalid_request_message).await;

// THEN the listener sends an error message in response to the invalid request
let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
assert!(result.is_ok());
}

#[tokio::test]
async fn test_request_listener_ignores_invalid_request() {
// GIVEN an RpcServer for a transport
let mut request_handler = MockRequestHandler::new();
request_handler.expect_handle_request().never();
let mut transport = MockTransport::new();
transport.expect_do_send().never();

// WHEN the server receives a message on an endpoint which is not a
// valid RPC Request message which does not contain enough information to
// create a response
let invalid_request_attributes = UAttributes {
type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(),
sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(),
source: UUri::try_from("up://localhost/A100/1/0").ok().into(),
ttl: Some(5_000),
id: None.into(),
priority: UPriority::UPRIORITY_CS5.into(),
..Default::default()
};
assert!(
UAttributesValidators::Request
.validator()
.validate(&invalid_request_attributes)
.is_err(),
"request message attributes are supposed to be invalid (no ID)"
);
let invalid_request_message = UMessage {
attributes: Some(invalid_request_attributes).into(),
..Default::default()
};

let request_listener = RequestListener {
request_handler: Arc::new(request_handler),
transport: Arc::new(transport),
};
request_listener.on_receive(invalid_request_message).await;

// THEN the listener ignores the invalid request
// let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
// assert!(result.is_ok());
}

#[tokio::test]
async fn test_request_listener_invokes_operation_successfully() {
let mut request_handler = MockRequestHandler::new();
Expand Down