diff --git a/decrypt_test.go b/decrypt_test.go index 2fd04fc..b9798fa 100644 --- a/decrypt_test.go +++ b/decrypt_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func Test_Decrypt(t *testing.T) { +func Test_Decrypt_V11(t *testing.T) { password := "asdf" content := `$ANSIBLE_VAULT;1.1;AES256 39663038636438383965366163636163376531336238346239623934393436393938656439643133 @@ -19,6 +19,20 @@ func Test_Decrypt(t *testing.T) { assert.NoError(t, err) } +func Test_Decrypt_V12(t *testing.T) { + password := "asdf" + content := `$ANSIBLE_VAULT;1.2;AES256;label +39663038636438383965366163636163376531336238346239623934393436393938656439643133 +3638363066366433666438623138373866393763373265320a366635386630336562633763323236 +61616562393964666464653532636436346535616566613434613361303734373734383930323661 +6664306264366235630a643235323438646132656337613434396338396335396439346336613062 +3766 +` + result, err := Decrypt(content, password) + assert.Equal(t, "hello", result) + assert.NoError(t, err) +} + func Test_DecryptFile(t *testing.T) { password := "asdf" filename := "./testdata/test1/secrets.yaml" diff --git a/encrypt.go b/encrypt.go index 1c20884..0656f85 100644 --- a/encrypt.go +++ b/encrypt.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "errors" "io/ioutil" "strings" ) @@ -21,15 +22,43 @@ func GenerateRandomBytes(n int) ([]byte, error) { return b, nil } -func EncryptFile(filename, password string) (result string, err error) { +func EncryptFile(filename, password, vaultID string) (result string, err error) { data, err := ioutil.ReadFile(filename) check(err) - result, err = Encrypt(string(data), password) + if vaultID == "" { + result, err = encryptV11(string(data), password) + } else { + result, err = encryptV12(string(data), password, vaultID) + } return } +// Encrypt will vault encrypt a piece of data. +// +// If a `vaultID` is not an empty string, it will upversion to 1.2, otherwise it will +// default to using 1.1. +// +// `vaultID` must not include `;`. If it does, an error will be thrown. +func Encrypt(body, password, vaultID string) (result string, err error) { + err = checkVaultID(vaultID) + if err != nil { + return "", err + } + if vaultID == "" { + return encryptV11(body, password) + } + return encryptV12(body, password, vaultID) +} + +func checkVaultID(vaultID string) error { + if strings.Contains(vaultID, ";") { + return errors.New("vaultID cannot contain ';'") + } + return nil +} + // see https://github.com/ansible/ansible/blob/0b8011436dc7f842b78298848e298f2a57ee8d78/lib/ansible/parsing/vault/__init__.py#L710 -func Encrypt(body, password string) (result string, err error) { +func encryptV11(body, password string) (result string, err error) { salt, err := GenerateRandomBytes(32) check(err) // salt_64 := "2262970e2309d5da757af6c473b0ed3034209cc0d48a3cc3d648c0b174c22fde" @@ -42,6 +71,21 @@ func Encrypt(body, password string) (result string, err error) { return } +// see https://docs.ansible.com/ansible/latest/user_guide/vault.html#ansible-vault-payload-format-1-1-1-2 +// see https://github.com/ansible/ansible/blob/0f95371131cd41d97ad95c4e8bd983081eb29a2a/lib/ansible/parsing/vault/__init__.py#L581 +func encryptV12(body, password, vaultID string) (result string, err error) { + salt, err := GenerateRandomBytes(32) + check(err) + // salt_64 := "2262970e2309d5da757af6c473b0ed3034209cc0d48a3cc3d648c0b174c22fde" + // salt,_ = hex.DecodeString(salt_64) + key1, key2, iv := genKeyInitctr(password, salt) + ciphertext := createCipherText(body, key1, iv) + combined := combineParts(ciphertext, key2, salt) + vaultText := hex.EncodeToString([]byte(combined)) + result = formatOutputV12(vaultText, vaultID) + return +} + func createCipherText(body string, key1, iv []byte) []byte { bs := aes.BlockSize padding := (bs - len(body)%bs) @@ -77,8 +121,12 @@ func combineParts(ciphertext, key2, salt []byte) string { return combined } -// https://github.com/ansible/ansible/blob/0b8011436dc7f842b78298848e298f2a57ee8d78/lib/ansible/parsing/vault/__init__.py#L268 func formatOutput(vaultText string) string { + return formatOutputV11(vaultText) +} + +// https://github.com/ansible/ansible/blob/0b8011436dc7f842b78298848e298f2a57ee8d78/lib/ansible/parsing/vault/__init__.py#L268 +func formatOutputV11(vaultText string) string { heading := "$ANSIBLE_VAULT" version := "1.1" cipherName := "AES256" @@ -103,3 +151,30 @@ func formatOutput(vaultText string) string { whole := strings.Join(elements, "\n") return whole } + +func formatOutputV12(vaultText, vaultIDText string) string { + heading := "$ANSIBLE_VAULT" + version := "1.2" + cipherName := "AES256" + + headerElements := make([]string, 4) + headerElements[0] = heading + headerElements[1] = version + headerElements[2] = cipherName + headerElements[3] = vaultIDText + header := strings.Join(headerElements, ";") + + elements := make([]string, 1) + elements[0] = header + for i := 0; i < len(vaultText); i += 80 { + end := i + 80 + if end > len(vaultText) { + end = len(vaultText) + } + elements = append(elements, vaultText[i:end]) + } + elements = append(elements, "") + + whole := strings.Join(elements, "\n") + return whole +} diff --git a/encrypt_test.go b/encrypt_test.go index 1eaaad2..c591f2b 100644 --- a/encrypt_test.go +++ b/encrypt_test.go @@ -5,12 +5,12 @@ import ( "testing" ) -func Test_Encrypt(t *testing.T) { +func Test_Encrypt_V11(t *testing.T) { password := "asdf" body := "secret" var encrypted string var err error - encrypted, err = Encrypt(body, password) + encrypted, err = Encrypt(body, password, "") assert.NoError(t, err) var result string @@ -18,3 +18,58 @@ func Test_Encrypt(t *testing.T) { assert.NoError(t, err) assert.Equal(t, body, result) } + +func Test_Encrypt_V12(t *testing.T) { + password := "asdf" + body := "secret" + var encrypted string + var err error + encrypted, err = Encrypt(body, password, "test") + assert.NoError(t, err) + + var result string + result, err = Decrypt(encrypted, password) + assert.NoError(t, err) + assert.Equal(t, body, result) +} + +func Test_encryptV11(t *testing.T) { + password := "asdf" + body := "secret" + var encrypted string + var err error + encrypted, err = encryptV11(body, password) + assert.NoError(t, err) + + var result string + result, err = Decrypt(encrypted, password) + assert.NoError(t, err) + assert.Equal(t, body, result) +} + +func Test_encryptV12(t *testing.T) { + password := "asdf" + body := "secret" + vaultID := "label" + var encrypted string + var err error + encrypted, err = encryptV12(body, password, vaultID) + assert.NoError(t, err) + + var result string + result, err = Decrypt(encrypted, password) + assert.NoError(t, err) + assert.Equal(t, body, result) +} + +func Test_checkvaultID(t *testing.T) { + var err error + err = checkVaultID("") + assert.NoError(t, err) + + err = checkVaultID("1-)90$#98klascalkkDADQXASdasd=-=+_+_=-=") + assert.NoError(t, err) + + err = checkVaultID("a;b") + assert.Error(t, err) +}