Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DGS-8708 Add rule config to preserve source fields #2783

Merged
merged 5 commits into from Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -66,6 +66,7 @@ public class FieldEncryptionExecutor implements FieldRuleExecutor {
public static final String ENCRYPT_KMS_KEY_ID = "encrypt.kms.key.id";
public static final String ENCRYPT_KMS_TYPE = "encrypt.kms.type";
public static final String ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm";
public static final String ENCRYPT_PRESERVE_SOURCE = "encrypt.preserve.source";

public static final String KMS_TYPE_SUFFIX = "://";
public static final byte[] EMPTY_AAD = new byte[0];
Expand All @@ -74,13 +75,18 @@ public class FieldEncryptionExecutor implements FieldRuleExecutor {

private Map<DekFormat, Cryptor> cryptors;
private Map<String, ?> configs;
private Boolean preserveSource;
private int cacheExpirySecs = -1;
private int cacheSize = 10000;
private DekRegistryClient client;

public FieldEncryptionExecutor() {
}

public boolean isPreserveSource() {
return Boolean.TRUE.equals(preserveSource);
}

@Override
public boolean addOriginalConfigs() {
return true;
Expand All @@ -89,6 +95,10 @@ public boolean addOriginalConfigs() {
@Override
public void configure(Map<String, ?> configs) {
this.configs = configs;
Object preserveSourceConfig = configs.get(ENCRYPT_PRESERVE_SOURCE);
if (preserveSourceConfig != null) {
this.preserveSource = Boolean.parseBoolean(preserveSourceConfig.toString());
}
Object cacheExpirySecsConfig = configs.get(CACHE_EXPIRY_SECS);
if (cacheExpirySecsConfig != null) {
try {
Expand Down Expand Up @@ -127,6 +137,20 @@ public FieldTransform newTransform(RuleContext ctx) throws RuleException {
return transform;
}

@Override
public Object preTransformMessage(RuleContext ctx, FieldTransform transform, Object message)
throws RuleException {
if (isPreserveSource()) {
try {
// We use the target schema
message = ctx.target().copyMessage(message);
} catch (IOException e) {
throw new RuleException("Could not copy source message", e);
}
}
return message;
}

private Cryptor getCryptor(RuleContext ctx) {
String algorithm = ctx.getParameter(ENCRYPT_DEK_ALGORITHM);
DekFormat dekFormat = algorithm != null
Expand Down Expand Up @@ -201,6 +225,12 @@ public void init(RuleContext ctx) throws RuleException {
cryptor = getCryptor(ctx);
kekName = getKekName(ctx);
kek = getKek(ctx, kekName);
if (FieldEncryptionExecutor.this.preserveSource == null) {
String preserveValueConfig = ctx.getParameter(ENCRYPT_PRESERVE_SOURCE);
if (preserveValueConfig != null) {
FieldEncryptionExecutor.this.preserveSource = Boolean.parseBoolean(preserveValueConfig);
}
}
}

protected String getKekName(RuleContext ctx) throws RuleException {
Expand Down
Expand Up @@ -265,7 +265,7 @@ private Schema createUserSchema() {
return schema;
}

private IndexedRecord createUserRecord() {
private GenericRecord createUserRecord() {
Schema schema = createUserSchema();
GenericRecord avroRecord = new GenericData.Record(schema);
avroRecord.put("name", "testUser");
Expand Down Expand Up @@ -358,6 +358,31 @@ public void testKafkaAvroSerializer() throws Exception {
assertEquals("testUser", record.get("name"));
}

@Test
public void testKafkaAvroSerializerPreserveSource() throws Exception {
GenericRecord avroRecord = createUserRecord();
AvroSchema avroSchema = new AvroSchema(createUserSchema());
Rule rule = new Rule("rule1", null, null, null,
FieldEncryptionExecutor.TYPE, ImmutableSortedSet.of("PII"),
ImmutableMap.of("encrypt.preserve.source", "true"), null, null, null, false);
RuleSet ruleSet = new RuleSet(Collections.emptyList(), ImmutableList.of(rule));
Metadata metadata = getMetadata("kek1");
avroSchema = avroSchema.copy(metadata, ruleSet);
schemaRegistry.register(topic + "-value", avroSchema);

int expectedEncryptions = 1;
RecordHeaders headers = new RecordHeaders();
Cryptor cryptor = addSpyToCryptor(avroSerializer);
byte[] bytes = avroSerializer.serialize(topic, headers, avroRecord);
verify(cryptor, times(expectedEncryptions)).encrypt(any(), any(), any());
cryptor = addSpyToCryptor(avroDeserializer);
GenericRecord record = (GenericRecord) avroDeserializer.deserialize(topic, headers, bytes);
verify(cryptor, times(expectedEncryptions)).decrypt(any(), any(), any());
assertEquals("testUser", record.get("name"));
// Old value is preserved
assertEquals("testUser", avroRecord.get("name"));
}

@Test
public void testKafkaAvroSerializer2() throws Exception {
IndexedRecord avroRecord = createUserRecord();
Expand Down Expand Up @@ -455,6 +480,40 @@ public void testKafkaAvroSerializerReflection() throws Exception {
assertEquals("678", ((OldWidget)obj).getPiiMap().get("key2").getPii());
}

@Test
public void testKafkaAvroSerializerReflectionPreserveSource() throws Exception {
OldWidget widget = new OldWidget("alice");
widget.setSsn(ImmutableList.of("123", "456"));
widget.setPiiArray(ImmutableList.of(new OldPii("789"), new OldPii("012")));
widget.setPiiMap(ImmutableMap.of("key1", new OldPii("345"), "key2", new OldPii("678")));
Schema schema = createWidgetSchema();
AvroSchema avroSchema = new AvroSchema(schema);
Rule rule = new Rule("rule1", null, null, null,
FieldEncryptionExecutor.TYPE, ImmutableSortedSet.of("PII"),
ImmutableMap.of("encrypt.preserve.source", "true"), null, null, null, false);
RuleSet ruleSet = new RuleSet(Collections.emptyList(), ImmutableList.of(rule));
Metadata metadata = getMetadata("kek1");
avroSchema = avroSchema.copy(metadata, ruleSet);
schemaRegistry.register(topic + "-value", avroSchema);

int expectedEncryptions = 7;
RecordHeaders headers = new RecordHeaders();
Cryptor cryptor = addSpyToCryptor(reflectionAvroSerializer);
byte[] bytes = reflectionAvroSerializer.serialize(topic, headers, widget);
verify(cryptor, times(expectedEncryptions)).encrypt(any(), any(), any());
cryptor = addSpyToCryptor(reflectionAvroDeserializer);
Object obj = reflectionAvroDeserializer.deserialize(topic, headers, bytes);
verify(cryptor, times(expectedEncryptions)).decrypt(any(), any(), any());

assertTrue(
"Returned object should be a Widget",
OldWidget.class.isInstance(obj)
);
assertEquals("alice", ((OldWidget)obj).getName());
// Old value is preserved
assertEquals("alice", widget.getName());
}

@Test
public void testKafkaAvroSerializerMultipleRules() throws Exception {
IndexedRecord avroRecord = createUserRecord();
Expand Down Expand Up @@ -809,6 +868,48 @@ public void testKafkaJsonSchemaSerializer() throws Exception {
);
}

@Test
public void testKafkaJsonSchemaSerializerPreserveSource() throws Exception {
OldWidget widget = new OldWidget("alice");
widget.setSize(123);
widget.setSsn(ImmutableList.of("123", "456"));
widget.setPiiArray(ImmutableList.of(new OldPii("789"), new OldPii("012")));
String schemaStr = "{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"title\":\"Old Widget\",\"type\":\"object\",\"additionalProperties\":false,\"properties\":{\n"
+ "\"name\":{\"oneOf\":[{\"type\":\"null\",\"title\":\"Not included\"},{\"type\":\"string\"}],"
+ "\"confluent:tags\": [ \"PII\" ]},"
+ "\"ssn\":{\"oneOf\":[{\"type\":\"null\",\"title\":\"Not included\"},{\"type\":\"array\",\"items\":{\"type\":\"string\"}}],"
+ "\"confluent:tags\": [ \"PII\" ]},"
+ "\"piiArray\":{\"oneOf\":[{\"type\":\"null\",\"title\":\"Not included\"},{\"type\":\"array\",\"items\":{\"$ref\":\"#/definitions/OldPii\"}}]},"
+ "\"piiMap\":{\"oneOf\":[{\"type\":\"null\",\"title\":\"Not included\"},{\"type\":\"object\",\"additionalProperties\":{\"$ref\":\"#/definitions/OldPii\"}}]},"
+ "\"size\":{\"type\":\"integer\"},"
+ "\"version\":{\"type\":\"integer\"}},"
+ "\"required\":[\"size\",\"version\"],"
+ "\"definitions\":{\"OldPii\":{\"type\":\"object\",\"additionalProperties\":false,\"properties\":{"
+ "\"pii\":{\"oneOf\":[{\"type\":\"null\",\"title\":\"Not included\"},{\"type\":\"string\"}],"
+ "\"confluent:tags\": [ \"PII\" ]}}}}}";
JsonSchema jsonSchema = new JsonSchema(schemaStr);
Rule rule = new Rule("rule1", null, null, null,
FieldEncryptionExecutor.TYPE, ImmutableSortedSet.of("PII"),
ImmutableMap.of("encrypt.preserve.source", "true"), null, null, null, false);
RuleSet ruleSet = new RuleSet(Collections.emptyList(), Collections.singletonList(rule));
Metadata metadata = getMetadata("kek1");
jsonSchema = jsonSchema.copy(metadata, ruleSet);
schemaRegistry.register(topic + "-value", jsonSchema);

int expectedEncryptions = 5;
RecordHeaders headers = new RecordHeaders();
Cryptor cryptor = addSpyToCryptor(jsonSchemaSerializer);
byte[] bytes = jsonSchemaSerializer.serialize(topic, headers, widget);
verify(cryptor, times(expectedEncryptions)).encrypt(any(), any(), any());
cryptor = addSpyToCryptor(jsonSchemaDeserializer);
Object obj = jsonSchemaDeserializer.deserialize(topic, headers, bytes);
verify(cryptor, times(expectedEncryptions)).decrypt(any(), any(), any());

assertEquals("alice", ((JsonNode)obj).get("name").textValue());
// Old value is preserved
assertEquals("alice", widget.getName());
}

@Test
public void testKafkaJsonSchemaSerializerAnnotated() throws Exception {
AnnotatedOldWidget widget = new AnnotatedOldWidget("alice");
Expand Down
Expand Up @@ -235,6 +235,10 @@ default JsonNode toJson(Object object) throws IOException {
throw new UnsupportedOperationException();
}

default Object copyMessage(Object message) throws IOException {
throw new UnsupportedOperationException();
}

default Object transformMessage(RuleContext ctx, FieldTransform transform, Object message)
throws RuleException {
throw new UnsupportedOperationException();
Expand Down
Expand Up @@ -509,6 +509,12 @@ public JsonNode toJson(Object message) throws IOException {
return JacksonMapper.INSTANCE.readTree(AvroSchemaUtils.toJson(message));
}

@Override
public Object copyMessage(Object message) throws IOException {
GenericData data = getData(message);
return data.deepCopy(rawSchema(), message);
}

@Override
public Object transformMessage(RuleContext ctx, FieldTransform transform, Object message)
throws RuleException {
Expand Down
Expand Up @@ -45,6 +45,8 @@ public class DlqAction implements RuleAction {
public static final String TYPE = "DLQ";

public static final String DLQ_TOPIC = "dlq.topic";
public static final String DLQ_AUTO_FLUSH = "dlq.auto.flush";
public static final String PRODUCER = "producer"; // for testing

public static final String HEADER_PREFIX = "__rule.";
public static final String RULE_NAME = HEADER_PREFIX + "name";
Expand All @@ -53,9 +55,6 @@ public class DlqAction implements RuleAction {
public static final String RULE_TOPIC = HEADER_PREFIX + "topic";
public static final String RULE_EXCEPTION = HEADER_PREFIX + "exception";

public static final String DLQ_AUTO_FLUSH = "dlq.auto.flush";
public static final String PRODUCER = "producer"; // for testing

private static final LongSerializer LONG_SERIALIZER = new LongSerializer();
private static final IntegerSerializer INT_SERIALIZER = new IntegerSerializer();
private static final ShortSerializer SHORT_SERIALIZER = new ShortSerializer();
Expand Down
Expand Up @@ -63,13 +63,19 @@ default Object transform(RuleContext ctx, Object message) throws RuleException {

try (FieldTransform transform = newTransform(ctx)) {
if (transform != null) {
message = preTransformMessage(ctx, transform, message);
return ctx.target().transformMessage(ctx, transform, message);
} else {
return message;
}
}
}

default Object preTransformMessage(RuleContext ctx, FieldTransform transform, Object message)
throws RuleException {
return message;
}

static boolean areTransformsWithSameTags(Rule rule1, Rule rule2) {
return rule1.getTags().size() > 0
&& rule1.getKind() == RuleKind.TRANSFORM
Expand Down
Expand Up @@ -45,7 +45,6 @@
import io.confluent.kafka.schemaregistry.client.rest.entities.Metadata;
import io.confluent.kafka.schemaregistry.client.rest.entities.RuleSet;
import io.confluent.kafka.schemaregistry.utils.BoundedConcurrentHashMap;
import io.confluent.kafka.schemaregistry.utils.JacksonMapper;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.HashMap;
Expand Down Expand Up @@ -511,7 +510,15 @@ public JsonNode toJson(Object message) throws IOException {
if (message instanceof JsonNode) {
return (JsonNode) message;
}
return JacksonMapper.INSTANCE.readTree(JsonSchemaUtils.toJson(message));
return objectMapper.readTree(JsonSchemaUtils.toJson(message));
}

@Override
public Object copyMessage(Object message) throws IOException {
if (message instanceof JsonNode) {
return ((JsonNode) message).deepCopy();
}
return toJson(message);
}

@Override
Expand Down
Expand Up @@ -2274,6 +2274,12 @@ public JsonNode toJson(Object message) throws IOException {
return JacksonMapper.INSTANCE.readTree(ProtobufSchemaUtils.toJson((Message) message));
}

@Override
public Object copyMessage(Object message) throws IOException {
// Protobuf messages are already immutable
return message;
}

@Override
public Object transformMessage(RuleContext ctx, FieldTransform transform, Object message)
throws RuleException {
Expand Down