forked from viant/toolbox
/
service.go
133 lines (123 loc) · 3.18 KB
/
service.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package aws
import (
"bytes"
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
akms "github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/pkg/errors"
"strings"
"github.com/viant/toolbox"
"github.com/viant/toolbox/kms"
)
type service struct {
*ssm.SSM
*akms.KMS
}
func (s *service) Encrypt(ctx context.Context, request *kms.EncryptRequest) (*kms.EncryptResponse, error) {
err := request.Validate()
if err != nil {
return nil, errors.Wrap(err, "invalid encrypt request")
}
if request.Parameter == "" {
return nil, errors.New("parameter was empty")
}
response := &kms.EncryptResponse{}
err = s.putParameters(request.Key, request.Parameter, string(request.Data))
if err == nil {
parameter, err := s.getParameters(request.Parameter, false)
if err != nil {
return nil, err
}
response.EncryptedText = *parameter.Value
response.EncryptedData = []byte(response.EncryptedText)
}
return response, err
}
func (s *service) Decrypt(ctx context.Context, request *kms.DecryptRequest) (*kms.DecryptResponse, error) {
err := request.Validate()
if err != nil {
return nil, errors.Wrap(err, "invalid encrypt request")
}
if request.Parameter == "" {
return nil, errors.New("parameter was empty")
}
response := &kms.DecryptResponse{}
parameter, err := s.getParameters(request.Parameter, true)
if err != nil {
return nil, err
}
response.Text = *parameter.Value
response.Data = []byte(response.Text)
return response, nil
}
func (s *service) Decode(ctx context.Context, decryptRequest *kms.DecryptRequest, factory toolbox.DecoderFactory, target interface{}) error {
response, err := s.Decrypt(ctx, decryptRequest)
if err != nil {
return err
}
reader := bytes.NewReader(response.Data)
return factory.Create(reader).Decode(target)
}
func (s *service) putParameters(keyOrAlias, name, value string) error {
targetKeyID, err := s.getKeyByAlias(keyOrAlias)
if err != nil {
return err
}
_, err = s.PutParameter(&ssm.PutParameterInput{
Name: aws.String(name),
KeyId: &targetKeyID,
Value: &value,
})
return err
}
func (s *service) getKeyByAlias(keyOrAlias string) (string, error) {
if strings.Count(keyOrAlias, ":") > 0 {
return keyOrAlias, nil
}
var nextMarker *string
for {
output, err := s.ListAliases(&akms.ListAliasesInput{
Marker: nextMarker,
})
if err != nil {
return "", err
}
if len(output.Aliases) == 0 {
break
}
for _, candidate := range output.Aliases {
if *candidate.AliasName == keyOrAlias {
return *candidate.TargetKeyId, nil
}
}
nextMarker = output.NextMarker
if nextMarker == nil {
break
}
}
return "", fmt.Errorf("key for alias %v no found", keyOrAlias)
}
func (s *service) getParameters(name string, withDecryption bool) (*ssm.Parameter, error) {
output, err := s.GetParameter(&ssm.GetParameterInput{
Name: aws.String(name),
WithDecryption: &withDecryption,
})
if err != nil {
return nil, err
}
return output.Parameter, nil
}
//New returns new kms service
func New() (kms.Service, error) {
sess, err := session.NewSession()
if err != nil {
return nil, err
}
return &service{
SSM: ssm.New(sess),
KMS: akms.New(sess),
}, nil
}