Skip to content

Commit

Permalink
update accumulator sig to return Result<TResult> instead of TResult (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
john-z-yang committed May 17, 2024
1 parent 41457bf commit 2b0d1ee
Showing 1 changed file with 137 additions and 26 deletions.
163 changes: 137 additions & 26 deletions rust-arroyo/src/processing/strategies/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use super::InvalidMessage;

struct BatchState<T, TResult> {
value: Option<TResult>,
accumulator: Arc<dyn Fn(TResult, T) -> TResult + Send + Sync>,
accumulator: Arc<
dyn Fn(TResult, Message<T>) -> Result<TResult, (SubmitError<T>, TResult)> + Send + Sync,
>,
offsets: BTreeMap<Partition, u64>,
batch_start_time: Deadline,
message_count: usize,
Expand All @@ -24,7 +26,9 @@ struct BatchState<T, TResult> {
impl<T, TResult> BatchState<T, TResult> {
fn new(
initial_value: TResult,
accumulator: Arc<dyn Fn(TResult, T) -> TResult + Send + Sync>,
accumulator: Arc<
dyn Fn(TResult, Message<T>) -> Result<TResult, (SubmitError<T>, TResult)> + Send + Sync,
>,
max_batch_time: Duration,
compute_batch_size: fn(&T) -> usize,
) -> BatchState<T, TResult> {
Expand All @@ -38,21 +42,33 @@ impl<T, TResult> BatchState<T, TResult> {
}
}

fn add(&mut self, message: Message<T>) {
for (partition, offset) in message.committable() {
self.offsets.insert(partition, offset);
fn add(&mut self, message: Message<T>) -> Result<(), SubmitError<T>> {
let commitable: Vec<_> = message.committable().collect();
let message_count = (self.compute_batch_size)(&message.payload());
let prev_result = self.value.take().unwrap();

match (self.accumulator)(prev_result, message) {
Ok(result) => {
self.value = Some(result);
self.message_count += message_count;
for (partition, offset) in commitable {
self.offsets.insert(partition, offset);
}
Ok(())
}
Err((submit_error, prev_result)) => {
self.value = Some(prev_result);
Err(submit_error)
}
}

let tmp = self.value.take().unwrap();
let payload = message.into_payload();
self.message_count += (self.compute_batch_size)(&payload);
self.value = Some((self.accumulator)(tmp, payload));
}
}

pub struct Reduce<T, TResult> {
next_step: Box<dyn ProcessingStrategy<TResult>>,
accumulator: Arc<dyn Fn(TResult, T) -> TResult + Send + Sync>,
accumulator: Arc<
dyn Fn(TResult, Message<T>) -> Result<TResult, (SubmitError<T>, TResult)> + Send + Sync,
>,
initial_value: Arc<dyn Fn() -> TResult + Send + Sync>,
max_batch_size: usize,
max_batch_time: Duration,
Expand All @@ -79,7 +95,7 @@ impl<T: Send + Sync, TResult: Send + Sync> ProcessingStrategy<T> for Reduce<T, T
return Err(SubmitError::MessageRejected(MessageRejected { message }));
}

self.batch_state.add(message);
self.batch_state.add(message)?;

Ok(())
}
Expand Down Expand Up @@ -123,7 +139,9 @@ impl<T: Send + Sync, TResult: Send + Sync> ProcessingStrategy<T> for Reduce<T, T
impl<T, TResult> Reduce<T, TResult> {
pub fn new<N>(
next_step: N,
accumulator: Arc<dyn Fn(TResult, T) -> TResult + Send + Sync>,
accumulator: Arc<
dyn Fn(TResult, Message<T>) -> Result<TResult, (SubmitError<T>, TResult)> + Send + Sync,
>,
initial_value: Arc<dyn Fn() -> TResult + Send + Sync>,
max_batch_size: usize,
max_batch_time: Duration,
Expand Down Expand Up @@ -217,7 +235,7 @@ impl<T, TResult> Reduce<T, TResult> {
mod tests {
use crate::processing::strategies::reduce::Reduce;
use crate::processing::strategies::{
CommitRequest, ProcessingStrategy, StrategyError, SubmitError,
CommitRequest, MessageRejected, ProcessingStrategy, StrategyError, SubmitError,
};
use crate::types::{BrokerMessage, InnerMessage, Message, Partition, Topic};
use std::sync::{Arc, Mutex};
Expand Down Expand Up @@ -255,9 +273,9 @@ mod tests {
let max_batch_time = Duration::from_secs(1);

let initial_value = Vec::new();
let accumulator = Arc::new(|mut acc: Vec<u64>, value: u64| {
acc.push(value);
acc
let accumulator = Arc::new(|mut acc: Vec<u64>, msg: Message<u64>| {
acc.push(msg.into_payload());
Ok(acc)
});
let compute_batch_size = |_: &_| -> usize { 1 };

Expand Down Expand Up @@ -300,6 +318,99 @@ mod tests {
);
}

#[test]
fn test_reduce_with_backpressure() {
let submitted_messages = Arc::new(Mutex::new(Vec::new()));
let submitted_messages_clone = submitted_messages.clone();

let partition1 = Partition::new(Topic::new("test"), 0);

let max_batch_size = 2;
let max_batch_time = Duration::from_secs(1);

#[derive(Clone, Debug, PartialEq)]
struct Buffer<T> {
data: Vec<T>,
flushed: bool,
}

let initial_value = Buffer {
data: Vec::new(),
flushed: false,
};

let accumulator = Arc::new(move |mut acc: Buffer<u64>, msg: Message<u64>| {
if acc.flushed {
acc.data.push(msg.into_payload());
acc.flushed = false;
Ok(acc)
} else {
acc.flushed = true;
Err((
SubmitError::MessageRejected(MessageRejected { message: msg }),
acc,
))
}
});
let compute_batch_size = |_: &_| -> usize { 1 };

let next_step = NextStep {
submitted: submitted_messages,
};

let mut strategy = Reduce::new(
next_step,
accumulator,
Arc::new(move || initial_value.clone()),
max_batch_size,
max_batch_time,
compute_batch_size,
);

for i in 0..3 {
let msg = Message {
inner_message: InnerMessage::BrokerMessage(BrokerMessage::new(
i,
partition1,
i,
chrono::Utc::now(),
)),
};
let res = strategy.submit(msg);
match res {
Err(SubmitError::MessageRejected(MessageRejected { message })) => {
strategy.submit(message).unwrap();
}
_ => {
unreachable!("Strategy should have backpressured")
}
};
let _ = strategy.poll();
}

// 3 messages with a max batch size of 2 means 1 batch was cleared
// and 1 message is left before next size limit.
assert_eq!(strategy.batch_state.message_count, 1);

strategy.close();
let _ = strategy.join(None);

// 2 batches were created
assert_eq!(
*submitted_messages_clone.lock().unwrap(),
vec![
Buffer {
data: vec![0, 1],
flushed: false
},
Buffer {
data: vec![2],
flushed: false
}
]
);
}

#[test]
fn test_reduce_with_custom_batch_size() {
let submitted_messages = Arc::new(Mutex::new(Vec::new()));
Expand All @@ -311,9 +422,9 @@ mod tests {
let max_batch_time = Duration::from_secs(1);

let initial_value = Vec::new();
let accumulator = Arc::new(|mut acc: Vec<u64>, value: u64| {
acc.push(value);
acc
let accumulator = Arc::new(|mut acc: Vec<u64>, msg: Message<u64>| {
acc.push(msg.into_payload());
Ok(acc)
});
let compute_batch_size = |_: &_| -> usize { 5 };

Expand Down Expand Up @@ -367,9 +478,9 @@ mod tests {
let max_batch_time = Duration::from_secs(100);

let initial_value = Vec::new();
let accumulator = Arc::new(|mut acc: Vec<u64>, value: u64| {
acc.push(value);
acc
let accumulator = Arc::new(|mut acc: Vec<u64>, msg: Message<u64>| {
acc.push(msg.into_payload());
Ok(acc)
});
let compute_batch_size = |_: &_| -> usize { 0 };

Expand Down Expand Up @@ -420,9 +531,9 @@ mod tests {
let max_batch_time = Duration::from_secs(100);

let initial_value = Vec::new();
let accumulator = Arc::new(|mut acc: Vec<u64>, value: u64| {
acc.push(value);
acc
let accumulator = Arc::new(|mut acc: Vec<u64>, msg: Message<u64>| {
acc.push(msg.into_payload());
Ok(acc)
});
let compute_batch_size = |_: &_| -> usize { 0 };

Expand Down

0 comments on commit 2b0d1ee

Please sign in to comment.