Skip to content

Commit

Permalink
Add tests for DEK generation (#2798)
Browse files Browse the repository at this point in the history
* Add tests for generating bad deks

* Minor cleanup

* Add good dek gen test

* Minor cleanup

* Clean up imports

* Minor cleanup
  • Loading branch information
rayokota committed Oct 18, 2023
1 parent ba924aa commit 498018e
Show file tree
Hide file tree
Showing 21 changed files with 220 additions and 54 deletions.
Expand Up @@ -27,7 +27,8 @@ public AwsFieldEncryptionExecutorTest() throws Exception {
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new AwsFieldEncryptionProperties(ruleNames);
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new AwsFieldEncryptionProperties(ruleNames, ruleExecutor);
}
}
Expand Up @@ -30,6 +30,10 @@ public AwsFieldEncryptionProperties(List<String> ruleNames) {
super(ruleNames);
}

public AwsFieldEncryptionProperties(List<String> ruleNames, Class<?> ruleExecutor) {
super(ruleNames, ruleExecutor);
}

@Override
public String getKmsType() {
return "aws-kms";
Expand All @@ -52,7 +56,7 @@ public Map<String, Object> getClientProperties(String baseUrls)
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS, String.join(",", ruleNames));
for (String ruleName : ruleNames) {
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName + ".class",
FieldEncryptionExecutor.class.getName());
getRuleExecutor().getName());
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName
+ ".param." + TEST_CLIENT,
getTestClient());
Expand Down
Expand Up @@ -19,18 +19,19 @@
import static io.confluent.kafka.schemaregistry.rules.RuleBase.DEFAULT_NAME;

import com.google.common.collect.ImmutableList;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionExecutorTest;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionProperties;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionServiceLoaderTest;
import java.util.List;

public class AwsFieldEncryptionServiceLoaderTest extends FieldEncryptionExecutorTest {
public class AwsFieldEncryptionServiceLoaderTest extends FieldEncryptionServiceLoaderTest {

public AwsFieldEncryptionServiceLoaderTest() throws Exception {
super();
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new AwsFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME));
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new AwsFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME), ruleExecutor);
}
}
Expand Up @@ -27,8 +27,9 @@ public AzureFieldEncryptionExecutorTest() throws Exception {
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new AzureFieldEncryptionProperties(ruleNames);
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new AzureFieldEncryptionProperties(ruleNames, ruleExecutor);
}
}

Expand Up @@ -40,6 +40,10 @@ public AzureFieldEncryptionProperties(List<String> ruleNames) {
super(ruleNames);
}

public AzureFieldEncryptionProperties(List<String> ruleNames, Class<?> ruleExecutor) {
super(ruleNames, ruleExecutor);
}

@Override
public String getKmsType() {
return "azure-kms";
Expand All @@ -62,7 +66,7 @@ public Map<String, Object> getClientProperties(String baseUrls)
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS, String.join(",", ruleNames));
for (String ruleName : ruleNames) {
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName + ".class",
FieldEncryptionExecutor.class.getName());
getRuleExecutor().getName());
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName
+ ".param." + TEST_CLIENT,
getTestClient());
Expand Down
Expand Up @@ -19,18 +19,19 @@
import static io.confluent.kafka.schemaregistry.rules.RuleBase.DEFAULT_NAME;

import com.google.common.collect.ImmutableList;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionExecutorTest;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionProperties;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionServiceLoaderTest;
import java.util.List;

public class AzureFieldEncryptionServiceLoaderTest extends FieldEncryptionExecutorTest {
public class AzureFieldEncryptionServiceLoaderTest extends FieldEncryptionServiceLoaderTest {

public AzureFieldEncryptionServiceLoaderTest() throws Exception {
super();
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new AzureFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME));
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new AzureFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME), ruleExecutor);
}
}
Expand Up @@ -27,8 +27,9 @@ public GcpFieldEncryptionExecutorTest() throws Exception {
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new GcpFieldEncryptionProperties(ruleNames);
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new GcpFieldEncryptionProperties(ruleNames, ruleExecutor);
}
}

Expand Up @@ -20,7 +20,6 @@
import static io.confluent.kafka.schemaregistry.encryption.gcp.GcpKmsDriver.PRIVATE_KEY_ID;
import static io.confluent.kafka.schemaregistry.encryption.tink.KmsDriver.TEST_CLIENT;

import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionExecutor;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionProperties;
import io.confluent.kafka.serializers.AbstractKafkaSchemaSerDeConfig;
import java.util.Collections;
Expand All @@ -34,6 +33,10 @@ public GcpFieldEncryptionProperties(List<String> ruleNames) {
super(ruleNames);
}

public GcpFieldEncryptionProperties(List<String> ruleNames, Class<?> ruleExecutor) {
super(ruleNames, ruleExecutor);
}

@Override
public String getKmsType() {
return "gcp-kms";
Expand Down Expand Up @@ -61,7 +64,7 @@ public Map<String, Object> getClientProperties(String baseUrls)
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS, String.join(",", ruleNames));
for (String ruleName : ruleNames) {
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName + ".class",
FieldEncryptionExecutor.class.getName());
getRuleExecutor().getName());
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName
+ ".param." + CLIENT_ID,
clientId);
Expand Down
Expand Up @@ -19,18 +19,19 @@
import static io.confluent.kafka.schemaregistry.rules.RuleBase.DEFAULT_NAME;

import com.google.common.collect.ImmutableList;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionExecutorTest;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionProperties;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionServiceLoaderTest;
import java.util.List;

public class GcpFieldEncryptionServiceLoaderTest extends FieldEncryptionExecutorTest {
public class GcpFieldEncryptionServiceLoaderTest extends FieldEncryptionServiceLoaderTest {

public GcpFieldEncryptionServiceLoaderTest() throws Exception {
super();
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new GcpFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME));
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new GcpFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME), ruleExecutor);
}
}
Expand Up @@ -27,8 +27,9 @@ public HcVaultFieldEncryptionExecutorTest() throws Exception {
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new HcVaultFieldEncryptionProperties(ruleNames);
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new HcVaultFieldEncryptionProperties(ruleNames, ruleExecutor);
}
}

Expand Up @@ -40,6 +40,10 @@ public HcVaultFieldEncryptionProperties(List<String> ruleNames) {
super(ruleNames);
}

public HcVaultFieldEncryptionProperties(List<String> ruleNames, Class<?> ruleExecutor) {
super(ruleNames, ruleExecutor);
}

@Override
public String getKmsType() {
return "hcvault";
Expand All @@ -61,7 +65,7 @@ public Map<String, Object> getClientProperties(String baseUrls) throws Exception
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS, String.join(",", ruleNames));
for (String ruleName : ruleNames) {
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName + ".class",
FieldEncryptionExecutor.class.getName());
getRuleExecutor().getName());
props.put(AbstractKafkaSchemaSerDeConfig.RULE_EXECUTORS + "." + ruleName
+ ".param." + TOKEN_ID,
"dev-only-token");
Expand Down
Expand Up @@ -19,18 +19,19 @@
import static io.confluent.kafka.schemaregistry.rules.RuleBase.DEFAULT_NAME;

import com.google.common.collect.ImmutableList;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionExecutorTest;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionProperties;
import io.confluent.kafka.schemaregistry.encryption.FieldEncryptionServiceLoaderTest;
import java.util.List;

public class HcVaultFieldEncryptionServiceLoaderTest extends FieldEncryptionExecutorTest {
public class HcVaultFieldEncryptionServiceLoaderTest extends FieldEncryptionServiceLoaderTest {

public HcVaultFieldEncryptionServiceLoaderTest() throws Exception {
super();
}

@Override
protected FieldEncryptionProperties getFieldEncryptionProperties(List<String> ruleNames) {
return new HcVaultFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME));
protected FieldEncryptionProperties getFieldEncryptionProperties(
List<String> ruleNames, Class<?> ruleExecutor) {
return new HcVaultFieldEncryptionProperties(ImmutableList.of(DEFAULT_NAME), ruleExecutor);
}
}
Expand Up @@ -19,6 +19,8 @@
import com.google.common.base.Ticker;
import com.google.crypto.tink.Aead;
import com.google.crypto.tink.KmsClient;
import com.google.crypto.tink.proto.AesGcmKey;
import com.google.crypto.tink.proto.AesSivKey;
import com.google.protobuf.ByteString;
import io.confluent.dekregistry.client.CachedDekRegistryClient.DekId;
import io.confluent.dekregistry.client.CachedDekRegistryClient.KekId;
Expand Down Expand Up @@ -165,8 +167,32 @@ public Map<DekFormat, Cryptor> getCryptors() {
return cryptors;
}

protected byte[] generateKey(DekFormat dekFormat) throws GeneralSecurityException {
return getCryptor(dekFormat).generateKey();
private byte[] generateKey(DekFormat dekFormat) throws GeneralSecurityException {
byte[] dek = generateDek(dekFormat);
if (dek != null) {
switch (dekFormat) {
case AES128_GCM:
case AES256_GCM:
return AesGcmKey.newBuilder()
.setKeyValue(ByteString.copyFrom(dek))
.build()
.toByteArray();
case AES256_SIV:
return AesSivKey.newBuilder()
.setKeyValue(ByteString.copyFrom(dek))
.build()
.toByteArray();
default:
throw new IllegalArgumentException("Invalid format " + dekFormat);
}
} else {
return getCryptor(dekFormat).generateKey();
}
}

// Can be overridden to generate a custom dek
protected byte[] generateDek(DekFormat dekFormat) throws GeneralSecurityException {
return null;
}

private static byte[] toBytes(Type type, Object obj) {
Expand Down Expand Up @@ -457,7 +483,6 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu
DekInfo dek;
byte[] plaintext;
byte[] ciphertext;
Object result;
switch (ctx.ruleMode()) {
case WRITE:
plaintext = toBytes(fieldCtx.getType(), fieldValue);
Expand Down

0 comments on commit 498018e

Please sign in to comment.