-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
iam_scram_client.go
182 lines (152 loc) · 4.43 KB
/
iam_scram_client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0
package awsmsk // import "github.com/open-telemetry/opentelemetry-collector-contrib/internal/kafka/awsmsk"
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/IBM/sarama"
"github.com/aws/aws-sdk-go/aws/credentials"
sign "github.com/aws/aws-sdk-go/aws/signer/v4"
"go.uber.org/multierr"
)
const (
Mechanism = "AWS_MSK_IAM"
service = "kafka-cluster"
supportedVersion = "2020_10_22"
scopeFormat = `%s/%s/%s/kafka-cluster/aws4_request`
)
const (
_ int32 = iota // Ignoring the zero value to ensure we start up correctly
initMessage
serverResponse
complete
failed
)
var (
ErrFailedServerChallenge = errors.New("failed server challenge")
ErrBadChallenge = errors.New("invalid challenge data provided")
ErrInvalidStateReached = errors.New("invalid state reached")
)
type IAMSASLClient struct {
MSKHostname string
Region string
UserAgent string
signer *sign.StreamSigner
state int32
accessKey string
secretKey string
}
type payload struct {
Version string `json:"version"`
BrokerHost string `json:"host"`
UserAgent string `json:"user-agent"`
Action string `json:"action"`
Algorithm string `json:"x-amz-algorithm"`
Credentials string `json:"x-amz-credential"`
Date string `json:"x-amz-date"`
Expires string `json:"x-amz-expires"`
SignedHeaders string `json:"x-amz-signedheaders"`
Signature string `json:"x-amz-signature"`
}
type response struct {
Version string `json:"version"`
RequestID string `json:"request-id"`
}
var _ sarama.SCRAMClient = (*IAMSASLClient)(nil)
func NewIAMSASLClient(mskhostname, region, useragent string) sarama.SCRAMClient {
return &IAMSASLClient{
MSKHostname: mskhostname,
Region: region,
UserAgent: useragent,
}
}
func (sc *IAMSASLClient) Begin(username, password, _ string) error {
if sc.MSKHostname == "" {
return errors.New("missing required MSK Broker hostname")
}
if sc.Region == "" {
return errors.New("missing MSK cluster region")
}
if sc.UserAgent == "" {
return errors.New("missing value for MSK user agent")
}
sc.signer = sign.NewStreamSigner(
sc.Region,
service,
nil,
credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: username,
SecretAccessKey: password,
},
},
}),
)
sc.accessKey = username
sc.secretKey = password
sc.state = initMessage
return nil
}
func (sc *IAMSASLClient) Step(challenge string) (string, error) {
var resp string
switch sc.state {
case initMessage:
if challenge != "" {
sc.state = failed
return "", fmt.Errorf("challenge must be empty for initial request: %w", ErrBadChallenge)
}
payload, err := sc.getAuthPayload()
if err != nil {
sc.state = failed
return "", err
}
resp = string(payload)
sc.state = serverResponse
case serverResponse:
if challenge == "" {
sc.state = failed
return "", fmt.Errorf("challenge must not be empty for server resposne: %w", ErrBadChallenge)
}
var resp response
if err := json.NewDecoder(strings.NewReader(challenge)).Decode(&resp); err != nil {
sc.state = failed
return "", fmt.Errorf("unable to process msk challenge response: %w", multierr.Combine(err, ErrFailedServerChallenge))
}
if resp.Version != supportedVersion {
sc.state = failed
return "", fmt.Errorf("unknown version found in response: %w", ErrFailedServerChallenge)
}
sc.state = complete
default:
return "", fmt.Errorf("invalid invocation: %w", ErrInvalidStateReached)
}
return resp, nil
}
func (sc *IAMSASLClient) Done() bool { return sc.state == complete }
func (sc *IAMSASLClient) getAuthPayload() ([]byte, error) {
ts := time.Now().UTC()
headers := []byte("host:" + sc.MSKHostname)
sig, err := sc.signer.GetSignature(headers, nil, ts)
if err != nil {
return nil, err
}
// Creating a timestamp in the form of: yyyyMMdd'T'HHmmss'Z'
date := ts.Format("20060102T150405Z")
return json.Marshal(&payload{
Version: supportedVersion,
BrokerHost: sc.MSKHostname,
UserAgent: sc.UserAgent,
Action: "kafka-cluster:Connect",
Algorithm: "AWS4-HMAC-SHA256",
Credentials: fmt.Sprintf(scopeFormat, sc.accessKey, date[:8], sc.Region),
Date: date,
SignedHeaders: "host",
Expires: "300", // Seconds => 5 Minutes
Signature: string(sig),
})
}