Skip to content

Commit

Permalink
chore(tests): improve coverage and refactor some error handling
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Smith <drsmith.phys@gmail.com>
  • Loading branch information
clok committed Dec 21, 2021
1 parent a125cf1 commit f3085c3
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 17 deletions.
29 changes: 14 additions & 15 deletions decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func Decrypt(opts *DecryptOptions) (result string, err error) {
return
}

// in order to support vault files with windows line endings
// replaceCarriageReturn in order to support vault files with windows line endings
func replaceCarriageReturn(data string) string {
return strings.ReplaceAll(data, "\r", "")
}
Expand All @@ -79,7 +79,7 @@ func splitHeader(data []byte) string {
header := strings.Split(lines[0], ";")
cipherName := strings.TrimSpace(header[2])
if cipherName != "AES256" {
panic("unsupported cipher: " + cipherName)
panic(fmt.Errorf("unsupported cipher: %s", cipherName))
}
body := strings.Join(lines[1:], "")
return body
Expand All @@ -92,18 +92,17 @@ https://github.com/ansible/ansible/blob/0b8011436dc7f842b78298848e298f2a57ee8d78
func decodeData(body string) (salt, cryptedHmac, ciphertext []byte) {
decoded, _ := hex.DecodeString(body)
elements := strings.SplitN(string(decoded), "\n", 3)
salt, err1 := hex.DecodeString(elements[0])
if err1 != nil {
panic(err1)
}
cryptedHmac, err2 := hex.DecodeString(elements[1])
if err2 != nil {
panic(err2)
}
ciphertext, err3 := hex.DecodeString(elements[2])
if err3 != nil {
panic(err3)
}

var err error
salt, err = hex.DecodeString(elements[0])
check(err)

cryptedHmac, err = hex.DecodeString(elements[1])
check(err)

ciphertext, err = hex.DecodeString(elements[2])
check(err)

return
}

Expand Down Expand Up @@ -131,6 +130,6 @@ func checkDigest(key2, cryptedHmac, ciphertext []byte) {
check(err)
expectedMAC := hmacDecrypt.Sum(nil)
if !hmac.Equal(cryptedHmac, expectedMAC) {
panic("digests do not match - exiting")
log.Fatal("digests do not match - exiting")
}
}
22 changes: 22 additions & 0 deletions decrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,25 @@ func Test_DecryptFile(t *testing.T) {
})
assert.NoError(t, err)
}

func Test_splitHeader(t *testing.T) {
content := []byte(`$ANSIBLE_VAULT;1.2;AES256;label
39663038636438383965366163636163376531336238346239623934393436393938656439643133
3638363066366433666438623138373866393763373265320a366635386630336562633763323236
61616562393964666464653532636436346535616566613434613361303734373734383930323661
6664306264366235630a643235323438646132656337613434396338396335396439346336613062
3766
`)
expected := "396630386364383839653661636361633765313362383462396239343934363939386564396431333638363066366433666438623138373866393763373265320a366635386630336562633763323236616165623939646664646535326364363465356165666134346133613037343737343839303236616664306264366235630a6432353234386461326563376134343963383963353964393463366130623766"
body := splitHeader(content)
assert.Equal(t, expected, body)
}

func Test_splitHeader_unsupported_cipher(t *testing.T) {
content := []byte(`$ANSIBLE_VAULT;1.2;AES128;label
NOOP
`)
assert.PanicsWithError(t, "unsupported cipher: AES128", func() {
_ = splitHeader(content)
})
}
4 changes: 2 additions & 2 deletions encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io/ioutil"
"strings"
)
Expand Down Expand Up @@ -75,7 +75,7 @@ func Encrypt(opts *EncryptOptions) (result string, err error) {

func checkVaultID(vaultID string) error {
if strings.Contains(vaultID, ";") {
return errors.New("vaultID cannot contain ';'")
return fmt.Errorf("vaultID (%s) cannot contain ';'", vaultID)
}
return nil
}
Expand Down
64 changes: 64 additions & 0 deletions encrypt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package avtool

import (
"errors"
"github.com/stretchr/testify/assert"
"testing"
)
Expand Down Expand Up @@ -48,6 +49,20 @@ func Test_Encrypt_V12(t *testing.T) {
assert.Equal(t, string(body), result)
}

func Test_Encrypt_V12_Bad_VaultID(t *testing.T) {
password := []byte("asdf")
body := []byte("secret")
var err error
_, err = Encrypt(&EncryptOptions{
Body: &body,
Password: &password,
VaultID: "A;Bad;Label",
})
if assert.Error(t, err) {
assert.Equal(t, errors.New("vaultID (A;Bad;Label) cannot contain ';'"), err)
}
}

func Test_encryptV11(t *testing.T) {
password := []byte("asdf")
body := []byte("secret")
Expand Down Expand Up @@ -103,3 +118,52 @@ func Test_checkVaultID(t *testing.T) {
err = checkVaultID("a;b")
assert.Error(t, err)
}

func Test_EncryptFile_V11(t *testing.T) {
password := []byte("asdf")
encrypted, err := EncryptFile(&EncryptFileOptions{
Filename: "./testdata/encrypt_file.log",
Password: &password,
})
assert.NoError(t, err)
assert.Contains(t, encrypted, "$ANSIBLE_VAULT;1.1;AES256")

var result string
data := []byte(encrypted)
result, err = Decrypt(&DecryptOptions{
Data: &data,
Password: &password,
})
assert.NoError(t, err)

expected := `This is a test.
I have data.
`
assert.Equal(t, expected, result)
}

func Test_EncryptFile_V12(t *testing.T) {
password := []byte("asdf")
encrypted, err := EncryptFile(&EncryptFileOptions{
Filename: "./testdata/encrypt_file.log",
Password: &password,
VaultID: "test-label",
})
assert.NoError(t, err)
assert.Contains(t, encrypted, "$ANSIBLE_VAULT;1.2;AES256;test-label")

var result string
data := []byte(encrypted)
result, err = Decrypt(&DecryptOptions{
Data: &data,
Password: &password,
})
assert.NoError(t, err)

expected := `This is a test.
I have data.
`
assert.Equal(t, expected, result)
}
3 changes: 3 additions & 0 deletions testdata/encrypt_file.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This is a test.

I have data.

0 comments on commit f3085c3

Please sign in to comment.