From 50fdd19a02803cc0d0743a4e57564f312d8f25a4 Mon Sep 17 00:00:00 2001 From: Manicben Date: Tue, 26 Jul 2022 18:00:05 +0100 Subject: [PATCH 1/2] Support setting principal and SASL extensions in oauth_cb and handle token failures --- src/confluent_kafka/src/confluent_kafka.c | 103 ++++++++++++++++++++-- tests/test_misc.py | 59 +++++++++++++ 2 files changed, 156 insertions(+), 6 deletions(-) diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index a9dd8c17a..47513d843 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -1522,6 +1522,62 @@ static void log_cb (const rd_kafka_t *rk, int level, CallState_resume(cs); } +/** + * @brief Translate Python \p key and \p value to C types and set on + * provided \p extensions char* array at the provided index. + * + * @returns 1 on success or 0 if an exception was raised. + */ +static int py_extensions_to_c (char **extensions, Py_ssize_t idx, + PyObject *key, PyObject *value) { + PyObject *ks, *ks8, *vo8 = NULL; + const char *k; + const char *v; + Py_ssize_t ksize = 0; + Py_ssize_t vsize = 0; + + if (!(ks = cfl_PyObject_Unistr(key))) { + PyErr_SetString(PyExc_TypeError, + "expected extension key to be unicode " + "string"); + return 0; + } + + k = cfl_PyUnistr_AsUTF8(ks, &ks8); + ksize = (Py_ssize_t)strlen(k); + + if (cfl_PyUnistr(_Check(value))) { + /* Unicode string, translate to utf-8. */ + v = cfl_PyUnistr_AsUTF8(value, &vo8); + if (!v) { + Py_DECREF(ks); + Py_XDECREF(ks8); + return 0; + } + vsize = (Py_ssize_t)strlen(v); + } else { + PyErr_Format(PyExc_TypeError, + "expected extension value to be " + "unicode string, not %s", + ((PyTypeObject *)PyObject_Type(value))-> + tp_name); + Py_DECREF(ks); + Py_XDECREF(ks8); + return 0; + } + + extensions[idx] = (char*)malloc(ksize); + strcpy(extensions[idx], k); + extensions[idx + 1] = (char*)malloc(vsize); + strcpy(extensions[idx + 1], v); + + Py_DECREF(ks); + Py_XDECREF(ks8); + Py_XDECREF(vo8); + + return 1; +} + static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) { Handle *h = opaque; @@ -1529,6 +1585,10 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config, CallState *cs; const char *token; double expiry; + const char *principal = ""; + PyObject *extensions = NULL; + char **rd_extensions = NULL; + Py_ssize_t rd_extensions_size = 0; char err_msg[2048]; rd_kafka_resp_err_t err_code; @@ -1539,26 +1599,57 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config, Py_DECREF(eo); if (!result) { - goto err; + goto fail; } - if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) { + if (!PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions)) { Py_DECREF(result); - PyErr_Format(PyExc_TypeError, + PyErr_SetString(PyExc_TypeError, "expect returned value from oauth_cb " "to be (token_str, expiry_time) tuple"); goto err; } + + if (extensions) { + int len = (int)PyDict_Size(extensions); + rd_extensions = (char **)malloc(2 * len * sizeof(char *)); + Py_ssize_t pos = 0; + PyObject *ko, *vo; + while (PyDict_Next(extensions, &pos, &ko, &vo)) { + if (!py_extensions_to_c(rd_extensions, rd_extensions_size, ko, vo)) { + Py_DECREF(result); + free(rd_extensions); + goto err; + } + rd_extensions_size = rd_extensions_size + 2; + } + } + err_code = rd_kafka_oauthbearer_set_token(h->rk, token, (int64_t)(expiry * 1000), - "", NULL, 0, err_msg, + principal, (const char **)rd_extensions, rd_extensions_size, err_msg, sizeof(err_msg)); Py_DECREF(result); - if (err_code) { + if (rd_extensions) { + for(int i = 0; i < rd_extensions_size; i++) { + free(rd_extensions[i]); + } + free(rd_extensions); + } + + if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) { PyErr_Format(PyExc_ValueError, "%s", err_msg); - goto err; + goto fail; } goto done; +fail: + err_code = rd_kafka_oauthbearer_set_token_failure(h->rk, "OAuth callback raised exception"); + if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) { + PyErr_SetString(PyExc_ValueError, "Failed to set token failure"); + goto err; + } + PyErr_Clear(); + goto done; err: CallState_crash(cs); rd_kafka_yield(h->rk); diff --git a/tests/test_misc.py b/tests/test_misc.py index cdf1147fe..28ef15f04 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -159,6 +159,65 @@ def oauth_cb(oauth_config): kc.close() +seen_oauth_cb = False + + +def test_oauth_cb_principal_sasl_extensions(): + """ Tests oauth_cb. """ + + def oauth_cb(oauth_config): + global seen_oauth_cb + seen_oauth_cb = True + assert oauth_config == 'oauth_cb' + return 'token', time.time() + 300.0, oauth_config, {"extone": "extoneval", "exttwo": "exttwoval"} + + conf = {'group.id': 'test', + 'security.protocol': 'sasl_plaintext', + 'sasl.mechanisms': 'OAUTHBEARER', + 'socket.timeout.ms': '100', + 'session.timeout.ms': 1000, # Avoid close() blocking too long + 'sasl.oauthbearer.config': 'oauth_cb', + 'oauth_cb': oauth_cb + } + + kc = confluent_kafka.Consumer(**conf) + + while not seen_oauth_cb: + kc.poll(timeout=1) + kc.close() + + +# global variable for oauth_cb call back function +oauth_cb_count = 0 + + +def test_oauth_cb_failure(): + """ Tests oauth_cb. """ + + def oauth_cb(oauth_config): + global oauth_cb_count + oauth_cb_count += 1 + assert oauth_config == 'oauth_cb' + if oauth_cb_count == 2: + return 'token', time.time() + 300.0, oauth_config, {"extthree": "extthreeval"} + raise Exception + + conf = {'group.id': 'test', + 'security.protocol': 'sasl_plaintext', + 'sasl.mechanisms': 'OAUTHBEARER', + 'socket.timeout.ms': '100', + 'session.timeout.ms': 1000, # Avoid close() blocking too long + 'sasl.oauthbearer.config': 'oauth_cb', + 'oauth_cb': oauth_cb + } + + kc = confluent_kafka.Consumer(**conf) + + while oauth_cb_count < 2: + kc.poll(timeout=1) + kc.close() + + def skip_interceptors(): # Run interceptor test if monitoring-interceptor is found for path in ["/usr/lib", "/usr/local/lib", "staging/libs", "."]: From 3604dc8694afb9f3ffe2fd0ec24d99a9c1840e2e Mon Sep 17 00:00:00 2001 From: Emanuele Sabellico Date: Tue, 2 Aug 2022 20:39:02 +0200 Subject: [PATCH 2/2] removed global variables --- tests/test_misc.py | 70 +++++++++++++++------------------------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/tests/test_misc.py b/tests/test_misc.py index 28ef15f04..ae016a3a9 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -24,22 +24,18 @@ def test_version(): assert confluent_kafka.version()[0] == confluent_kafka.__version__ -# global variable for error_cb call back function -seen_error_cb = False - - def test_error_cb(): """ Tests error_cb. """ + seen_error_cb = False def error_cb(error_msg): - global seen_error_cb + nonlocal seen_error_cb seen_error_cb = True acceptable_error_codes = (confluent_kafka.KafkaError._TRANSPORT, confluent_kafka.KafkaError._ALL_BROKERS_DOWN) assert error_msg.code() in acceptable_error_codes conf = {'bootstrap.servers': 'localhost:65531', # Purposely cause connection refused error 'group.id': 'test', - 'socket.timeout.ms': '100', 'session.timeout.ms': 1000, # Avoid close() blocking too long 'error_cb': error_cb } @@ -47,26 +43,22 @@ def error_cb(error_msg): kc = confluent_kafka.Consumer(**conf) kc.subscribe(["test"]) while not seen_error_cb: - kc.poll(timeout=1) + kc.poll(timeout=0.1) kc.close() -# global variable for stats_cb call back function -seen_stats_cb = False - - def test_stats_cb(): """ Tests stats_cb. """ + seen_stats_cb = False def stats_cb(stats_json_str): - global seen_stats_cb + nonlocal seen_stats_cb seen_stats_cb = True stats_json = json.loads(stats_json_str) assert len(stats_json['name']) > 0 conf = {'group.id': 'test', - 'socket.timeout.ms': '100', 'session.timeout.ms': 1000, # Avoid close() blocking too long 'statistics.interval.ms': 200, 'stats_cb': stats_cb @@ -76,22 +68,20 @@ def stats_cb(stats_json_str): kc.subscribe(["test"]) while not seen_stats_cb: - kc.poll(timeout=1) + kc.poll(timeout=0.1) kc.close() -seen_stats_cb_check_no_brokers = False - - def test_conf_none(): """ Issue #133 Test that None can be passed for NULL by setting bootstrap.servers to None. If None would be converted to a string then a broker would show up in statistics. Verify that it doesnt. """ + seen_stats_cb_check_no_brokers = False def stats_cb_check_no_brokers(stats_json_str): """ Make sure no brokers are reported in stats """ - global seen_stats_cb_check_no_brokers + nonlocal seen_stats_cb_check_no_brokers stats = json.loads(stats_json_str) assert len(stats['brokers']) == 0, "expected no brokers in stats: %s" % stats_json_str seen_stats_cb_check_no_brokers = True @@ -101,9 +91,8 @@ def stats_cb_check_no_brokers(stats_json_str): 'stats_cb': stats_cb_check_no_brokers} p = confluent_kafka.Producer(conf) - p.poll(timeout=1) + p.poll(timeout=0.1) - global seen_stats_cb_check_no_brokers assert seen_stats_cb_check_no_brokers @@ -130,15 +119,12 @@ def test_throttle_event_types(): assert str(throttle_event) == "broker/0 throttled for 10000 ms" -# global variable for oauth_cb call back function -seen_oauth_cb = False - - def test_oauth_cb(): """ Tests oauth_cb. """ + seen_oauth_cb = False def oauth_cb(oauth_config): - global seen_oauth_cb + nonlocal seen_oauth_cb seen_oauth_cb = True assert oauth_config == 'oauth_cb' return 'token', time.time() + 300.0 @@ -146,7 +132,6 @@ def oauth_cb(oauth_config): conf = {'group.id': 'test', 'security.protocol': 'sasl_plaintext', 'sasl.mechanisms': 'OAUTHBEARER', - 'socket.timeout.ms': '100', 'session.timeout.ms': 1000, # Avoid close() blocking too long 'sasl.oauthbearer.config': 'oauth_cb', 'oauth_cb': oauth_cb @@ -155,18 +140,16 @@ def oauth_cb(oauth_config): kc = confluent_kafka.Consumer(**conf) while not seen_oauth_cb: - kc.poll(timeout=1) + kc.poll(timeout=0.1) kc.close() -seen_oauth_cb = False - - def test_oauth_cb_principal_sasl_extensions(): """ Tests oauth_cb. """ + seen_oauth_cb = False def oauth_cb(oauth_config): - global seen_oauth_cb + nonlocal seen_oauth_cb seen_oauth_cb = True assert oauth_config == 'oauth_cb' return 'token', time.time() + 300.0, oauth_config, {"extone": "extoneval", "exttwo": "exttwoval"} @@ -174,8 +157,7 @@ def oauth_cb(oauth_config): conf = {'group.id': 'test', 'security.protocol': 'sasl_plaintext', 'sasl.mechanisms': 'OAUTHBEARER', - 'socket.timeout.ms': '100', - 'session.timeout.ms': 1000, # Avoid close() blocking too long + 'session.timeout.ms': 100, # Avoid close() blocking too long 'sasl.oauthbearer.config': 'oauth_cb', 'oauth_cb': oauth_cb } @@ -183,29 +165,25 @@ def oauth_cb(oauth_config): kc = confluent_kafka.Consumer(**conf) while not seen_oauth_cb: - kc.poll(timeout=1) + kc.poll(timeout=0.1) kc.close() -# global variable for oauth_cb call back function -oauth_cb_count = 0 - - def test_oauth_cb_failure(): """ Tests oauth_cb. """ + oauth_cb_count = 0 def oauth_cb(oauth_config): - global oauth_cb_count + nonlocal oauth_cb_count oauth_cb_count += 1 assert oauth_config == 'oauth_cb' if oauth_cb_count == 2: - return 'token', time.time() + 300.0, oauth_config, {"extthree": "extthreeval"} + return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"} raise Exception conf = {'group.id': 'test', 'security.protocol': 'sasl_plaintext', 'sasl.mechanisms': 'OAUTHBEARER', - 'socket.timeout.ms': '100', 'session.timeout.ms': 1000, # Avoid close() blocking too long 'sasl.oauthbearer.config': 'oauth_cb', 'oauth_cb': oauth_cb @@ -214,7 +192,7 @@ def oauth_cb(oauth_config): kc = confluent_kafka.Consumer(**conf) while oauth_cb_count < 2: - kc.poll(timeout=1) + kc.poll(timeout=0.1) kc.close() @@ -253,11 +231,9 @@ def test_unordered_dict(init_func): client.poll(0) -# global variable for on_delivery call back function -seen_delivery_cb = False - - def test_topic_config_update(): + seen_delivery_cb = False + # *NOTE* default.topic.config has been deprecated. # This example remains to ensure backward-compatibility until its removal. confs = [{"message.timeout.ms": 600000, "default.topic.config": {"message.timeout.ms": 1000}}, @@ -266,7 +242,7 @@ def test_topic_config_update(): def on_delivery(err, msg): # Since there is no broker, produced messages should time out. - global seen_delivery_cb + nonlocal seen_delivery_cb seen_delivery_cb = True assert err.code() == confluent_kafka.KafkaError._MSG_TIMED_OUT