Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def trace_kafka_poll(

try:
res = wrapped(*args, **kwargs)
create_span("poll", res.topic(), res.headers())
if res:
create_span("poll", res.topic(), res.headers())
return res
except Exception as exc:
exception = exc
Expand Down
303 changes: 302 additions & 1 deletion tests/clients/kafka/test_confluent_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@


import os
import threading
import time
from typing import Generator
from typing import Generator, List

import pytest
from confluent_kafka import Consumer, KafkaException, Producer
Expand Down Expand Up @@ -775,3 +776,303 @@ def test_trace_kafka_close_exception_handling(self, span: "InstanaSpan") -> None

# Verify span was ended
assert not span.is_recording()

def test_confluent_kafka_poll_returns_none(self) -> None:
consumer_config = self.kafka_config.copy()
consumer_config["group.id"] = "test-empty-poll-group"
consumer_config["auto.offset.reset"] = "earliest"

consumer = Consumer(consumer_config)
consumer.subscribe([testenv["kafka_topic"] + "_3"])

with self.tracer.start_as_current_span("test"):
msg = consumer.poll(timeout=0.1)

assert msg is None

consumer.close()

spans = self.recorder.queued_spans()

assert len(spans) == 1
test_span = spans[0]
assert test_span.n == "sdk"
assert test_span.data["sdk"]["name"] == "test"

def test_confluent_kafka_poll_returns_none_with_context_cleanup(self) -> None:
consumer_config = self.kafka_config.copy()
consumer_config["group.id"] = "test-context-cleanup-group"
consumer_config["auto.offset.reset"] = "earliest"

consumer = Consumer(consumer_config)
consumer.subscribe([testenv["kafka_topic"] + "_3"])

# Consume any existing messages to ensure topic is empty
while True:
msg = consumer.poll(timeout=0.5)
if msg is None:
break

# Clear any spans created during cleanup
self.recorder.clear_spans()

with self.tracer.start_as_current_span("test"):
for _ in range(3):
msg = consumer.poll(timeout=0.1)
assert msg is None

consumer.close()

spans = self.recorder.queued_spans()
assert len(spans) == 1
test_span = spans[0]
assert test_span.n == "sdk"

def test_confluent_kafka_poll_none_then_message(self) -> None:
# First, create a temporary consumer to clean up any existing messages
cleanup_config = self.kafka_config.copy()
cleanup_config["group.id"] = "test-none-then-message-cleanup"
cleanup_config["auto.offset.reset"] = "earliest"

cleanup_consumer = Consumer(cleanup_config)
cleanup_consumer.subscribe([testenv["kafka_topic"] + "_3"])

# Consume any existing messages
while True:
msg = cleanup_consumer.poll(timeout=0.5)
if msg is None:
break

cleanup_consumer.close()

# Clear any spans created during cleanup
self.recorder.clear_spans()

# Now run the actual test with a fresh consumer
consumer_config = self.kafka_config.copy()
consumer_config["group.id"] = "test-none-then-message-group"
consumer_config["auto.offset.reset"] = "earliest"

consumer = Consumer(consumer_config)
consumer.subscribe([testenv["kafka_topic"] + "_3"])

with self.tracer.start_as_current_span("test"):
msg1 = consumer.poll(timeout=0.1)
assert msg1 is None

self.producer.produce(testenv["kafka_topic"] + "_3", b"test_message")
self.producer.flush(timeout=10)

msg2 = consumer.poll(timeout=5)
assert msg2 is not None
assert msg2.value() == b"test_message"

consumer.close()

spans = self.recorder.queued_spans()
assert len(spans) == 3

kafka_span = get_first_span_by_filter(
spans,
lambda span: span.n == "kafka" and span.data["kafka"]["access"] == "poll",
)
assert kafka_span is not None
assert kafka_span.data["kafka"]["service"] == testenv["kafka_topic"] + "_3"

kafka_span = get_first_span_by_filter(
spans,
lambda span: span.n == "kafka"
and span.data["kafka"]["access"] == "produce",
)
assert kafka_span is not None
assert kafka_span.data["kafka"]["service"] == testenv["kafka_topic"] + "_3"

def test_confluent_kafka_poll_multithreaded_context_isolation(self) -> None:
agent.options.allow_exit_as_root = True
agent.options.set_trace_configurations()

# Produce messages to multiple topics
num_threads = 3
messages_per_topic = 2

for i in range(num_threads):
topic = f"{testenv['kafka_topic']}_thread_{i}"
# Create topic
try:
self.kafka_client.create_topics(
[NewTopic(topic, num_partitions=1, replication_factor=1)]
)
except KafkaException:
pass

# Produce messages
for j in range(messages_per_topic):
self.producer.produce(topic, f"message_{j}".encode())

self.producer.flush(timeout=10)
time.sleep(1) # Allow messages to be available

# Track results from each thread
thread_results: List[dict] = []
thread_errors: List[Exception] = []
lock = threading.Lock()

def consume_from_topic(thread_id: int) -> None:
try:
topic = f"{testenv['kafka_topic']}_thread_{thread_id}"
consumer_config = self.kafka_config.copy()
consumer_config["group.id"] = f"test-multithread-group-{thread_id}"
consumer_config["auto.offset.reset"] = "earliest"

consumer = Consumer(consumer_config)
consumer.subscribe([topic])

messages_consumed = 0
none_polls = 0
max_polls = 10

with self.tracer.start_as_current_span(f"thread-{thread_id}"):
for _ in range(max_polls):
msg = consumer.poll(timeout=1.0)

if msg is None:
none_polls += 1
_ = consumer_span.get(None)
else:
if msg.error():
continue
messages_consumed += 1

assert msg.topic() == topic

if messages_consumed >= messages_per_topic:
break

consumer.close()

with lock:
thread_results.append(
{
"thread_id": thread_id,
"topic": topic,
"messages_consumed": messages_consumed,
"none_polls": none_polls,
"success": True,
}
)

except Exception as e:
with lock:
thread_errors.append(e)
thread_results.append(
{"thread_id": thread_id, "success": False, "error": str(e)}
)

threads = []
for i in range(num_threads):
thread = threading.Thread(target=consume_from_topic, args=(i,))
threads.append(thread)
thread.start()

for thread in threads:
thread.join(timeout=30)

assert len(thread_errors) == 0, f"Errors in threads: {thread_errors}"

assert len(thread_results) == num_threads
for result in thread_results:
assert result[
"success"
], f"Thread {result['thread_id']} failed: {result.get('error')}"
assert (
result["messages_consumed"] == messages_per_topic
), f"Thread {result['thread_id']} consumed {result['messages_consumed']} messages, expected {messages_per_topic}"

spans = self.recorder.queued_spans()

expected_min_spans = num_threads * (1 + messages_per_topic * 2)
assert (
len(spans) >= expected_min_spans
), f"Expected at least {expected_min_spans} spans, got {len(spans)}"

for i in range(num_threads):
topic = f"{testenv['kafka_topic']}_thread_{i}"

poll_spans = [
s
for s in spans
if s.n == "kafka"
and s.data.get("kafka", {}).get("access") == "poll"
and s.data.get("kafka", {}).get("service") == topic
]

assert (
len(poll_spans) >= 1
), f"Expected poll spans for topic {topic}, got {len(poll_spans)}"

topics_to_delete = [
f"{testenv['kafka_topic']}_thread_{i}" for i in range(num_threads)
]
self.kafka_client.delete_topics(topics_to_delete)
time.sleep(1)

def test_confluent_kafka_poll_multithreaded_with_none_returns(self) -> None:
num_threads = 5

thread_errors: List[Exception] = []
lock = threading.Lock()

def poll_empty_topic(thread_id: int) -> None:
try:
consumer_config = self.kafka_config.copy()
consumer_config["group.id"] = f"test-empty-poll-{thread_id}"
consumer_config["auto.offset.reset"] = "earliest"

consumer = Consumer(consumer_config)
consumer.subscribe([testenv["kafka_topic"] + "_3"])

# Consume any existing messages to ensure topic is empty
while True:
msg = consumer.poll(timeout=0.5)
if msg is None:
break

with self.tracer.start_as_current_span(
f"empty-poll-thread-{thread_id}"
):
for _ in range(5):
msg = consumer.poll(timeout=0.1)
assert msg is None, "Expected None from empty topic"

time.sleep(0.01)

consumer.close()

except Exception as e:
with lock:
thread_errors.append(e)

threads = []
for i in range(num_threads):
thread = threading.Thread(target=poll_empty_topic, args=(i,))
threads.append(thread)
thread.start()

for thread in threads:
thread.join(timeout=10)

assert (
len(thread_errors) == 0
), f"Context errors in threads: {[str(e) for e in thread_errors]}"

spans = self.recorder.queued_spans()

test_spans = [s for s in spans if s.n == "sdk"]
assert (
len(test_spans) == num_threads
), f"Expected {num_threads} test spans, got {len(test_spans)}"

kafka_spans = [s for s in spans if s.n == "kafka"]
assert (
len(kafka_spans) == 0
), f"Expected no kafka spans for None polls, got {len(kafka_spans)}"
Loading