From 5329156885059d684951fe583b71f857c7f17214 Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Thu, 20 Nov 2025 13:02:10 -0800 Subject: [PATCH] DGS-22899 Fix support for wrapped Avro unions --- .../schema_registry/common/avro.py | 27 ++- .../_async/test_avro_serdes.py | 229 ++++++++++++++++++ .../schema_registry/_sync/test_avro_serdes.py | 229 ++++++++++++++++++ 3 files changed, 478 insertions(+), 7 deletions(-) diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index 199867318..f7ab32dfe 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -5,7 +5,7 @@ from collections import defaultdict from copy import deepcopy from io import BytesIO -from typing import Dict, Optional, Set, Union +from typing import Dict, Optional, Set, Tuple, Union from fastavro import repository, validate from fastavro.schema import load_schema @@ -42,6 +42,7 @@ bytes, # 'bytes' list, # 'array' dict, # 'map' and 'record' + tuple, # wrapped union type ] AvroSchema = Union[str, list, dict] @@ -108,10 +109,13 @@ def transform( if field_ctx is not None: field_ctx.field_type = get_type(schema) if isinstance(schema, list): - subschema = _resolve_union(schema, message) + (subschema, submessage) = _resolve_union(schema, message) if subschema is None: return message - return transform(ctx, subschema, message, field_transform) + submessage = transform(ctx, subschema, submessage, field_transform) + if isinstance(message, tuple) and len(message) == 2: + return (message[0], submessage) + return submessage elif isinstance(schema, dict): schema_type = schema.get("type") if schema_type == 'array': @@ -207,14 +211,23 @@ def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: return True -def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSchema]: +def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Tuple[Optional[AvroSchema], AvroMessage]: + is_wrapped_union = isinstance(message, tuple) and len(message) == 2 + is_typed_union = isinstance(message, dict) and '-type' in message for subschema in schema: try: - validate(message, subschema) + if is_wrapped_union: + if isinstance(subschema, dict) and subschema["name"] == message[0]: + return (subschema, message[1]) + elif is_typed_union: + if isinstance(subschema, dict) and subschema["name"] == message['-type']: + return (subschema, message) + else: + validate(message, subschema) + return (subschema, message) except: # noqa: E722 continue - return subschema - return None + return (None, message) def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: diff --git a/tests/schema_registry/_async/test_avro_serdes.py b/tests/schema_registry/_async/test_avro_serdes.py index c1f1b53ad..43548a755 100644 --- a/tests/schema_registry/_async/test_avro_serdes.py +++ b/tests/schema_registry/_async/test_avro_serdes.py @@ -1275,6 +1275,235 @@ async def test_avro_encryption_deterministic(): assert obj == obj2 +async def test_avro_encryption_wrapped_union(): + executor = FieldEncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = AsyncSchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + "fields": [ + { + "name": "id", + "type": "int" + }, + { + "name": "result", + "type": [ + "null", + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "confluent:tags": [ + "PII" + ], + "name": "secret", + "type": [ + "null", + "string" + ] + } + ], + "name": "Data", + "type": "record" + }, + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "name": "reason", + "type": [ + "null", + "string" + ] + } + ], + "name": "Error", + "type": "record" + } + ] + } + ], + "name": "Result", + "namespace": "com.acme", + "type": "record" + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT", + ["PII"], + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + await client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, [rule]) + )) + + obj = { + 'id': 123, + 'result': ( + 'com.acme.Data', { + 'code': 456, + 'secret': 'mypii' + } + ) + } + ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = await ser(obj, ser_ctx) + + # reset encrypted fields + assert obj['result'][1]['secret'] != 'mypii' + # remove union wrapper + obj['result'] = { + 'code': 456, + 'secret': 'mypii' + } + + deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = await deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +async def test_avro_encryption_typed_union(): + executor = FieldEncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = AsyncSchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + "fields": [ + { + "name": "id", + "type": "int" + }, + { + "name": "result", + "type": [ + "null", + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "confluent:tags": [ + "PII" + ], + "name": "secret", + "type": [ + "null", + "string" + ] + } + ], + "name": "Data", + "type": "record" + }, + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "name": "reason", + "type": [ + "null", + "string" + ] + } + ], + "name": "Error", + "type": "record" + } + ] + } + ], + "name": "Result", + "namespace": "com.acme", + "type": "record" + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT", + ["PII"], + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + await client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, [rule]) + )) + + obj = { + 'id': 123, + 'result': { + '-type': 'com.acme.Data', + 'code': 456, + 'secret': 'mypii' + } + } + ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = await ser(obj, ser_ctx) + + # reset encrypted fields + assert obj['result']['secret'] != 'mypii' + # remove union wrapper + obj['result'] = { + 'code': 456, + 'secret': 'mypii' + } + + deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = await deser(obj_bytes, ser_ctx) + assert obj == obj2 + + async def test_avro_encryption_cel(): executor = FieldEncryptionExecutor.register_with_clock(FakeClock()) diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index 30e67cf3c..77e15e2b4 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -1275,6 +1275,235 @@ def test_avro_encryption_deterministic(): assert obj == obj2 +def test_avro_encryption_wrapped_union(): + executor = FieldEncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = SchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + "fields": [ + { + "name": "id", + "type": "int" + }, + { + "name": "result", + "type": [ + "null", + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "confluent:tags": [ + "PII" + ], + "name": "secret", + "type": [ + "null", + "string" + ] + } + ], + "name": "Data", + "type": "record" + }, + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "name": "reason", + "type": [ + "null", + "string" + ] + } + ], + "name": "Error", + "type": "record" + } + ] + } + ], + "name": "Result", + "namespace": "com.acme", + "type": "record" + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT", + ["PII"], + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, [rule]) + )) + + obj = { + 'id': 123, + 'result': ( + 'com.acme.Data', { + 'code': 456, + 'secret': 'mypii' + } + ) + } + ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = ser(obj, ser_ctx) + + # reset encrypted fields + assert obj['result'][1]['secret'] != 'mypii' + # remove union wrapper + obj['result'] = { + 'code': 456, + 'secret': 'mypii' + } + + deser = AvroDeserializer(client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +def test_avro_encryption_typed_union(): + executor = FieldEncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = SchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + "fields": [ + { + "name": "id", + "type": "int" + }, + { + "name": "result", + "type": [ + "null", + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "confluent:tags": [ + "PII" + ], + "name": "secret", + "type": [ + "null", + "string" + ] + } + ], + "name": "Data", + "type": "record" + }, + { + "fields": [ + { + "name": "code", + "type": "int" + }, + { + "name": "reason", + "type": [ + "null", + "string" + ] + } + ], + "name": "Error", + "type": "record" + } + ] + } + ], + "name": "Result", + "namespace": "com.acme", + "type": "record" + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT", + ["PII"], + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, [rule]) + )) + + obj = { + 'id': 123, + 'result': { + '-type': 'com.acme.Data', + 'code': 456, + 'secret': 'mypii' + } + } + ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = ser(obj, ser_ctx) + + # reset encrypted fields + assert obj['result']['secret'] != 'mypii' + # remove union wrapper + obj['result'] = { + 'code': 456, + 'secret': 'mypii' + } + + deser = AvroDeserializer(client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = deser(obj_bytes, ser_ctx) + assert obj == obj2 + + def test_avro_encryption_cel(): executor = FieldEncryptionExecutor.register_with_clock(FakeClock())