@@ -251,6 +251,13 @@ def delivery_timeout_ms(self):
251251 def next_expiry_time_ms (self ):
252252 return self ._next_batch_expiry_time_ms
253253
254+ def _tp_lock (self , tp ):
255+ if tp not in self ._tp_locks :
256+ with self ._tp_locks [None ]:
257+ if tp not in self ._tp_locks :
258+ self ._tp_locks [tp ] = threading .Lock ()
259+ return self ._tp_locks [tp ]
260+
254261 def append (self , tp , timestamp_ms , key , value , headers , now = None ):
255262 """Add a record to the accumulator, return the append result.
256263
@@ -275,12 +282,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None):
275282 # not miss batches in abortIncompleteBatches().
276283 self ._appends_in_progress .increment ()
277284 try :
278- if tp not in self ._tp_locks :
279- with self ._tp_locks [None ]:
280- if tp not in self ._tp_locks :
281- self ._tp_locks [tp ] = threading .Lock ()
282-
283- with self ._tp_locks [tp ]:
285+ with self ._tp_lock (tp ):
284286 # check if we have an in-progress batch
285287 dq = self ._batches [tp ]
286288 if dq :
@@ -290,7 +292,7 @@ def append(self, tp, timestamp_ms, key, value, headers, now=None):
290292 batch_is_full = len (dq ) > 1 or last .records .is_full ()
291293 return future , batch_is_full , False
292294
293- with self ._tp_locks [ tp ] :
295+ with self ._tp_lock ( tp ) :
294296 # Need to check if producer is closed again after grabbing the
295297 # dequeue lock.
296298 assert not self ._closed , 'RecordAccumulator is closed'
@@ -333,8 +335,7 @@ def expired_batches(self, now=None):
333335 """Get a list of batches which have been sitting in the accumulator too long and need to be expired."""
334336 expired_batches = []
335337 for tp in list (self ._batches .keys ()):
336- assert tp in self ._tp_locks , 'TopicPartition not in locks dict'
337- with self ._tp_locks [tp ]:
338+ with self ._tp_lock (tp ):
338339 # iterate over the batches and expire them if they have stayed
339340 # in accumulator for more than request_timeout_ms
340341 dq = self ._batches [tp ]
@@ -352,14 +353,12 @@ def expired_batches(self, now=None):
352353
353354 def reenqueue (self , batch , now = None ):
354355 """
355- Re-enqueue the given record batch in the accumulator. In Sender.completeBatch method, we check
356- whether the batch has reached deliveryTimeoutMs or not. Hence we do not do the delivery timeout check here.
356+ Re-enqueue the given record batch in the accumulator. In Sender._complete_batch method, we check
357+ whether the batch has reached delivery_timeout_ms or not. Hence we do not do the delivery timeout check here.
357358 """
358359 batch .retry (now = now )
359- assert batch .topic_partition in self ._tp_locks , 'TopicPartition not in locks dict'
360- assert batch .topic_partition in self ._batches , 'TopicPartition not in batches'
361- dq = self ._batches [batch .topic_partition ]
362- with self ._tp_locks [batch .topic_partition ]:
360+ with self ._tp_lock (batch .topic_partition ):
361+ dq = self ._batches [batch .topic_partition ]
363362 dq .appendleft (batch )
364363
365364 def ready (self , cluster , now = None ):
@@ -412,7 +411,7 @@ def ready(self, cluster, now=None):
412411 elif tp in self .muted :
413412 continue
414413
415- with self ._tp_locks [ tp ] :
414+ with self ._tp_lock ( tp ) :
416415 dq = self ._batches [tp ]
417416 if not dq :
418417 continue
@@ -445,7 +444,7 @@ def ready(self, cluster, now=None):
445444 def has_undrained (self ):
446445 """Check whether there are any batches which haven't been drained"""
447446 for tp in list (self ._batches .keys ()):
448- with self ._tp_locks [ tp ] :
447+ with self ._tp_lock ( tp ) :
449448 dq = self ._batches [tp ]
450449 if len (dq ):
451450 return True
@@ -485,7 +484,7 @@ def drain_batches_for_one_node(self, cluster, node_id, max_size, now=None):
485484 if tp not in self ._batches :
486485 continue
487486
488- with self ._tp_locks [ tp ] :
487+ with self ._tp_lock ( tp ) :
489488 dq = self ._batches [tp ]
490489 if len (dq ) == 0 :
491490 continue
@@ -619,7 +618,7 @@ def _abort_batches(self, error):
619618 for batch in self ._incomplete .all ():
620619 tp = batch .topic_partition
621620 # Close the batch before aborting
622- with self ._tp_locks [ tp ] :
621+ with self ._tp_lock ( tp ) :
623622 batch .records .close ()
624623 self ._batches [tp ].remove (batch )
625624 batch .abort (error )
@@ -628,7 +627,7 @@ def _abort_batches(self, error):
628627 def abort_undrained_batches (self , error ):
629628 for batch in self ._incomplete .all ():
630629 tp = batch .topic_partition
631- with self ._tp_locks [ tp ] :
630+ with self ._tp_lock ( tp ) :
632631 aborted = False
633632 if not batch .is_done :
634633 aborted = True
0 commit comments