Skip to content

Commit

Permalink
Optimization: remove extraneous KekInfo/DekInfo classes
Browse files Browse the repository at this point in the history
  • Loading branch information
rayokota committed Feb 22, 2024
1 parent 2316185 commit 2656938
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 142 deletions.
Expand Up @@ -50,7 +50,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.kafka.common.config.ConfigException;
Expand Down Expand Up @@ -244,7 +243,7 @@ public void close() throws RuleException {
public class FieldEncryptionExecutorTransform implements FieldTransform {
private Cryptor cryptor;
private String kekName;
private KekInfo kek;
private Kek kek;
private int dekExpiryDays;

public void init(RuleContext ctx) throws RuleException {
Expand Down Expand Up @@ -280,14 +279,14 @@ protected String getKekName(RuleContext ctx) throws RuleException {
return name;
}

protected KekInfo getOrCreateKek(RuleContext ctx) throws RuleException {
protected Kek getOrCreateKek(RuleContext ctx) throws RuleException {
boolean isRead = ctx.ruleMode() == RuleMode.READ;
KekId kekId = new KekId(kekName, isRead);

String kmsType = ctx.getParameter(ENCRYPT_KMS_TYPE);
String kmsKeyId = ctx.getParameter(ENCRYPT_KMS_KEY_ID);

KekInfo kek = retrieveKekFromRegistry(ctx, kekId);
Kek kek = retrieveKekFromRegistry(ctx, kekId);
if (kek == null) {
if (isRead) {
throw new RuleException("No kek found for " + kekName + " during consume");
Expand All @@ -298,8 +297,7 @@ protected KekInfo getOrCreateKek(RuleContext ctx) throws RuleException {
if (kmsKeyId == null) {
throw new RuleException("No kms key id found for " + kekName + " during produce");
}
kek = new KekInfo(kmsType, kmsKeyId, false);
kek = storeKekToRegistry(ctx, kekId, kek);
kek = storeKekToRegistry(ctx, kekId, kmsType, kmsKeyId, false);
if (kek == null) {
// Handle conflicts (409)
kek = retrieveKekFromRegistry(ctx, kekId);
Expand Down Expand Up @@ -336,13 +334,13 @@ private int getDekExpiryDays(RuleContext ctx) throws RuleException {
return dekExpiryDays;
}

private KekInfo retrieveKekFromRegistry(RuleContext ctx, KekId key) throws RuleException {
private Kek retrieveKekFromRegistry(RuleContext ctx, KekId key) throws RuleException {
try {
Kek kek = client.getKek(key.getName(), key.isLookupDeleted());
if (kek == null) {
return null;
}
return new KekInfo(kek.getKmsType(), kek.getKmsKeyId(), kek.isShared());
return kek;
} catch (RestClientException e) {
if (e.getStatus() == 404) {
return null;
Expand All @@ -353,14 +351,15 @@ private KekInfo retrieveKekFromRegistry(RuleContext ctx, KekId key) throws RuleE
}
}

private KekInfo storeKekToRegistry(RuleContext ctx, KekId key, KekInfo kekInfo)
private Kek storeKekToRegistry(
RuleContext ctx, KekId key, String kmsType, String kmsKeyId, boolean shared)
throws RuleException {
try {
Kek kek = client.createKek(
key.getName(), kekInfo.getKmsType(), kekInfo.getKmsKeyId(),
null, null, kekInfo.isShared());
key.getName(), kmsType, kmsKeyId,
null, null, shared);
log.info("Registered kek " + key.getName());
return new KekInfo(kek.getKmsType(), kek.getKmsKeyId(), kek.isShared());
return kek;
} catch (RestClientException e) {
if (e.getStatus() == 409) {
return null;
Expand All @@ -371,13 +370,13 @@ private KekInfo storeKekToRegistry(RuleContext ctx, KekId key, KekInfo kekInfo)
}
}

public DekInfo getOrCreateDek(RuleContext ctx, Integer version)
public Dek getOrCreateDek(RuleContext ctx, Integer version)
throws RuleException, GeneralSecurityException {
boolean isRead = ctx.ruleMode() == RuleMode.READ;
DekId dekId = new DekId(kekName, ctx.subject(), version, cryptor.getDekFormat(), isRead);

Aead aead = null;
DekInfo dek = retrieveDekFromRegistry(dekId);
Dek dek = retrieveDekFromRegistry(dekId);
boolean isExpired = isExpired(ctx, dek);
if (isExpired) {
log.info("Dek with ts " + dek.getTimestamp()
Expand Down Expand Up @@ -408,24 +407,24 @@ public DekInfo getOrCreateDek(RuleContext ctx, Integer version)
throw new RuleException("No dek found for " + kekName + " during produce");
}
}
if (dek.getRawDek() == null) {
if (dek.getKeyMaterialBytes() == null) {
if (aead == null) {
aead = getAead(configs, kek);
}
byte[] rawDek = aead.decrypt(dek.getEncryptedDek(), EMPTY_AAD);
dek.setRawDek(rawDek);
byte[] rawDek = aead.decrypt(dek.getEncryptedKeyMaterialBytes(), EMPTY_AAD);
dek.setKeyMaterial(rawDek);
}
return dek;
}

private boolean isExpired(RuleContext ctx, DekInfo dek) {
private boolean isExpired(RuleContext ctx, Dek dek) {
return ctx.ruleMode() != RuleMode.READ
&& dekExpiryDays > 0
&& dek != null
&& (clock.millis() - dek.getTimestamp()) / MILLIS_IN_DAY >= dekExpiryDays;
}

private DekInfo retrieveDekFromRegistry(DekId key)
private Dek retrieveDekFromRegistry(DekId key)
throws RuleException {
try {
Dek dek;
Expand All @@ -440,15 +439,7 @@ private DekInfo retrieveDekFromRegistry(DekId key)
if (dek == null) {
return null;
}
byte[] rawDek = dek.getKeyMaterial() != null
? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getKeyMaterial()))
: null;
byte[] encryptedDek = dek.getEncryptedKeyMaterial() != null
? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getEncryptedKeyMaterial()))
: null;
return encryptedDek != null
? new DekInfo(dek.getVersion(), rawDek, encryptedDek, dek.getTimestamp())
: null;
return dek.getEncryptedKeyMaterial() != null ? dek : null;
} catch (RestClientException e) {
if (e.getStatus() == 404) {
return null;
Expand All @@ -461,7 +452,7 @@ private DekInfo retrieveDekFromRegistry(DekId key)
}
}

private DekInfo storeDekToRegistry(DekId key, byte[] encryptedDek)
private Dek storeDekToRegistry(DekId key, byte[] encryptedDek)
throws RuleException {
try {
String encryptedDekStr = encryptedDek != null
Expand All @@ -476,14 +467,8 @@ private DekInfo storeDekToRegistry(DekId key, byte[] encryptedDek)
dek = client.createDek(
key.getKekName(), key.getSubject(), key.getDekFormat(), encryptedDekStr);
}
byte[] rawDek = dek.getKeyMaterial() != null
? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getKeyMaterial()))
: null;
encryptedDek = dek.getEncryptedKeyMaterial() != null
? Base64.getDecoder().decode(toBytes(Type.STRING, dek.getEncryptedKeyMaterial()))
: null;
log.info("Registered dek for kek " + key.getKekName() + ", subject " + key.getSubject());
return new DekInfo(dek.getVersion(), rawDek, encryptedDek, dek.getTimestamp());
return dek;
} catch (RestClientException e) {
if (e.getStatus() == 409) {
return null;
Expand All @@ -502,7 +487,7 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu
if (fieldValue == null) {
return null;
}
DekInfo dek;
Dek dek;
byte[] plaintext;
byte[] ciphertext;
switch (ctx.ruleMode()) {
Expand All @@ -513,7 +498,7 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu
"Type '" + fieldCtx.getType() + "' not supported for encryption");
}
dek = getOrCreateDek(ctx, isDekRotated() ? LATEST_VERSION : null);
ciphertext = cryptor.encrypt(dek.getRawDek(), plaintext, EMPTY_AAD);
ciphertext = cryptor.encrypt(dek.getKeyMaterialBytes(), plaintext, EMPTY_AAD);
if (isDekRotated()) {
ciphertext = prefixVersion(dek.getVersion(), ciphertext);
}
Expand All @@ -536,7 +521,7 @@ public Object transform(RuleContext ctx, FieldContext fieldCtx, Object fieldValu
ciphertext = kv.getValue();
}
dek = getOrCreateDek(ctx, version);
plaintext = cryptor.decrypt(dek.getRawDek(), ciphertext, EMPTY_AAD);
plaintext = cryptor.decrypt(dek.getKeyMaterialBytes(), ciphertext, EMPTY_AAD);
return toObject(fieldCtx.getType(), plaintext);
default:
throw new IllegalArgumentException("Unsupported rule mode " + ctx.ruleMode());
Expand Down Expand Up @@ -572,7 +557,7 @@ public void close() {
}
}

private static Aead getAead(Map<String, ?> configs, KekInfo kek)
private static Aead getAead(Map<String, ?> configs, Kek kek)
throws GeneralSecurityException, RuleException {
String kekUrl = kek.getKmsType() + KMS_TYPE_SUFFIX + kek.getKmsKeyId();
KmsClient kmsClient = getKmsClient(configs, kekUrl);
Expand All @@ -590,106 +575,5 @@ private static KmsClient getKmsClient(Map<String, ?> configs, String kekUrl)
return KmsDriverManager.getDriver(kekUrl).registerKmsClient(configs, Optional.of(kekUrl));
}
}

static class KekInfo {

private final String kmsType;
private final String kmsKeyId;
private final boolean shared;

public KekInfo(String kmsType, String kmsKeyId, boolean shared) {
this.kmsType = kmsType;
this.kmsKeyId = kmsKeyId;
this.shared = shared;
}

public String getKmsType() {
return kmsType;
}

public String getKmsKeyId() {
return kmsKeyId;
}

public boolean isShared() {
return shared;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
KekInfo kek = (KekInfo) o;
return shared == kek.shared
&& Objects.equals(kmsType, kek.kmsType)
&& Objects.equals(kmsKeyId, kek.kmsKeyId);
}

@Override
public int hashCode() {
return Objects.hash(kmsType, kmsKeyId, shared);
}
}

static class DekInfo {

private final Integer version;
private byte[] rawDek;
private final byte[] encryptedDek;
private final Long ts;

public DekInfo(Integer version, byte[] rawDek, byte[] encryptedDek, Long ts) {
this.version = version;
this.rawDek = rawDek;
this.encryptedDek = encryptedDek;
this.ts = ts;
}

public Integer getVersion() {
return version;
}

public byte[] getRawDek() {
return rawDek;
}

public void setRawDek(byte[] rawDek) {
this.rawDek = rawDek;
}

public byte[] getEncryptedDek() {
return encryptedDek;
}

public Long getTimestamp() {
return ts;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DekInfo dek = (DekInfo) o;
return Objects.equals(version, dek.version)
&& Arrays.equals(rawDek, dek.rawDek)
&& Arrays.equals(encryptedDek, dek.encryptedDek);
}

@Override
public int hashCode() {
int result = Objects.hash(version);
result = 31 * result + Arrays.hashCode(rawDek);
result = 31 * result + Arrays.hashCode(encryptedDek);
return result;
}
}
}

Expand Up @@ -25,6 +25,8 @@
import io.confluent.kafka.schemaregistry.encryption.tink.DekFormat;
import io.confluent.kafka.schemaregistry.utils.JacksonMapper;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;

@JsonInclude(Include.NON_NULL)
Expand All @@ -36,10 +38,13 @@ public class Dek {
private final int version;
private final DekFormat algorithm;
private final String encryptedKeyMaterial;
private final String keyMaterial;
private String keyMaterial;
private final Long timestamp;
private final Boolean deleted;

private byte[] encryptedKeyMaterialBytes;
private byte[] keyMaterialBytes;

@JsonCreator
public Dek(
@JsonProperty("kekName") String kekName,
Expand Down Expand Up @@ -91,6 +96,14 @@ public String getKeyMaterial() {
return this.keyMaterial;
}

@JsonProperty("keyMaterial")
public void setKeyMaterial(byte[] keyMaterialBytes) {
if (keyMaterialBytes != null) {
this.keyMaterial =
new String(Base64.getEncoder().encode(keyMaterialBytes), StandardCharsets.UTF_8);
}
}

@JsonProperty("ts")
public Long getTimestamp() {
return this.timestamp;
Expand All @@ -106,6 +119,30 @@ public boolean isDeleted() {
return Boolean.TRUE.equals(this.deleted);
}

@JsonIgnore
public byte[] getEncryptedKeyMaterialBytes() {
if (encryptedKeyMaterial == null) {
return null;
}
if (encryptedKeyMaterialBytes == null) {
encryptedKeyMaterialBytes =
Base64.getDecoder().decode(encryptedKeyMaterial.getBytes(StandardCharsets.UTF_8));
}
return encryptedKeyMaterialBytes;
}

@JsonIgnore
public byte[] getKeyMaterialBytes() {
if (keyMaterial == null) {
return null;
}
if (keyMaterialBytes == null) {
keyMaterialBytes =
Base64.getDecoder().decode(keyMaterial.getBytes(StandardCharsets.UTF_8));
}
return keyMaterialBytes;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down

0 comments on commit 2656938

Please sign in to comment.