Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 258 additions & 24 deletions google/internal/externalaccount/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,54 @@
package externalaccount

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"golang.org/x/oauth2"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"sort"
"strings"
"time"
)

// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsSecurityCredentials struct {
AccessKeyID string `json:"AccessKeyID"`
SecretAccessKey string `json:"SecretAccessKey"`
SecurityToken string `json:"Token"`
}

// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct {
RegionName string
AwsSecurityCredentials map[string]string
AwsSecurityCredentials awsSecurityCredentials
}

// getenv aliases os.Getenv for testing
var getenv = os.Getenv

const (
// AWS Signature Version 4 signing algorithm identifier.
// AWS Signature Version 4 signing algorithm identifier.
awsAlgorithm = "AWS4-HMAC-SHA256"

// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
awsRequestType = "aws4_request"

// The AWS authorization header name for the security session token if available.
// The AWS authorization header name for the security session token if available.
awsSecurityTokenHeader = "x-amz-security-token"

// The AWS authorization header name for the auto-generated date.
// The AWS authorization header name for the auto-generated date.
awsDateHeader = "x-amz-date"

awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatShort = "20060102"
)

Expand Down Expand Up @@ -167,8 +180,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {

signedRequest.Header.Add("host", requestHost(req))

if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
if rs.AwsSecurityCredentials.SecurityToken != "" {
signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
}

if signedRequest.Header.Get("date") == "" {
Expand All @@ -186,15 +199,6 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
}

func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
if !ok {
return "", errors.New("oauth2/google: missing secret_access_key header")
}
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
if !ok {
return "", errors.New("oauth2/google: missing access_key_id header")
}

canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)

dateStamp := timestamp.Format(awsTimeFormatShort)
Expand All @@ -203,28 +207,258 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
serviceName = splitHost[0]
}

credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)

requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
if err != nil {
return "", err
}
requestHash, err := getSha256([]byte(requestString))
if err != nil{
if err != nil {
return "", err
}

stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)

signingKey := []byte("AWS4" + secretAccessKey)
signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
for _, signingInput := range []string{
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
} {
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
if err != nil{
if err != nil {
return "", err
}
}

return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
}

type awsCredentialSource struct {
EnvironmentID string
RegionURL string
RegionalCredVerificationURL string
CredVerificationURL string
TargetResource string
requestSigner *awsRequestSigner
region string
ctx context.Context
client *http.Client
}

type awsRequestHeader struct {
Key string `json:"key"`
Value string `json:"value"`
}

type awsRequest struct {
URL string `json:"url"`
Method string `json:"method"`
Headers []awsRequestHeader `json:"headers"`
}

func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil)
}
return cs.client.Do(req.WithContext(cs.ctx))
}

func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil {
awsSecurityCredentials, err := cs.getSecurityCredentials()
if err != nil {
return "", err
}

if cs.region, err = cs.getRegion(); err != nil {
return "", err
}

cs.requestSigner = &awsRequestSigner{
RegionName: cs.region,
AwsSecurityCredentials: awsSecurityCredentials,
}
}

// Generate the signed request to AWS STS GetCallerIdentity API.
// Use the required regional endpoint. Otherwise, the request will fail.
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
if err != nil {
return "", err
}
// The full, canonical resource name of the workload identity pool
// provider, with or without the HTTPS prefix.
// Including this header as part of the signature is recommended to
// ensure data integrity.
if cs.TargetResource != "" {
req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
}
cs.requestSigner.SignRequest(req)

/*
The GCP STS endpoint expects the headers to be formatted as:
# [
# {key: 'x-amz-date', value: '...'},
# {key: 'Authorization', value: '...'},
# ...
# ]
# And then serialized as:
# quote(json.dumps({
# url: '...',
# method: 'POST',
# headers: [{key: 'x-amz-date', value: '...'}, ...]
# }))
*/

awsSignedReq := awsRequest{
URL: req.URL.String(),
Method: "POST",
}
for headerKey, headerList := range req.Header {
for _, headerValue := range headerList {
awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
Key: headerKey,
Value: headerValue,
})
}
}
sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
if headerCompare == 0 {
return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
}
return headerCompare < 0
})

result, err := json.Marshal(awsSignedReq)
if err != nil {
return "", err
}
return string(result), nil
}

func (cs *awsCredentialSource) getRegion() (string, error) {
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
}

if cs.RegionURL == "" {
return "", errors.New("oauth2/google: unable to determine AWS region")
}

req, err := http.NewRequest("GET", cs.RegionURL, nil)
if err != nil {
return "", err
}

resp, err := cs.doRequest(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}

if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
}

// This endpoint will return the region in format: us-east-2b.
// Only the us-east-2 part should be used.
respBodyEnd := 0
if len(respBody) > 1 {
respBodyEnd = len(respBody) - 1
}
return string(respBody[:respBodyEnd]), nil
}

func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SecurityToken: getenv("AWS_SESSION_TOKEN"),
}, nil
}
}

roleName, err := cs.getMetadataRoleName()
if err != nil {
return
}

credentials, err := cs.getMetadataSecurityCredentials(roleName)
if err != nil {
return
}

if credentials.AccessKeyID == "" {
return result, errors.New("oauth2/google: missing AccessKeyId credential")
}

if credentials.SecretAccessKey == "" {
return result, errors.New("oauth2/google: missing SecretAccessKey credential")
}

return credentials, nil
}

func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials

req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
if err != nil {
return result, err
}
req.Header.Add("Content-Type", "application/json")

resp, err := cs.doRequest(req)
if err != nil {
return result, err
}
defer resp.Body.Close()

respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return result, err
}

if resp.StatusCode != 200 {
return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
}

err = json.Unmarshal(respBody, &result)
return result, err
}

func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
if cs.CredVerificationURL == "" {
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
}

req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
if err != nil {
return "", err
}

resp, err := cs.doRequest(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}

if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
}

return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
return string(respBody), nil
}
Loading