diff --git a/client-encryption/src/main/java/io/confluent/kafka/schemaregistry/encryption/FieldEncryptionExecutor.java b/client-encryption/src/main/java/io/confluent/kafka/schemaregistry/encryption/FieldEncryptionExecutor.java index 7e97fdcef7f..41125b0cd58 100644 --- a/client-encryption/src/main/java/io/confluent/kafka/schemaregistry/encryption/FieldEncryptionExecutor.java +++ b/client-encryption/src/main/java/io/confluent/kafka/schemaregistry/encryption/FieldEncryptionExecutor.java @@ -50,7 +50,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.apache.kafka.common.config.ConfigException; @@ -244,7 +243,7 @@ public void close() throws RuleException { public class FieldEncryptionExecutorTransform implements FieldTransform { private Cryptor cryptor; private String kekName; - private KekInfo kek; + private Kek kek; private int dekExpiryDays; public void init(RuleContext ctx) throws RuleException { @@ -280,14 +279,14 @@ protected String getKekName(RuleContext ctx) throws RuleException { return name; } - protected KekInfo getOrCreateKek(RuleContext ctx) throws RuleException { + protected Kek getOrCreateKek(RuleContext ctx) throws RuleException { boolean isRead = ctx.ruleMode() == RuleMode.READ; KekId kekId = new KekId(kekName, isRead); String kmsType = ctx.getParameter(ENCRYPT_KMS_TYPE); String kmsKeyId = ctx.getParameter(ENCRYPT_KMS_KEY_ID); - KekInfo kek = retrieveKekFromRegistry(ctx, kekId); + Kek kek = retrieveKekFromRegistry(ctx, kekId); if (kek == null) { if (isRead) { throw new RuleException("No kek found for " + kekName + " during consume"); @@ -298,8 +297,7 @@ protected KekInfo getOrCreateKek(RuleContext ctx) throws RuleException { if (kmsKeyId == null) { throw new RuleException("No kms key id found for " + kekName + " during produce"); } - kek = new KekInfo(kmsType, kmsKeyId, false); - kek = storeKekToRegistry(ctx, kekId, kek); + kek = storeKekToRegistry(ctx, kekId, kmsType, kmsKeyId, false); if (kek == null) { // Handle conflicts (409) kek = retrieveKekFromRegistry(ctx, kekId); @@ -336,13 +334,13 @@ private int getDekExpiryDays(RuleContext ctx) throws RuleException { return dekExpiryDays; } - private KekInfo retrieveKekFromRegistry(RuleContext ctx, KekId key) throws RuleException { + private Kek retrieveKekFromRegistry(RuleContext ctx, KekId key) throws RuleException { try { Kek kek = client.getKek(key.getName(), key.isLookupDeleted()); if (kek == null) { return null; } - return new KekInfo(kek.getKmsType(), kek.getKmsKeyId(), kek.isShared()); + return kek; } catch (RestClientException e) { if (e.getStatus() == 404) { return null; @@ -353,14 +351,15 @@ private KekInfo retrieveKekFromRegistry(RuleContext ctx, KekId key) throws RuleE } } - private KekInfo storeKekToRegistry(RuleContext ctx, KekId key, KekInfo kekInfo) + private Kek storeKekToRegistry( + RuleContext ctx, KekId key, String kmsType, String kmsKeyId, boolean shared) throws RuleException { try { Kek kek = client.createKek( - key.getName(), kekInfo.getKmsType(), kekInfo.getKmsKeyId(), - null, null, kekInfo.isShared()); + key.getName(), kmsType, kmsKeyId, + null, null, shared); log.info("Registered kek " + key.getName()); - return new KekInfo(kek.getKmsType(), kek.getKmsKeyId(), kek.isShared()); + return kek; } catch (RestClientException e) { if (e.getStatus() == 409) { return null; @@ -371,13 +370,13 @@ private KekInfo storeKekToRegistry(RuleContext ctx, KekId key, KekInfo kekInfo) } } - public DekInfo getOrCreateDek(RuleContext ctx, Integer version) + public Dek getOrCreateDek(RuleContext ctx, Integer version) throws RuleException, GeneralSecurityException { boolean isRead = ctx.ruleMode() == RuleMode.READ; DekId dekId = new DekId(kekName, ctx.subject(), version, cryptor.getDekFormat(), isRead); Aead aead = null; - DekInfo dek = retrieveDekFromRegistry(dekId); + Dek dek = retrieveDekFromRegistry(dekId); boolean isExpired = isExpired(ctx, dek); if (isExpired) { log.info("Dek with ts " + dek.getTimestamp() @@ -408,24 +407,24 @@ public DekInfo getOrCreateDek(RuleContext ctx, Integer version) throw new RuleException("No dek found for " + kekName + " during produce"); } } - if (dek.getRawDek() == null) { + if (dek.getKeyMaterialBytes() == null) { if (aead == null) { aead = getAead(configs, kek); } - byte[] rawDek = aead.decrypt(dek.getEncryptedDek(), EMPTY_AAD); - dek.setRawDek(rawDek); + byte[] rawDek = aead.decrypt(dek.getEncryptedKeyMaterialBytes(), EMPTY_AAD); + dek.setKeyMaterial(rawDek); } return dek; } - private boolean isExpired(RuleContext ctx, DekInfo dek) { + private boolean isExpired(RuleContext ctx, Dek dek) { return ctx.ruleMode() != RuleMode.READ && dekExpiryDays > 0 && dek != null && (clock.millis() - dek.getTimestamp()) / MILLIS_IN_DAY >= dekExpiryDays; } - private DekInfo retrieveDekFromRegistry(DekId key) + private Dek retrieveDekFromRegistry(DekId key) throws RuleException { try { Dek dek; @@ -440,15 +439,7 @@ private DekInfo retrieveDekFromRegistry(DekId key) if (dek == null) { return null; } - byte[] rawDek = dek.getKeyMaterial() != null - ? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getKeyMaterial())) - : null; - byte[] encryptedDek = dek.getEncryptedKeyMaterial() != null - ? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getEncryptedKeyMaterial())) - : null; - return encryptedDek != null - ? new DekInfo(dek.getVersion(), rawDek, encryptedDek, dek.getTimestamp()) - : null; + return dek.getEncryptedKeyMaterial() != null ? dek : null; } catch (RestClientException e) { if (e.getStatus() == 404) { return null; @@ -461,7 +452,7 @@ private DekInfo retrieveDekFromRegistry(DekId key) } } - private DekInfo storeDekToRegistry(DekId key, byte[] encryptedDek) + private Dek storeDekToRegistry(DekId key, byte[] encryptedDek) throws RuleException { try { String encryptedDekStr = encryptedDek != null @@ -476,14 +467,8 @@ private DekInfo storeDekToRegistry(DekId key, byte[] encryptedDek) dek = client.createDek( key.getKekName(), key.getSubject(), key.getDekFormat(), encryptedDekStr); } - byte[] rawDek = dek.getKeyMaterial() != null - ? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getKeyMaterial())) - : null; - encryptedDek = dek.getEncryptedKeyMaterial() != null - ? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getEncryptedKeyMaterial())) - : null; log.info("Registered dek for kek " + key.getKekName() + ", subject " + key.getSubject()); - return new DekInfo(dek.getVersion(), rawDek, encryptedDek, dek.getTimestamp()); + return dek; } catch (RestClientException e) { if (e.getStatus() == 409) { return null; @@ -502,7 +487,7 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu if (fieldValue == null) { return null; } - DekInfo dek; + Dek dek; byte[] plaintext; byte[] ciphertext; switch (ctx.ruleMode()) { @@ -513,7 +498,7 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu "Type '" + fieldCtx.getType() + "' not supported for encryption"); } dek = getOrCreateDek(ctx, isDekRotated() ? LATEST_VERSION : null); - ciphertext = cryptor.encrypt(dek.getRawDek(), plaintext, EMPTY_AAD); + ciphertext = cryptor.encrypt(dek.getKeyMaterialBytes(), plaintext, EMPTY_AAD); if (isDekRotated()) { ciphertext = prefixVersion(dek.getVersion(), ciphertext); } @@ -536,7 +521,7 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu ciphertext = kv.getValue(); } dek = getOrCreateDek(ctx, version); - plaintext = cryptor.decrypt(dek.getRawDek(), ciphertext, EMPTY_AAD); + plaintext = cryptor.decrypt(dek.getKeyMaterialBytes(), ciphertext, EMPTY_AAD); return toObject(fieldCtx.getType(), plaintext); default: throw new IllegalArgumentException("Unsupported rule mode " + ctx.ruleMode()); @@ -572,7 +557,7 @@ public void close() { } } - private static Aead getAead(Map configs, KekInfo kek) + private static Aead getAead(Map configs, Kek kek) throws GeneralSecurityException, RuleException { String kekUrl = kek.getKmsType() + KMS_TYPE_SUFFIX + kek.getKmsKeyId(); KmsClient kmsClient = getKmsClient(configs, kekUrl); @@ -590,106 +575,5 @@ private static KmsClient getKmsClient(Map configs, String kekUrl) return KmsDriverManager.getDriver(kekUrl).registerKmsClient(configs, Optional.of(kekUrl)); } } - - static class KekInfo { - - private final String kmsType; - private final String kmsKeyId; - private final boolean shared; - - public KekInfo(String kmsType, String kmsKeyId, boolean shared) { - this.kmsType = kmsType; - this.kmsKeyId = kmsKeyId; - this.shared = shared; - } - - public String getKmsType() { - return kmsType; - } - - public String getKmsKeyId() { - return kmsKeyId; - } - - public boolean isShared() { - return shared; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - KekInfo kek = (KekInfo) o; - return shared == kek.shared - && Objects.equals(kmsType, kek.kmsType) - && Objects.equals(kmsKeyId, kek.kmsKeyId); - } - - @Override - public int hashCode() { - return Objects.hash(kmsType, kmsKeyId, shared); - } - } - - static class DekInfo { - - private final Integer version; - private byte[] rawDek; - private final byte[] encryptedDek; - private final Long ts; - - public DekInfo(Integer version, byte[] rawDek, byte[] encryptedDek, Long ts) { - this.version = version; - this.rawDek = rawDek; - this.encryptedDek = encryptedDek; - this.ts = ts; - } - - public Integer getVersion() { - return version; - } - - public byte[] getRawDek() { - return rawDek; - } - - public void setRawDek(byte[] rawDek) { - this.rawDek = rawDek; - } - - public byte[] getEncryptedDek() { - return encryptedDek; - } - - public Long getTimestamp() { - return ts; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - DekInfo dek = (DekInfo) o; - return Objects.equals(version, dek.version) - && Arrays.equals(rawDek, dek.rawDek) - && Arrays.equals(encryptedDek, dek.encryptedDek); - } - - @Override - public int hashCode() { - int result = Objects.hash(version); - result = 31 * result + Arrays.hashCode(rawDek); - result = 31 * result + Arrays.hashCode(encryptedDek); - return result; - } - } } diff --git a/dek-registry-client/src/main/java/io/confluent/dekregistry/client/rest/entities/Dek.java b/dek-registry-client/src/main/java/io/confluent/dekregistry/client/rest/entities/Dek.java index 349878ef5da..082a6d03af6 100644 --- a/dek-registry-client/src/main/java/io/confluent/dekregistry/client/rest/entities/Dek.java +++ b/dek-registry-client/src/main/java/io/confluent/dekregistry/client/rest/entities/Dek.java @@ -25,6 +25,8 @@ import io.confluent.kafka.schemaregistry.encryption.tink.DekFormat; import io.confluent.kafka.schemaregistry.utils.JacksonMapper; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Objects; @JsonInclude(Include.NON_NULL) @@ -36,10 +38,13 @@ public class Dek { private final int version; private final DekFormat algorithm; private final String encryptedKeyMaterial; - private final String keyMaterial; + private String keyMaterial; private final Long timestamp; private final Boolean deleted; + private byte[] encryptedKeyMaterialBytes; + private byte[] keyMaterialBytes; + @JsonCreator public Dek( @JsonProperty("kekName") String kekName, @@ -91,6 +96,14 @@ public String getKeyMaterial() { return this.keyMaterial; } + @JsonProperty("keyMaterial") + public void setKeyMaterial(byte[] keyMaterialBytes) { + if (keyMaterialBytes != null) { + this.keyMaterial = + new String(Base64.getEncoder().encode(keyMaterialBytes), StandardCharsets.UTF_8); + } + } + @JsonProperty("ts") public Long getTimestamp() { return this.timestamp; @@ -106,6 +119,30 @@ public boolean isDeleted() { return Boolean.TRUE.equals(this.deleted); } + @JsonIgnore + public byte[] getEncryptedKeyMaterialBytes() { + if (encryptedKeyMaterial == null) { + return null; + } + if (encryptedKeyMaterialBytes == null) { + encryptedKeyMaterialBytes = + Base64.getDecoder().decode(encryptedKeyMaterial.getBytes(StandardCharsets.UTF_8)); + } + return encryptedKeyMaterialBytes; + } + + @JsonIgnore + public byte[] getKeyMaterialBytes() { + if (keyMaterial == null) { + return null; + } + if (keyMaterialBytes == null) { + keyMaterialBytes = + Base64.getDecoder().decode(keyMaterial.getBytes(StandardCharsets.UTF_8)); + } + return keyMaterialBytes; + } + @Override public boolean equals(Object o) { if (this == o) {