Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AD mode to Transit's AEAD ciphers #17638

Merged
merged 4 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions builtin/logical/transit/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1717,3 +1717,115 @@ func TestTransit_AutoRotateKeys(t *testing.T) {
)
}
}

func TestTransit_AEAD(t *testing.T) {
testTransit_AEAD(t, "aes128-gcm96")
testTransit_AEAD(t, "aes256-gcm96")
testTransit_AEAD(t, "chacha20-poly1305")
}

func testTransit_AEAD(t *testing.T, keyType string) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)

keyReq := &logical.Request{
Path: "keys/aead",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": keyType,
},
Storage: storage,
}

resp, err = b.HandleRequest(context.Background(), keyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox"
associated := "U3BoaW54IG9mIGJsYWNrIHF1YXJ0eiwganVkZ2UgbXkgdm93Lgo=" // "Sphinx of black quartz, judge my vow."

// Basic encrypt/decrypt should work.
encryptReq := &logical.Request{
Path: "encrypt/aead",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"plaintext": plaintext,
},
}

resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

ciphertext1 := resp.Data["ciphertext"].(string)

decryptReq := &logical.Request{
Path: "decrypt/aead",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"ciphertext": ciphertext1,
},
}

resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

decryptedPlaintext := resp.Data["plaintext"]

if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}

// Using associated as ciphertext should fail.
decryptReq.Data["ciphertext"] = associated
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp)
}

// Redoing the above with additional data should work.
encryptReq.Data["associated_data"] = associated
resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

ciphertext2 := resp.Data["ciphertext"].(string)
decryptReq.Data["ciphertext"] = ciphertext2
decryptReq.Data["associated_data"] = associated

resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}

decryptedPlaintext = resp.Data["plaintext"]
if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}

// Removing the associated_data should break the decryption.
decryptReq.Data = map[string]interface{}{
"ciphertext": ciphertext2,
}
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp)
}

// Using a valid ciphertext with associated_data should also break the
// decryption.
decryptReq.Data["ciphertext"] = ciphertext1
decryptReq.Data["associated_data"] = associated
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp)
}
}
31 changes: 27 additions & 4 deletions builtin/logical/transit/path_decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Base64 encoded nonce value used during encryption. Must be provided if
convergent encryption is enabled for this key and the key was generated with
Vault 0.6.1. Not required for keys created in 0.6.2+.`,
},

"partial_failure_response_code": {
Type: framework.TypeInt,
Description: `
Expand All @@ -58,6 +59,17 @@ the HTTP response code is 400 (Bad Request). Some applications may want to trea
Providing the parameter returns the given response code integer instead of a 400 in this case. If all values fail
HTTP 400 is still returned.`,
},

"associated_data": {
Type: framework.TypeString,
Description: `
When using an AEAD cipher mode, such as AES-GCM, this parameter allows
passing associated data (AD/AAD) into the encryption function; this data
must be passed on subsequent decryption requests but can be transited in
plaintext. On successful decryption, both the ciphertext and the associated
data are attested not to have been tampered with.
`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -90,9 +102,10 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d

batchInputItems = make([]BatchRequestItem, 1)
batchInputItems[0] = BatchRequestItem{
Ciphertext: ciphertext,
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
Ciphertext: ciphertext,
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
AssociatedData: d.Get("associated_data").(string),
}
}

Expand Down Expand Up @@ -155,7 +168,17 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
continue
}

plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext)
var factory interface{}
if item.AssociatedData != "" {
if !p.Type.AssociatedDataSupported() {
batchResponseItems[i].Error = fmt.Sprintf("'[%d].associated_data' provided for non-AEAD cipher suite %v", i, p.Type.String())
continue
}

factory = AssocDataFactory{item.AssociatedData}
}

plaintext, err := p.DecryptWithFactory(item.DecodedContext, item.DecodedNonce, item.Ciphertext, factory)
if err != nil {
switch err.(type) {
case errutil.InternalError:
Expand Down
53 changes: 48 additions & 5 deletions builtin/logical/transit/path_encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ type BatchRequestItem struct {

// DecodedNonce is the base64 decoded version of Nonce
DecodedNonce []byte

// Associated Data for AEAD ciphers
AssociatedData string `json:"associated_data" struct:"associated_data" mapstructure:"associated_data"`
}

// EncryptBatchResponseItem represents a response item for batch processing
Expand All @@ -55,6 +58,14 @@ type EncryptBatchResponseItem struct {
Error string `json:"error,omitempty" structs:"error" mapstructure:"error"`
}

type AssocDataFactory struct {
Encoded string
}

func (a AssocDataFactory) GetAssociatedData() ([]byte, error) {
return base64.StdEncoding.DecodeString(a.Encoded)
}

func (b *backend) pathEncrypt() *framework.Path {
return &framework.Path{
Pattern: "encrypt/" + framework.GenericNameRegex("name"),
Expand Down Expand Up @@ -113,6 +124,7 @@ will severely impact the ciphertext's security.`,
Must be 0 (for latest) or a value greater than or equal
to the min_encryption_version configured on the key.`,
},

"partial_failure_response_code": {
Type: framework.TypeInt,
Description: `
Expand All @@ -121,6 +133,17 @@ the HTTP response code is 400 (Bad Request). Some applications may want to trea
Providing the parameter returns the given response code integer instead of a 400 in this case. If all values fail
HTTP 400 is still returned.`,
},

"associated_data": {
Type: framework.TypeString,
Description: `
When using an AEAD cipher mode, such as AES-GCM, this parameter allows
passing associated data (AD/AAD) into the encryption function; this data
must be passed on subsequent decryption requests but can be transited in
plaintext. On successful decryption, both the ciphertext and the associated
data are attested not to have been tampered with.
`,
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -229,6 +252,15 @@ func decodeBatchRequestItems(src interface{}, requirePlaintext bool, requireCiph
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].key_version' expected type 'int', got unconvertible type '%T'", i, item["key_version"]))
}
}

if v, has := item["associated_data"]; has {
if !reflect.ValueOf(v).IsValid() {
} else if casted, ok := v.(string); ok {
(*dst)[i].AssociatedData = casted
} else {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].associated_data' expected type 'string', got unconvertible type '%T'", i, item["associated_data"]))
}
}
}

if len(errs.Errors) > 0 {
Expand Down Expand Up @@ -279,10 +311,11 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d

batchInputItems = make([]BatchRequestItem, 1)
batchInputItems[0] = BatchRequestItem{
Plaintext: valueRaw.(string),
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
KeyVersion: d.Get("key_version").(int),
Plaintext: valueRaw.(string),
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
KeyVersion: d.Get("key_version").(int),
AssociatedData: d.Get("associated_data").(string),
}
}

Expand Down Expand Up @@ -393,7 +426,17 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
warnAboutNonceUsage = true
}

ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext)
var factory interface{}
if item.AssociatedData != "" {
if !p.Type.AssociatedDataSupported() {
batchResponseItems[i].Error = fmt.Sprintf("'[%d].associated_data' provided for non-AEAD cipher suite %v", i, p.Type.String())
continue
}

factory = AssocDataFactory{item.AssociatedData}
}

ciphertext, err := p.EncryptWithFactory(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext, factory)
if err != nil {
switch err.(type) {
case errutil.InternalError:
Expand Down
3 changes: 3 additions & 0 deletions changelog/17638.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Add associated_data parameter for additional authenticated data in AEAD ciphers
```
74 changes: 61 additions & 13 deletions sdk/helper/keysutil/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ type AEADFactory interface {
GetAEAD(iv []byte) (cipher.AEAD, error)
}

type AssociatedDataFactory interface {
GetAssociatedData() ([]byte, error)
}

type RestoreInfo struct {
Time time.Time `json:"time"`
Version int `json:"version"`
Expand Down Expand Up @@ -147,6 +151,14 @@ func (kt KeyType) DerivationSupported() bool {
return false
}

func (kt KeyType) AssociatedDataSupported() bool {
switch kt {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
return true
}
return false
}

func (kt KeyType) String() string {
switch kt {
case KeyType_AES128_GCM96:
Expand Down Expand Up @@ -844,6 +856,10 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string,
}

func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
return p.DecryptWithFactory(context, nonce, value, nil)
}

func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) {
if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
}
Expand Down Expand Up @@ -911,11 +927,28 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
}

plain, err = p.SymmetricDecryptRaw(encKey, decoded,
SymmetricOpts{
Convergent: p.ConvergentEncryption,
ConvergentVersion: p.ConvergentVersion,
})
symopts := SymmetricOpts{
Convergent: p.ConvergentEncryption,
ConvergentVersion: p.ConvergentVersion,
}
for index, rawFactory := range factories {
if rawFactory == nil {
continue
}
switch factory := rawFactory.(type) {
case AEADFactory:
symopts.AEADFactory = factory
case AssociatedDataFactory:
symopts.AdditionalData, err = factory.GetAssociatedData()
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
}
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
}
}

plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -1830,7 +1863,7 @@ func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOp
return plain, nil
}

func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factory AEADFactory) (string, error) {
func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) {
if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
}
Expand Down Expand Up @@ -1891,14 +1924,29 @@ func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value
}
}

ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext,
SymmetricOpts{
Convergent: p.ConvergentEncryption,
HMACKey: hmacKey,
Nonce: nonce,
AEADFactory: factory,
})
symopts := SymmetricOpts{
Convergent: p.ConvergentEncryption,
HMACKey: hmacKey,
Nonce: nonce,
}
for index, rawFactory := range factories {
if rawFactory == nil {
continue
}
switch factory := rawFactory.(type) {
case AEADFactory:
symopts.AEADFactory = factory
case AssociatedDataFactory:
symopts.AdditionalData, err = factory.GetAssociatedData()
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
}
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
}
}

ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts)
if err != nil {
return "", err
}
Expand Down
Loading