diff --git a/core/src/message/context.rs b/core/src/message/context.rs index 157cf214107..05b3f61e933 100644 --- a/core/src/message/context.rs +++ b/core/src/message/context.rs @@ -260,6 +260,13 @@ impl MessageContext { Ok(()) } + fn increase_counter(counter: u32, amount: impl TryInto, limit: u32) -> Option { + TryInto::::try_into(amount) + .ok() + .and_then(|amount| counter.checked_add(amount)) + .and_then(|counter| (counter <= limit).then_some(counter)) + } + /// Return bool defining was reply sent within the execution. pub fn reply_sent(&self) -> bool { self.outcome.reply.is_some() @@ -303,53 +310,51 @@ impl MessageContext { pub fn send_commit( &mut self, handle: u32, - packet: HandlePacket, + mut packet: HandlePacket, delay: u32, reservation: Option, ) -> Result { - if let Some(payload) = self.store.outgoing.get_mut(&handle) { - if let Some(data) = payload.take() { - let Some(new_outgoing_bytes) = self - .outgoing_bytes_counter - .checked_add(packet.payload_len()) - .and_then(|counter| { - (counter <= self.settings.outgoing_bytes_limit).then_some(counter) - }) - else { - *payload = Some(data); - return Err(Error::OutgoingMessagesBytesLimitExceeded); - }; - - // TODO: set data back if error #3779 - let packet = { - let mut packet = packet; - packet - .try_prepend(data) - .map_err(|_| Error::MaxMessageSizeExceed)?; - packet - }; - - let message_id = MessageId::generate_outgoing(self.current.id(), handle); - let message = HandleMessage::from_packet(message_id, packet); - - self.outcome.handle.push((message, delay, reservation)); - - // Increasing `outgoing_bytes_counter`, instead of decreasing it, because - // this counter takes into account also messages, that are already committed - // during this execution. - // The message subsequent executions will recalculate this counter from - // store outgoing messages (see `Self::new`), - // so committed during this execution messages won't be taken into account - // during next executions. - self.outgoing_bytes_counter = new_outgoing_bytes; - - Ok(message_id) - } else { - Err(Error::LateAccess) - } - } else { - Err(Error::OutOfBounds) - } + let outgoing = self + .store + .outgoing + .get_mut(&handle) + .ok_or(Error::OutOfBounds)?; + let data = outgoing.take().ok_or(Error::LateAccess)?; + + let do_send_commit = || { + let Some(new_outgoing_bytes) = Self::increase_counter( + self.outgoing_bytes_counter, + packet.payload_len(), + self.settings.outgoing_bytes_limit, + ) else { + return Err((Error::OutgoingMessagesBytesLimitExceeded, data)); + }; + + packet + .try_prepend(data) + .map_err(|data| (Error::MaxMessageSizeExceed, data))?; + + let message_id = MessageId::generate_outgoing(self.current.id(), handle); + let message = HandleMessage::from_packet(message_id, packet); + + self.outcome.handle.push((message, delay, reservation)); + + // Increasing `outgoing_bytes_counter`, instead of decreasing it, + // because this counter takes into account also messages, + // that are already committed during this execution. + // The message subsequent executions will recalculate this counter from + // store outgoing messages (see `Self::new`), + // so committed during this execution messages won't be taken into account + // during next executions. + self.outgoing_bytes_counter = new_outgoing_bytes; + + Ok(message_id) + }; + + do_send_commit().map_err(|(err, data)| { + *outgoing = Some(data); + err + }) } /// Provide space for storing payload for future message creation. @@ -369,53 +374,51 @@ impl MessageContext { /// Pushes payload into stored payload by handle. pub fn send_push(&mut self, handle: u32, buffer: &[u8]) -> Result<(), Error> { - match self.store.outgoing.get_mut(&handle) { - Some(Some(data)) => { - let new_outgoing_bytes = u32::try_from(buffer.len()) - .ok() - .and_then(|bytes_amount| self.outgoing_bytes_counter.checked_add(bytes_amount)) - .and_then(|counter| { - (counter <= self.settings.outgoing_bytes_limit).then_some(counter) - }) - .ok_or(Error::OutgoingMessagesBytesLimitExceeded)?; - - data.try_extend_from_slice(buffer) - .map_err(|_| Error::MaxMessageSizeExceed)?; - self.outgoing_bytes_counter = new_outgoing_bytes; - Ok(()) - } - Some(None) => Err(Error::LateAccess), - None => Err(Error::OutOfBounds), - } + let data = match self.store.outgoing.get_mut(&handle) { + Some(Some(data)) => data, + Some(None) => return Err(Error::LateAccess), + None => return Err(Error::OutOfBounds), + }; + + let new_outgoing_bytes = Self::increase_counter( + self.outgoing_bytes_counter, + buffer.len(), + self.settings.outgoing_bytes_limit, + ) + .ok_or(Error::OutgoingMessagesBytesLimitExceeded)?; + + data.try_extend_from_slice(buffer) + .map_err(|_| Error::MaxMessageSizeExceed)?; + + self.outgoing_bytes_counter = new_outgoing_bytes; + + Ok(()) } /// Pushes the incoming buffer/payload into stored payload by handle. pub fn send_push_input(&mut self, handle: u32, range: CheckedRange) -> Result<(), Error> { - let data = self - .store - .outgoing - .get_mut(&handle) - .ok_or(Error::OutOfBounds)? - .as_mut() - .ok_or(Error::LateAccess)?; + let data = match self.store.outgoing.get_mut(&handle) { + Some(Some(data)) => data, + Some(None) => return Err(Error::LateAccess), + None => return Err(Error::OutOfBounds), + }; + let bytes_amount = range.len(); let CheckedRange { offset, excluded_end, } = range; - let bytes_amount = excluded_end.checked_sub(offset).unwrap_or_else(|| { - unreachable!("`CheckedRange` must guarantee that `excluded_end` >= `offset`") - }); - - let new_outgoing_bytes = u32::try_from(bytes_amount) - .ok() - .and_then(|bytes_amount| self.outgoing_bytes_counter.checked_add(bytes_amount)) - .and_then(|counter| (counter <= self.settings.outgoing_bytes_limit).then_some(counter)) - .ok_or(Error::OutgoingMessagesBytesLimitExceeded)?; + let new_outgoing_bytes = Self::increase_counter( + self.outgoing_bytes_counter, + bytes_amount, + self.settings.outgoing_bytes_limit, + ) + .ok_or(Error::OutgoingMessagesBytesLimitExceeded)?; data.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end]) .map_err(|_| Error::MaxMessageSizeExceed)?; + self.outgoing_bytes_counter = new_outgoing_bytes; Ok(()) @@ -451,47 +454,44 @@ impl MessageContext { /// Returns message id. pub fn reply_commit( &mut self, - packet: ReplyPacket, + mut packet: ReplyPacket, reservation: Option, ) -> Result { self.check_reply_availability()?; - if !self.reply_sent() { - let data = self.store.reply.take().unwrap_or_default(); + if self.reply_sent() { + return Err(Error::DuplicateReply.into()); + } - // TODO: set data back if error #3779 - let packet = { - let mut packet = packet; - packet - .try_prepend(data) - .map_err(|_| Error::MaxMessageSizeExceed)?; - packet - }; + let data = self.store.reply.take().unwrap_or_default(); - let message_id = MessageId::generate_reply(self.current.id()); - let message = ReplyMessage::from_packet(message_id, packet); + if let Err(data) = packet.try_prepend(data) { + self.store.reply = Some(data); + return Err(Error::MaxMessageSizeExceed.into()); + } - self.outcome.reply = Some((message, reservation)); + let message_id = MessageId::generate_reply(self.current.id()); + let message = ReplyMessage::from_packet(message_id, packet); - Ok(message_id) - } else { - Err(Error::DuplicateReply.into()) - } + self.outcome.reply = Some((message, reservation)); + + Ok(message_id) } /// Pushes payload into stored reply payload. pub fn reply_push(&mut self, buffer: &[u8]) -> Result<(), ExtError> { self.check_reply_availability()?; - if !self.reply_sent() { - let data = self.store.reply.get_or_insert_with(Default::default); - data.try_extend_from_slice(buffer) - .map_err(|_| Error::MaxMessageSizeExceed)?; - - Ok(()) - } else { - Err(Error::LateAccess.into()) + if self.reply_sent() { + return Err(Error::LateAccess.into()); } + + // NOTE: it's normal to not undone `get_or_insert_with` in case of error + self.store + .reply + .get_or_insert_with(Default::default) + .try_extend_from_slice(buffer) + .map_err(|_| Error::MaxMessageSizeExceed.into()) } /// Return reply destination. @@ -503,20 +503,21 @@ impl MessageContext { pub fn reply_push_input(&mut self, range: CheckedRange) -> Result<(), ExtError> { self.check_reply_availability()?; - if !self.reply_sent() { - let CheckedRange { - offset, - excluded_end, - } = range; + if self.reply_sent() { + return Err(Error::LateAccess.into()); + } - let data = self.store.reply.get_or_insert_with(Default::default); - data.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end]) - .map_err(|_| Error::MaxMessageSizeExceed)?; + let CheckedRange { + offset, + excluded_end, + } = range; - Ok(()) - } else { - Err(Error::LateAccess.into()) - } + // NOTE: it's normal to not undone `get_or_insert_with` in case of error + self.store + .reply + .get_or_insert_with(Default::default) + .try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end]) + .map_err(|_| Error::MaxMessageSizeExceed.into()) } /// Wake message by it's message id. @@ -697,6 +698,64 @@ mod tests { ), Error::OutgoingMessagesBytesLimitExceeded, ); + + // commit 5 bytes should be ok. + assert_ok!(message_context.send_commit( + handle, + HandlePacket::new( + Default::default(), + Payload::try_from([1, 2, 3, 4, 5].to_vec()).unwrap(), + 0, + ), + 0, + None, + )); + + let messages = message_context.drain().0.drain().outgoing_dispatches; + assert_eq!( + messages[0].0.payload_bytes(), + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + ); + } + + #[test] + fn send_commit_message_size_limit() { + let mut message_context = MessageContext::new( + Default::default(), + Default::default(), + ContextSettings::with_outgoing_limits(1024, u32::MAX), + ) + .expect("Outgoing messages bytes limit exceeded"); + + let handle = message_context.send_init().unwrap(); + + // push 1 byte + assert_ok!(message_context.send_push(handle, &[1])); + + let payload = Payload::filled_with(2); + assert_err!( + message_context.send_commit( + handle, + HandlePacket::new(Default::default(), payload, 0), + 0, + None + ), + Error::MaxMessageSizeExceed, + ); + + let payload = Payload::try_from(vec![1; Payload::max_len() - 1]).unwrap(); + assert_ok!(message_context.send_commit( + handle, + HandlePacket::new(Default::default(), payload, 0), + 0, + None, + )); + + let messages = message_context.drain().0.drain().outgoing_dispatches; + assert_eq!( + Payload::try_from(messages[0].0.payload_bytes().to_vec()).unwrap(), + Payload::filled_with(1) + ); } #[test] @@ -867,6 +926,30 @@ mod tests { ); } + #[test] + fn reply_commit_message_size_limit() { + let mut message_context = + MessageContext::new(Default::default(), Default::default(), Default::default()) + .expect("Outgoing messages bytes limit exceeded"); + + assert_ok!(message_context.reply_push(&[1])); + + let payload = Payload::filled_with(2); + assert_err!( + message_context.reply_commit(ReplyPacket::new(payload, 0), None), + Error::MaxMessageSizeExceed, + ); + + let payload = Payload::try_from(vec![1; Payload::max_len() - 1]).unwrap(); + assert_ok!(message_context.reply_commit(ReplyPacket::new(payload, 0), None)); + + let messages = message_context.drain().0.drain().outgoing_dispatches; + assert_eq!( + Payload::try_from(messages[0].0.payload_bytes().to_vec()).unwrap(), + Payload::filled_with(1) + ); + } + #[test] /// Test that covers full api of `MessageContext` fn message_context_api() { diff --git a/core/src/message/handle.rs b/core/src/message/handle.rs index cb2527a43fb..d2991a0673f 100644 --- a/core/src/message/handle.rs +++ b/core/src/message/handle.rs @@ -28,8 +28,6 @@ use scale_info::{ TypeInfo, }; -use super::PayloadSizeError; - /// Message for Handle entry point. /// Represents a standard message that sends between actors. #[derive(Clone, Default, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Decode, Encode, TypeInfo)] @@ -167,8 +165,13 @@ impl HandlePacket { } /// Prepend payload. - pub(super) fn try_prepend(&mut self, data: Payload) -> Result<(), PayloadSizeError> { - self.payload.try_prepend(data) + pub(super) fn try_prepend(&mut self, mut data: Payload) -> Result<(), Payload> { + if data.try_extend_from_slice(self.payload_bytes()).is_err() { + Err(data) + } else { + self.payload = data; + Ok(()) + } } /// Packet destination. diff --git a/core/src/message/reply.rs b/core/src/message/reply.rs index e7c61c1fa6b..238de918a52 100644 --- a/core/src/message/reply.rs +++ b/core/src/message/reply.rs @@ -16,7 +16,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -use super::{common::ReplyDetails, PayloadSizeError}; +use super::common::ReplyDetails; use crate::{ ids::{MessageId, ProgramId}, message::{ @@ -222,8 +222,13 @@ impl ReplyPacket { } /// Prepend payload. - pub(super) fn try_prepend(&mut self, data: Payload) -> Result<(), PayloadSizeError> { - self.payload.try_prepend(data) + pub(super) fn try_prepend(&mut self, mut data: Payload) -> Result<(), Payload> { + if data.try_extend_from_slice(self.payload_bytes()).is_err() { + Err(data) + } else { + self.payload = data; + Ok(()) + } } /// Packet status code.