Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): set taken payload back in case of errors in MessageContext #3785

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 199 additions & 116 deletions core/src/message/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ impl MessageContext {
Ok(())
}

fn increase_counter(counter: u32, amount: impl TryInto<u32>, limit: u32) -> Option<u32> {
TryInto::<u32>::try_into(amount)
.ok()
.and_then(|amount| counter.checked_add(amount))
.and_then(|counter| (counter <= limit).then_some(counter))
}

/// Return bool defining was reply sent within the execution.
pub fn reply_sent(&self) -> bool {
self.outcome.reply.is_some()
Expand Down Expand Up @@ -303,53 +310,51 @@ impl MessageContext {
pub fn send_commit(
&mut self,
handle: u32,
packet: HandlePacket,
mut packet: HandlePacket,
delay: u32,
reservation: Option<ReservationId>,
) -> Result<MessageId, Error> {
if let Some(payload) = self.store.outgoing.get_mut(&handle) {
if let Some(data) = payload.take() {
let Some(new_outgoing_bytes) = self
.outgoing_bytes_counter
.checked_add(packet.payload_len())
.and_then(|counter| {
(counter <= self.settings.outgoing_bytes_limit).then_some(counter)
})
else {
*payload = Some(data);
return Err(Error::OutgoingMessagesBytesLimitExceeded);
};

// TODO: set data back if error #3779
let packet = {
let mut packet = packet;
packet
.try_prepend(data)
.map_err(|_| Error::MaxMessageSizeExceed)?;
packet
};

let message_id = MessageId::generate_outgoing(self.current.id(), handle);
let message = HandleMessage::from_packet(message_id, packet);

self.outcome.handle.push((message, delay, reservation));

// Increasing `outgoing_bytes_counter`, instead of decreasing it, because
// this counter takes into account also messages, that are already committed
// during this execution.
// The message subsequent executions will recalculate this counter from
// store outgoing messages (see `Self::new`),
// so committed during this execution messages won't be taken into account
// during next executions.
self.outgoing_bytes_counter = new_outgoing_bytes;

Ok(message_id)
} else {
Err(Error::LateAccess)
}
} else {
Err(Error::OutOfBounds)
}
let outgoing = self
.store
.outgoing
.get_mut(&handle)
.ok_or(Error::OutOfBounds)?;
let data = outgoing.take().ok_or(Error::LateAccess)?;

let do_send_commit = || {
let Some(new_outgoing_bytes) = Self::increase_counter(
self.outgoing_bytes_counter,
packet.payload_len(),
self.settings.outgoing_bytes_limit,
) else {
return Err((Error::OutgoingMessagesBytesLimitExceeded, data));
};

packet
.try_prepend(data)
.map_err(|data| (Error::MaxMessageSizeExceed, data))?;

let message_id = MessageId::generate_outgoing(self.current.id(), handle);
let message = HandleMessage::from_packet(message_id, packet);

self.outcome.handle.push((message, delay, reservation));

// Increasing `outgoing_bytes_counter`, instead of decreasing it,
// because this counter takes into account also messages,
// that are already committed during this execution.
// The message subsequent executions will recalculate this counter from
// store outgoing messages (see `Self::new`),
// so committed during this execution messages won't be taken into account
// during next executions.
self.outgoing_bytes_counter = new_outgoing_bytes;

Ok(message_id)
};

do_send_commit().map_err(|(err, data)| {
*outgoing = Some(data);
err
})
}

/// Provide space for storing payload for future message creation.
Expand All @@ -369,53 +374,51 @@ impl MessageContext {

/// Pushes payload into stored payload by handle.
pub fn send_push(&mut self, handle: u32, buffer: &[u8]) -> Result<(), Error> {
match self.store.outgoing.get_mut(&handle) {
Some(Some(data)) => {
let new_outgoing_bytes = u32::try_from(buffer.len())
.ok()
.and_then(|bytes_amount| self.outgoing_bytes_counter.checked_add(bytes_amount))
.and_then(|counter| {
(counter <= self.settings.outgoing_bytes_limit).then_some(counter)
})
.ok_or(Error::OutgoingMessagesBytesLimitExceeded)?;

data.try_extend_from_slice(buffer)
.map_err(|_| Error::MaxMessageSizeExceed)?;
self.outgoing_bytes_counter = new_outgoing_bytes;
Ok(())
}
Some(None) => Err(Error::LateAccess),
None => Err(Error::OutOfBounds),
}
let data = match self.store.outgoing.get_mut(&handle) {
Some(Some(data)) => data,
Some(None) => return Err(Error::LateAccess),
None => return Err(Error::OutOfBounds),
};

let new_outgoing_bytes = Self::increase_counter(
self.outgoing_bytes_counter,
buffer.len(),
self.settings.outgoing_bytes_limit,
)
.ok_or(Error::OutgoingMessagesBytesLimitExceeded)?;

data.try_extend_from_slice(buffer)
.map_err(|_| Error::MaxMessageSizeExceed)?;

self.outgoing_bytes_counter = new_outgoing_bytes;

Ok(())
}

/// Pushes the incoming buffer/payload into stored payload by handle.
pub fn send_push_input(&mut self, handle: u32, range: CheckedRange) -> Result<(), Error> {
let data = self
.store
.outgoing
.get_mut(&handle)
.ok_or(Error::OutOfBounds)?
.as_mut()
.ok_or(Error::LateAccess)?;
let data = match self.store.outgoing.get_mut(&handle) {
Some(Some(data)) => data,
Some(None) => return Err(Error::LateAccess),
None => return Err(Error::OutOfBounds),
};

let bytes_amount = range.len();
let CheckedRange {
offset,
excluded_end,
} = range;

let bytes_amount = excluded_end.checked_sub(offset).unwrap_or_else(|| {
unreachable!("`CheckedRange` must guarantee that `excluded_end` >= `offset`")
});

let new_outgoing_bytes = u32::try_from(bytes_amount)
.ok()
.and_then(|bytes_amount| self.outgoing_bytes_counter.checked_add(bytes_amount))
.and_then(|counter| (counter <= self.settings.outgoing_bytes_limit).then_some(counter))
.ok_or(Error::OutgoingMessagesBytesLimitExceeded)?;
let new_outgoing_bytes = Self::increase_counter(
self.outgoing_bytes_counter,
bytes_amount,
self.settings.outgoing_bytes_limit,
)
.ok_or(Error::OutgoingMessagesBytesLimitExceeded)?;

data.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end])
.map_err(|_| Error::MaxMessageSizeExceed)?;

self.outgoing_bytes_counter = new_outgoing_bytes;

Ok(())
Expand Down Expand Up @@ -451,47 +454,44 @@ impl MessageContext {
/// Returns message id.
pub fn reply_commit(
&mut self,
packet: ReplyPacket,
mut packet: ReplyPacket,
reservation: Option<ReservationId>,
) -> Result<MessageId, ExtError> {
self.check_reply_availability()?;

if !self.reply_sent() {
let data = self.store.reply.take().unwrap_or_default();
if self.reply_sent() {
return Err(Error::DuplicateReply.into());
}

// TODO: set data back if error #3779
let packet = {
let mut packet = packet;
packet
.try_prepend(data)
.map_err(|_| Error::MaxMessageSizeExceed)?;
packet
};
let data = self.store.reply.take().unwrap_or_default();

let message_id = MessageId::generate_reply(self.current.id());
let message = ReplyMessage::from_packet(message_id, packet);
if let Err(data) = packet.try_prepend(data) {
self.store.reply = Some(data);
return Err(Error::MaxMessageSizeExceed.into());
}

self.outcome.reply = Some((message, reservation));
let message_id = MessageId::generate_reply(self.current.id());
let message = ReplyMessage::from_packet(message_id, packet);

Ok(message_id)
} else {
Err(Error::DuplicateReply.into())
}
self.outcome.reply = Some((message, reservation));

Ok(message_id)
}

/// Pushes payload into stored reply payload.
pub fn reply_push(&mut self, buffer: &[u8]) -> Result<(), ExtError> {
self.check_reply_availability()?;

if !self.reply_sent() {
let data = self.store.reply.get_or_insert_with(Default::default);
data.try_extend_from_slice(buffer)
.map_err(|_| Error::MaxMessageSizeExceed)?;

Ok(())
} else {
Err(Error::LateAccess.into())
if self.reply_sent() {
return Err(Error::LateAccess.into());
}

// NOTE: it's normal to not undone `get_or_insert_with` in case of error
self.store
.reply
.get_or_insert_with(Default::default)
.try_extend_from_slice(buffer)
.map_err(|_| Error::MaxMessageSizeExceed.into())
}

/// Return reply destination.
Expand All @@ -503,20 +503,21 @@ impl MessageContext {
pub fn reply_push_input(&mut self, range: CheckedRange) -> Result<(), ExtError> {
self.check_reply_availability()?;

if !self.reply_sent() {
let CheckedRange {
offset,
excluded_end,
} = range;
if self.reply_sent() {
return Err(Error::LateAccess.into());
}

let data = self.store.reply.get_or_insert_with(Default::default);
data.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end])
.map_err(|_| Error::MaxMessageSizeExceed)?;
let CheckedRange {
offset,
excluded_end,
} = range;

Ok(())
} else {
Err(Error::LateAccess.into())
}
// NOTE: it's normal to not undone `get_or_insert_with` in case of error
self.store
.reply
.get_or_insert_with(Default::default)
.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end])
.map_err(|_| Error::MaxMessageSizeExceed.into())
}

/// Wake message by it's message id.
Expand Down Expand Up @@ -697,6 +698,64 @@ mod tests {
),
Error::OutgoingMessagesBytesLimitExceeded,
);

// commit 5 bytes should be ok.
assert_ok!(message_context.send_commit(
handle,
HandlePacket::new(
Default::default(),
Payload::try_from([1, 2, 3, 4, 5].to_vec()).unwrap(),
0,
),
0,
None,
));

let messages = message_context.drain().0.drain().outgoing_dispatches;
assert_eq!(
messages[0].0.payload_bytes(),
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
);
}

#[test]
fn send_commit_message_size_limit() {
let mut message_context = MessageContext::new(
Default::default(),
Default::default(),
ContextSettings::with_outgoing_limits(1024, u32::MAX),
)
.expect("Outgoing messages bytes limit exceeded");

let handle = message_context.send_init().unwrap();

// push 1 byte
assert_ok!(message_context.send_push(handle, &[1]));

let payload = Payload::filled_with(2);
assert_err!(
message_context.send_commit(
handle,
HandlePacket::new(Default::default(), payload, 0),
0,
None
),
Error::MaxMessageSizeExceed,
);

let payload = Payload::try_from(vec![1; Payload::max_len() - 1]).unwrap();
assert_ok!(message_context.send_commit(
handle,
HandlePacket::new(Default::default(), payload, 0),
0,
None,
));

let messages = message_context.drain().0.drain().outgoing_dispatches;
assert_eq!(
Payload::try_from(messages[0].0.payload_bytes().to_vec()).unwrap(),
Payload::filled_with(1)
);
}

#[test]
Expand Down Expand Up @@ -867,6 +926,30 @@ mod tests {
);
}

#[test]
fn reply_commit_message_size_limit() {
let mut message_context =
MessageContext::new(Default::default(), Default::default(), Default::default())
.expect("Outgoing messages bytes limit exceeded");

assert_ok!(message_context.reply_push(&[1]));

let payload = Payload::filled_with(2);
assert_err!(
message_context.reply_commit(ReplyPacket::new(payload, 0), None),
Error::MaxMessageSizeExceed,
);

let payload = Payload::try_from(vec![1; Payload::max_len() - 1]).unwrap();
assert_ok!(message_context.reply_commit(ReplyPacket::new(payload, 0), None));

let messages = message_context.drain().0.drain().outgoing_dispatches;
assert_eq!(
Payload::try_from(messages[0].0.payload_bytes().to_vec()).unwrap(),
Payload::filled_with(1)
);
}

#[test]
/// Test that covers full api of `MessageContext`
fn message_context_api() {
Expand Down
Loading
Loading