Skip to content

Commit

Permalink
fix(rust): Clear all backpressure state between assignments
Browse files Browse the repository at this point in the history
This is a port of getsentry/arroyo#299

SNS-2517
  • Loading branch information
untitaker committed Nov 17, 2023
1 parent 05a0ebb commit 462a2fa
Showing 1 changed file with 45 additions and 29 deletions.
74 changes: 45 additions & 29 deletions rust_snuba/rust_arroyo/src/processing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,28 @@ pub enum RunError {
Pause(#[source] ConsumerError),
}

struct Strategies<TPayload> {
struct ConsumerState<TPayload> {
processing_factory: Box<dyn ProcessingStrategyFactory<TPayload>>,
strategy: Option<Box<dyn ProcessingStrategy<TPayload>>>,
backpressure_timestamp: Option<Instant>,
is_paused: bool,
metrics_buffer: metrics_buffer::MetricsBuffer,
}

impl<TPayload> ConsumerState<TPayload> {
fn clear_backpressure(&mut self) {
if self.backpressure_timestamp.is_some() {
self.metrics_buffer.incr_timing(
"arroyo.consumer.backpressure.time",
self.backpressure_timestamp.unwrap().elapsed(),
);
self.backpressure_timestamp = None;
}
}
}

struct Callbacks<TPayload> {
strategies: Arc<Mutex<Strategies<TPayload>>>,
strategies: Arc<Mutex<ConsumerState<TPayload>>>,
consumer: Arc<Mutex<dyn Consumer<TPayload>>>,
}

Expand Down Expand Up @@ -85,6 +100,8 @@ impl<TPayload: 'static> AssignmentCallbacks for Callbacks<TPayload> {
}
}
stg.strategy = None;
stg.is_paused = false;
stg.clear_backpressure();

metrics.timing(
"arroyo.consumer.join.time",
Expand All @@ -100,7 +117,7 @@ impl<TPayload: 'static> AssignmentCallbacks for Callbacks<TPayload> {

impl<TPayload> Callbacks<TPayload> {
pub fn new(
strategies: Arc<Mutex<Strategies<TPayload>>>,
strategies: Arc<Mutex<ConsumerState<TPayload>>>,
consumer: Arc<Mutex<dyn Consumer<TPayload>>>,
) -> Self {
Self {
Expand All @@ -116,11 +133,9 @@ impl<TPayload> Callbacks<TPayload> {
/// partition revocation.
pub struct StreamProcessor<TPayload> {
consumer: Arc<Mutex<dyn Consumer<TPayload>>>,
strategies: Arc<Mutex<Strategies<TPayload>>>,
consumer_state: Arc<Mutex<ConsumerState<TPayload>>>,
message: Option<Message<TPayload>>,
processor_handle: ProcessorHandle,
backpressure_timestamp: Option<Instant>,
is_paused: bool,
metrics_buffer: metrics_buffer::MetricsBuffer,
}

Expand All @@ -129,27 +144,28 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {
consumer: Arc<Mutex<dyn Consumer<TPayload>>>,
processing_factory: Box<dyn ProcessingStrategyFactory<TPayload>>,
) -> Self {
let strategies = Arc::new(Mutex::new(Strategies {
let consumer_state = Arc::new(Mutex::new(ConsumerState {
processing_factory,
strategy: None,
backpressure_timestamp: None,
is_paused: false,
metrics_buffer: metrics_buffer::MetricsBuffer::new(),
}));

Self {
consumer,
strategies,
consumer_state,
message: None,
processor_handle: ProcessorHandle {
shutdown_requested: Arc::new(AtomicBool::new(false)),
},
backpressure_timestamp: None,
is_paused: false,
metrics_buffer: metrics_buffer::MetricsBuffer::new(),
}
}

pub fn subscribe(&mut self, topic: Topic) {
let callbacks: Box<dyn AssignmentCallbacks> = Box::new(Callbacks::new(
self.strategies.clone(),
self.consumer_state.clone(),
self.consumer.clone(),
));
self.consumer
Expand All @@ -163,7 +179,7 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {
let metrics = get_metrics();
metrics.increment("arroyo.consumer.run_once", 1, None);

if self.is_paused {
if self.consumer_state.lock().unwrap().is_paused {
// If the consumer waas paused, it should not be returning any messages
// on ``poll``.
let res = self
Expand Down Expand Up @@ -202,8 +218,12 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {
}
}

let mut trait_callbacks = self.strategies.lock().unwrap();
let Some(strategy) = trait_callbacks.strategy.as_mut() else {
// since we do not drive the kafka consumer at this point, it is safe to acquire the state
// lock, as we can be sure that for the rest of this function, no assignment callback will
// run.
let mut consumer_state = self.consumer_state.lock().unwrap();

let Some(strategy) = consumer_state.strategy.as_mut() else {
match self.message.as_ref() {
None => return Ok(()),
Some(_) => return Err(RunError::InvalidState),
Expand Down Expand Up @@ -231,16 +251,17 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {
"arroyo.consumer.processing.time",
processing_start.elapsed(),
);

match ret {
Ok(()) => {
// Resume if we are currently in a paused state
if self.is_paused {
if consumer_state.is_paused {
let mut consumer = self.consumer.lock().unwrap();
let partitions = consumer.tell().unwrap().into_keys().collect();

match consumer.resume(partitions) {
Ok(()) => {
self.is_paused = false;
consumer_state.is_paused = false;
}
Err(error) => {
tracing::error!(%error, "pause error");
Expand All @@ -250,27 +271,22 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {
}

// Clear backpressure timestamp if it is set
if self.backpressure_timestamp.is_some() {
self.metrics_buffer.incr_timing(
"arroyo.consumer.backpressure.time",
self.backpressure_timestamp.unwrap().elapsed(),
);
self.backpressure_timestamp = None;
}
consumer_state.clear_backpressure();
}
Err(SubmitError::MessageRejected(MessageRejected { message })) => {
// Put back the carried over message
self.message = Some(message);

if self.backpressure_timestamp.is_none() {
self.backpressure_timestamp = Some(Instant::now());
if consumer_state.backpressure_timestamp.is_none() {
consumer_state.backpressure_timestamp = Some(Instant::now());
}

// If we are in the backpressure state for more than 1 second,
// we pause the consumer and hold the message until it is
// accepted, at which point we can resume consuming.
if !self.is_paused && self.backpressure_timestamp.is_some() {
let backpressure_duration = self.backpressure_timestamp.unwrap().elapsed();
if !consumer_state.is_paused && consumer_state.backpressure_timestamp.is_some() {
let backpressure_duration =
consumer_state.backpressure_timestamp.unwrap().elapsed();

if backpressure_duration < Duration::from_secs(1) {
return Ok(());
Expand All @@ -285,7 +301,7 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {

match consumer.pause(partitions) {
Ok(()) => {
self.is_paused = true;
consumer_state.is_paused = true;
}
Err(error) => {
tracing::error!(%error, "pause error");
Expand All @@ -310,7 +326,7 @@ impl<TPayload: 'static> StreamProcessor<TPayload> {
.load(Ordering::Relaxed)
{
if let Err(e) = self.run_once() {
let mut trait_callbacks = self.strategies.lock().unwrap();
let mut trait_callbacks = self.consumer_state.lock().unwrap();

if let Some(strategy) = trait_callbacks.strategy.as_mut() {
strategy.terminate();
Expand Down

0 comments on commit 462a2fa

Please sign in to comment.