From ab2395725b4da4ab7019ac0dab91bead0c545cbe Mon Sep 17 00:00:00 2001 From: Kai Hudalla Date: Fri, 24 Oct 2025 10:32:40 +0200 Subject: [PATCH] Remove obsolete message validation Removed all incoming message validation logic and instead rely on the fact that the transport layer only forwards valid uProtocol messages. --- src/communication/in_memory_rpc_client.rs | 78 +++----- src/communication/in_memory_rpc_server.rs | 206 +++------------------- 2 files changed, 47 insertions(+), 237 deletions(-) diff --git a/src/communication/in_memory_rpc_client.rs b/src/communication/in_memory_rpc_client.rs index f3d8e7a..1b47cee 100644 --- a/src/communication/in_memory_rpc_client.rs +++ b/src/communication/in_memory_rpc_client.rs @@ -24,30 +24,22 @@ use tokio::time::timeout; use tracing::{debug, info}; use crate::{ - LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UMessageType, UStatus, - UTransport, UUri, UUID, + LocalUriProvider, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri, UUID, }; use super::{ build_message, CallOptions, RegistrationError, RpcClient, ServiceInvocationError, UPayload, }; +/// Handles an RPC Response message received from the transport layer. fn handle_response_message(response: UMessage) -> Result, ServiceInvocationError> { - let Some(attribs) = response.attributes.as_ref() else { - return Err(ServiceInvocationError::InvalidArgument( - "response message does not contain attributes".to_string(), - )); - }; - - match attribs.commstatus.map(|v| v.enum_value_or_default()) { + match response.commstatus() { Some(UCode::OK) | None => { // successful invocation - response.payload.map_or(Ok(None), |payload| { - Ok(Some(UPayload::new( - payload, - attribs.payload_format.enum_value_or_default(), - ))) - }) + let payload_format = response.payload_format().unwrap_or_default(); + Ok(response + .payload + .map(|payload| UPayload::new(payload, payload_format))) } Some(code) => { // try to extract UStatus from response payload @@ -86,15 +78,20 @@ impl ResponseListener { } } - fn handle_response(&self, reqid: &UUID, response_message: UMessage) { - let Ok(mut pending_requests) = self.pending_requests.lock() else { - info!( - request_id = reqid.to_hyphenated_string(), - "failed to process response message, cannot acquire lock for pending requests map" - ); - return; + fn handle_response(&self, response_message: UMessage) { + let reqid = response_message.request_id_unchecked().clone(); + let response_sender = { + // drop lock as soon as possible + let Ok(mut pending_requests) = self.pending_requests.lock() else { + info!( + request_id = reqid.to_hyphenated_string(), + "failed to process response message, cannot acquire lock for pending requests map" + ); + return; + }; + pending_requests.remove(&reqid) }; - if let Some(sender) = pending_requests.remove(reqid) { + if let Some(sender) = response_sender { if let Err(_e) = sender.send(response_message) { // channel seems to be closed already debug!( @@ -133,27 +130,15 @@ impl ResponseListener { #[async_trait] impl UListener for ResponseListener { async fn on_receive(&self, msg: UMessage) { - let message_type = msg - .attributes - .get_or_default() - .type_ - .enum_value_or_default(); - if message_type != UMessageType::UMESSAGE_TYPE_RESPONSE { + // it is sufficient to check if the message is a response + // because the transport implementation forwards valid UMessages only + if msg.is_response() { + self.handle_response(msg); + } else { debug!( - message_type = message_type.to_cloudevent_type(), - "service provider replied with message that is not an RPC Response" + message_type = msg.type_unchecked().to_cloudevent_type(), + "ignoring non-response message received by RPC client" ); - return; - } - - if let Some(reqid) = msg - .attributes - .as_ref() - .and_then(|attribs| attribs.reqid.clone().into_option()) - { - self.handle_response(&reqid, msg); - } else { - debug!("ignoring malformed response message not containing request ID"); } } } @@ -617,13 +602,4 @@ mod tests { assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) })); assert!(!client.contains_pending_request(&message_id)); } - - #[test] - fn test_handle_response_message_fails_for_missing_attributes() { - let response_msg = UMessage { - ..Default::default() - }; - let result = handle_response_message(response_msg); - assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_)))); - } } diff --git a/src/communication/in_memory_rpc_server.rs b/src/communication/in_memory_rpc_server.rs index 31ef8e9..e829298 100644 --- a/src/communication/in_memory_rpc_server.rs +++ b/src/communication/in_memory_rpc_server.rs @@ -22,8 +22,8 @@ use async_trait::async_trait; use tracing::{debug, info}; use crate::{ - communication::build_message, LocalUriProvider, UAttributes, UAttributesError, - UAttributesValidators, UCode, UListener, UMessage, UMessageBuilder, UStatus, UTransport, UUri, + communication::build_message, LocalUriProvider, UListener, UMessage, UMessageBuilder, UStatus, + UTransport, UUri, }; use super::{RegistrationError, RequestHandler, RpcServer, ServiceInvocationError, UPayload}; @@ -37,26 +37,16 @@ impl RequestListener { async fn process_valid_request(&self, resource_id: u16, request_message: UMessage) { let transport_clone = self.transport.clone(); let request_handler_clone = self.request_handler.clone(); + let mut response_builder = + UMessageBuilder::response_for_request(request_message.attributes_unchecked()); - let request_id = request_message - .attributes - .get_or_default() - .id - .get_or_default(); - let request_timeout = request_message - .attributes - .get_or_default() - .ttl - .unwrap_or(10_000); + let request_message_id = request_message.id_unchecked().clone(); + let request_timeout = request_message.ttl_unchecked(); + let payload_format = request_message.payload_format().unwrap_or_default(); let payload = request_message.payload; - let payload_format = request_message - .attributes - .get_or_default() - .payload_format - .enum_value_or_default(); let request_payload = payload.map(|data| UPayload::new(data, payload_format)); - debug!(ttl = request_timeout, id = %request_id, "processing RPC request"); + debug!(ttl = request_timeout, id = %request_message_id, "processing RPC request"); let invocation_result_future = request_handler_clone.handle_request( resource_id, @@ -75,15 +65,10 @@ impl RequestListener { .and_then(|v| v); let response = match outcome { - Ok(response_payload) => { - let mut builder = UMessageBuilder::response_for_request( - request_message.attributes.get_or_default(), - ); - build_message(&mut builder, response_payload) - } + Ok(response_payload) => build_message(&mut response_builder, response_payload), Err(e) => { let error = UStatus::from(e); - UMessageBuilder::response_for_request(request_message.attributes.get_or_default()) + response_builder .with_comm_status(error.get_code()) .build_with_protobuf_payload(&error) } @@ -100,64 +85,20 @@ impl RequestListener { } } } - - async fn process_invalid_request( - &self, - validation_error: UAttributesError, - request_attributes: &UAttributes, - ) { - // all we need is a valid source address and a message ID to be able to send back an error message - let (Some(id), Some(source_address)) = ( - request_attributes.id.to_owned().into_option(), - request_attributes - .source - .to_owned() - .into_option() - .filter(|uri| uri.is_rpc_response()), - ) else { - debug!("invalid request message does not contain enough data to create response"); - return; - }; - - debug!(id = %id, "processing invalid request message"); - - let response_payload = - UStatus::fail_with_code(UCode::INVALID_ARGUMENT, validation_error.to_string()); - let Ok(response_message) = UMessageBuilder::response( - source_address, - id, - request_attributes.sink.get_or_default().to_owned(), - ) - .with_comm_status(response_payload.get_code()) - .build_with_protobuf_payload(&response_payload) else { - info!("failed to create error message"); - return; - }; - - if let Err(e) = self.transport.send(response_message).await { - info!(ucode = e.code.value(), "failed to send error response"); - } - } } #[async_trait] impl UListener for RequestListener { async fn on_receive(&self, msg: UMessage) { - let Some(attributes) = msg.attributes.as_ref() else { - debug!("ignoring invalid message having no attributes"); - return; - }; - - let validator = UAttributesValidators::Request.validator(); - if let Err(e) = validator.validate(attributes) { - self.process_invalid_request(e, attributes).await; - } else if let Some(resource_id) = attributes - .sink - .as_ref() - .and_then(|uri| u16::try_from(uri.resource_id).ok()) - { - // the conversion cannot fail because request message validation has succeeded - self.process_valid_request(resource_id, msg).await; + if msg.is_request() { + // cannot fail because inbound messages are validated at the transport layer already + let method_id = msg.sink_unchecked().resource_id(); + self.process_valid_request(method_id, msg).await; + } else { + debug!( + message_type = msg.type_unchecked().to_cloudevent_type(), + "ignoring non-request message received by RPC server" + ); } } } @@ -295,7 +236,7 @@ mod tests { use crate::{ communication::rpc::MockRequestHandler, utransport::MockTransport, StaticUriProvider, - UAttributes, UMessageType, UPriority, UUri, UUID, + UAttributes, UCode, UUri, UUID, }; fn new_uri_provider() -> Arc { @@ -426,113 +367,6 @@ mod tests { assert!(result.is_err_and(|e| matches!(e, RegistrationError::NoSuchListener))); } - #[tokio::test] - async fn test_request_listener_returns_response_for_invalid_request() { - // GIVEN an RpcServer for a transport - let mut request_handler = MockRequestHandler::new(); - let mut transport = MockTransport::new(); - let notify = Arc::new(Notify::new()); - let notify_clone = notify.clone(); - let message_id = UUID::build(); - let request_id = message_id.clone(); - - request_handler.expect_handle_request().never(); - transport - .expect_do_send() - .once() - .withf(move |response_message| { - if !response_message.is_response() { - return false; - } - if response_message.request_id_unchecked() != &request_id { - return false; - } - let error: UStatus = response_message.extract_protobuf().unwrap(); - error.get_code() == UCode::INVALID_ARGUMENT - && response_message.commstatus_unchecked() == error.get_code() - }) - .returning(move |_msg| { - notify_clone.notify_one(); - Ok(()) - }); - - // WHEN the server receives a message on an endpoint which is not a - // valid RPC Request message but contains enough information to - // create a response - let invalid_request_attributes = UAttributes { - type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(), - sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(), - source: UUri::try_from("up://localhost/A100/1/0").ok().into(), - id: Some(message_id.clone()).into(), - priority: UPriority::UPRIORITY_CS5.into(), - ..Default::default() - }; - assert!( - UAttributesValidators::Request - .validator() - .validate(&invalid_request_attributes) - .is_err(), - "request message attributes are supposed to be invalid (no TTL)" - ); - let invalid_request_message = UMessage { - attributes: Some(invalid_request_attributes).into(), - ..Default::default() - }; - - let request_listener = RequestListener { - request_handler: Arc::new(request_handler), - transport: Arc::new(transport), - }; - request_listener.on_receive(invalid_request_message).await; - - // THEN the listener sends an error message in response to the invalid request - let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn test_request_listener_ignores_invalid_request() { - // GIVEN an RpcServer for a transport - let mut request_handler = MockRequestHandler::new(); - request_handler.expect_handle_request().never(); - let mut transport = MockTransport::new(); - transport.expect_do_send().never(); - - // WHEN the server receives a message on an endpoint which is not a - // valid RPC Request message which does not contain enough information to - // create a response - let invalid_request_attributes = UAttributes { - type_: UMessageType::UMESSAGE_TYPE_REQUEST.into(), - sink: UUri::try_from("up://localhost/A200/1/7000").ok().into(), - source: UUri::try_from("up://localhost/A100/1/0").ok().into(), - ttl: Some(5_000), - id: None.into(), - priority: UPriority::UPRIORITY_CS5.into(), - ..Default::default() - }; - assert!( - UAttributesValidators::Request - .validator() - .validate(&invalid_request_attributes) - .is_err(), - "request message attributes are supposed to be invalid (no ID)" - ); - let invalid_request_message = UMessage { - attributes: Some(invalid_request_attributes).into(), - ..Default::default() - }; - - let request_listener = RequestListener { - request_handler: Arc::new(request_handler), - transport: Arc::new(transport), - }; - request_listener.on_receive(invalid_request_message).await; - - // THEN the listener ignores the invalid request - // let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await; - // assert!(result.is_ok()); - } - #[tokio::test] async fn test_request_listener_invokes_operation_successfully() { let mut request_handler = MockRequestHandler::new();