Skip to content

Commit

Permalink
Merge pull request #314 from dpkp/keyed_producer_failover
Browse files Browse the repository at this point in the history
Handle keyed producer failover
  • Loading branch information
dpkp committed Feb 20, 2015
2 parents 60a7378 + 6ed6ad5 commit 9ad0be6
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 145 deletions.
7 changes: 3 additions & 4 deletions kafka/partitioner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ def __init__(self, partitions):
"""
self.partitions = partitions

def partition(self, key, partitions):
def partition(self, key, partitions=None):
"""
Takes a string key and num_partitions as argument and returns
a partition to be used for the message
Arguments:
partitions: The list of partitions is passed in every call. This
may look like an overhead, but it will be useful
(in future) when we handle cases like rebalancing
key: the key to use for partitioning
partitions: (optional) a list of partitions.
"""
raise NotImplementedError('partition function has to be implemented')
4 changes: 3 additions & 1 deletion kafka/partitioner/hashed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ class HashedPartitioner(Partitioner):
Implements a partitioner which selects the target partition based on
the hash of the key
"""
def partition(self, key, partitions):
def partition(self, key, partitions=None):
if not partitions:
partitions = self.partitions
size = len(partitions)
idx = hash(key) % size

Expand Down
4 changes: 2 additions & 2 deletions kafka/partitioner/roundrobin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def _set_partitions(self, partitions):
self.partitions = partitions
self.iterpart = cycle(partitions)

def partition(self, key, partitions):
def partition(self, key, partitions=None):
# Refresh the partition list if necessary
if self.partitions != partitions:
if partitions and self.partitions != partitions:
self._set_partitions(partitions)

return next(self.iterpart)
2 changes: 1 addition & 1 deletion kafka/producer/keyed.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _next_partition(self, topic, key):
self.partitioners[topic] = self.partitioner_class(self.client.get_partition_ids_for_topic(topic))

partitioner = self.partitioners[topic]
return partitioner.partition(key, self.client.get_partition_ids_for_topic(topic))
return partitioner.partition(key)

def send_messages(self,topic,key,*msg):
partition = self._next_partition(topic, key)
Expand Down
77 changes: 62 additions & 15 deletions test/test_failover_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kafka import KafkaClient, SimpleConsumer
from kafka.common import TopicAndPartition, FailedPayloadsError, ConnectionError
from kafka.producer.base import Producer
from kafka.producer import KeyedProducer

from test.fixtures import ZookeeperFixture, KafkaFixture
from test.testutil import (
Expand All @@ -17,8 +18,7 @@
class TestFailover(KafkaIntegrationTestCase):
create_client = False

@classmethod
def setUpClass(cls): # noqa
def setUp(self):
if not os.environ.get('KAFKA_VERSION'):
return

Expand All @@ -27,33 +27,41 @@ def setUpClass(cls): # noqa
partitions = 2

# mini zookeeper, 2 kafka brokers
cls.zk = ZookeeperFixture.instance()
kk_args = [cls.zk.host, cls.zk.port, zk_chroot, replicas, partitions]
cls.brokers = [KafkaFixture.instance(i, *kk_args) for i in range(replicas)]
self.zk = ZookeeperFixture.instance()
kk_args = [self.zk.host, self.zk.port, zk_chroot, replicas, partitions]
self.brokers = [KafkaFixture.instance(i, *kk_args) for i in range(replicas)]

hosts = ['%s:%d' % (b.host, b.port) for b in cls.brokers]
cls.client = KafkaClient(hosts)
hosts = ['%s:%d' % (b.host, b.port) for b in self.brokers]
self.client = KafkaClient(hosts)
super(TestFailover, self).setUp()

@classmethod
def tearDownClass(cls):
def tearDown(self):
super(TestFailover, self).tearDown()
if not os.environ.get('KAFKA_VERSION'):
return

cls.client.close()
for broker in cls.brokers:
self.client.close()
for broker in self.brokers:
broker.close()
cls.zk.close()
self.zk.close()

@kafka_versions("all")
def test_switch_leader(self):
topic = self.topic
partition = 0

# Test the base class Producer -- send_messages to a specific partition
# Testing the base Producer class here so that we can easily send
# messages to a specific partition, kill the leader for that partition
# and check that after another broker takes leadership the producer
# is able to resume sending messages

# require that the server commit messages to all in-sync replicas
# so that failover doesn't lose any messages on server-side
# and we can assert that server-side message count equals client-side
producer = Producer(self.client, async=False,
req_acks=Producer.ACK_AFTER_CLUSTER_COMMIT)

# Send 10 random messages
# Send 100 random messages to a specific partition
self._send_random_messages(producer, topic, partition, 100)

# kill leader for partition
Expand All @@ -80,7 +88,7 @@ def test_switch_leader(self):
self._send_random_messages(producer, topic, partition, 100)

# count number of messages
# Should be equal to 10 before + 1 recovery + 10 after
# Should be equal to 100 before + 1 recovery + 100 after
self.assert_message_count(topic, 201, partitions=(partition,))


Expand Down Expand Up @@ -116,6 +124,45 @@ def test_switch_leader_async(self):
# Should be equal to 10 before + 1 recovery + 10 after
self.assert_message_count(topic, 21, partitions=(partition,))

@kafka_versions("all")
def test_switch_leader_keyed_producer(self):
topic = self.topic

producer = KeyedProducer(self.client, async=False)

# Send 10 random messages
for _ in range(10):
key = random_string(3)
msg = random_string(10)
producer.send_messages(topic, key, msg)

# kill leader for partition 0
self._kill_leader(topic, 0)

recovered = False
started = time.time()
timeout = 60
while not recovered and (time.time() - started) < timeout:
try:
key = random_string(3)
msg = random_string(10)
producer.send_messages(topic, key, msg)
if producer.partitioners[topic].partition(key) == 0:
recovered = True
except (FailedPayloadsError, ConnectionError):
logging.debug("caught exception sending message -- will retry")
continue

# Verify we successfully sent the message
self.assertTrue(recovered)

# send some more messages just to make sure no more exceptions
for _ in range(10):
key = random_string(3)
msg = random_string(10)
producer.send_messages(topic, key, msg)


def _send_random_messages(self, producer, topic, partition, n):
for j in range(n):
logging.debug('_send_random_message to %s:%d -- try %d', topic, partition, j)
Expand Down
Loading

0 comments on commit 9ad0be6

Please sign in to comment.