-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AWS KMS support for OAuth2 Client Credentials JWT authentication
This implementaion adds new configuration properties to "oauth2" aws_kms: AWS KMS key details aws_signing: Infomation for signing AWS requestion, similar to s3_signing References: 1) https://github.com/go-jose/go-jose/blob/v3/asymmetric.go#L501 2) https://github.com/codelittinc/gobitauth/blob/master/sign.go#L101 Signed-off-by: Prasanth Ullattil <prasanth.ullattil@dnb.no>
- Loading branch information
1 parent
f74a5f6
commit 6dc46be
Showing
7 changed files
with
709 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
package aws | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/base64" | ||
"encoding/json" | ||
"fmt" | ||
"net/http" | ||
"time" | ||
|
||
"github.com/open-policy-agent/opa/internal/version" | ||
"github.com/open-policy-agent/opa/logging" | ||
) | ||
|
||
// Values taken from | ||
// https://docs.aws.amazon.com/kms/latest/APIReference/Welcome.html | ||
// https://docs.aws.amazon.com/general/latest/gr/kms.html | ||
const ( | ||
kmsSignTarget = "TrentService.Sign" | ||
kmsEndpointFmt = "https://kms.%s.amazonaws.com/" | ||
) | ||
|
||
// KMS is used to sign payloads using AWS Key Management Service. | ||
type KMS struct { | ||
// endpoint returns the region-specifc KMS endpoint. | ||
// It can be overridden by tests. | ||
endpoint func(region string) string | ||
|
||
// client is used to send authorization tokens requests. | ||
client *http.Client | ||
|
||
logger logging.Logger | ||
} | ||
|
||
func NewKMS(logger logging.Logger) *KMS { | ||
return &KMS{ | ||
endpoint: func(region string) string { | ||
return fmt.Sprintf(kmsEndpointFmt, region) | ||
}, | ||
client: &http.Client{}, | ||
logger: logger, | ||
} | ||
} | ||
|
||
func NewKMSWithURLClient(url string, client *http.Client, logger logging.Logger) *KMS { | ||
return &KMS{ | ||
endpoint: func(string) string { return url }, | ||
client: client, | ||
logger: logger, | ||
} | ||
} | ||
|
||
type KMSSignRequest struct { | ||
KeyID string `json:"KeyId"` | ||
Message string `json:"Message"` | ||
MessageType string `json:"MessageType"` | ||
SigningAlgorithm string `json:"SigningAlgorithm"` | ||
} | ||
type KMSSignResponse struct { | ||
KeyID string `json:"KeyId"` | ||
Signature string `json:"Signature"` | ||
SigningAlgorithm string `json:"SigningAlgorithm"` | ||
} | ||
|
||
// SignDigest signs a digest using KMS. | ||
func (k *KMS) SignDigest(ctx context.Context, digest []byte, keyID string, signingAlgorithm string, creds Credentials, signatureVersion string) (string, error) { | ||
endpoint := k.endpoint(creds.RegionName) | ||
|
||
kmsRequest := KMSSignRequest{ | ||
KeyID: keyID, | ||
Message: base64.StdEncoding.EncodeToString(digest), | ||
MessageType: "DIGEST", | ||
SigningAlgorithm: signingAlgorithm, | ||
} | ||
requestJSONBytes, err := json.Marshal(kmsRequest) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to marshall request: %w", err) | ||
} | ||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(requestJSONBytes)) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to create request: %w", err) | ||
} | ||
|
||
req.Header.Set("X-Amz-Target", kmsSignTarget) | ||
req.Header.Set("Accept-Encoding", "identity") | ||
req.Header.Set("Content-Type", "application/x-amz-json-1.1") | ||
req.Header.Set("User-Agent", version.UserAgent) | ||
|
||
if err := SignRequest(req, "kms", creds, time.Now(), signatureVersion); err != nil { | ||
return "", fmt.Errorf("failed to sign request: %w", err) | ||
} | ||
|
||
resp, err := DoRequestWithClient(req, k.client, "kms sign digest", k.logger) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
var data KMSSignResponse | ||
if err := json.Unmarshal(resp, &data); err != nil { | ||
return "", fmt.Errorf("failed to unmarshal response: %w", err) | ||
} | ||
|
||
return data.Signature, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package aws | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/open-policy-agent/opa/logging" | ||
) | ||
|
||
func mockPayload(request KMSSignRequest) string { | ||
responseFmt := `{"KeyId": "%s", "Signature": "%s", "SigningAlgorithm": "%s"}` | ||
return fmt.Sprintf(responseFmt, request.KeyID, request.Message, request.SigningAlgorithm) | ||
} | ||
|
||
func TestKMS_SignDigest(t *testing.T) { | ||
type testCase struct { | ||
name string | ||
request KMSSignRequest | ||
responsePayload string | ||
responseStatus int | ||
wantSignature string | ||
wantErr bool | ||
} | ||
|
||
run := func(t *testing.T, tc testCase) { | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if tc.responseStatus != 200 { | ||
w.WriteHeader(tc.responseStatus) | ||
} | ||
if _, err := io.WriteString(w, tc.responsePayload); err != nil { | ||
t.Fatalf("io.WriteString(w, payload) = %v", err) | ||
} | ||
|
||
})) | ||
defer server.Close() | ||
|
||
logger := logging.New() | ||
logger.SetLevel(logging.Debug) | ||
|
||
kms := NewKMSWithURLClient(server.URL, server.Client(), logger) | ||
|
||
creds := Credentials{} | ||
signature, err := kms.SignDigest(context.Background(), []byte(tc.request.Message), tc.request.KeyID, tc.request.SigningAlgorithm, creds, "v4") | ||
if err != nil && tc.wantErr == false { | ||
t.Fatalf("expected no error, got: %s", err) | ||
} | ||
|
||
if err == nil && tc.wantErr { | ||
t.Fatal("expected error") | ||
} | ||
|
||
if err == nil && tc.wantSignature != signature { | ||
t.Fatalf("expected %s, got %s", tc.wantSignature, signature) | ||
} | ||
|
||
} | ||
validRequest1 := KMSSignRequest{ | ||
KeyID: "Keyid1", | ||
Message: "sample", | ||
SigningAlgorithm: "ECDSA_SHA_256", | ||
} | ||
testCases := []testCase{ | ||
{ | ||
name: "valid response", | ||
request: validRequest1, | ||
responsePayload: mockPayload(validRequest1), | ||
responseStatus: 200, | ||
wantSignature: validRequest1.Message, | ||
wantErr: false, | ||
}, | ||
{ | ||
name: "error response", | ||
request: validRequest1, | ||
responsePayload: "Backend error", | ||
responseStatus: 500, | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "valid error response", | ||
request: validRequest1, | ||
responsePayload: `{ "__type" :"SerializationException" }`, | ||
responseStatus: 400, | ||
wantErr: true, | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
run(t, tc) | ||
}) | ||
} | ||
} |
Oops, something went wrong.