Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions google/internal/externalaccount/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package externalaccount

import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"path"
"sort"
"strings"
"time"
)

// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct {
RegionName string
AwsSecurityCredentials map[string]string
}

const (
// AWS Signature Version 4 signing algorithm identifier.
awsAlgorithm = "AWS4-HMAC-SHA256"

// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
awsRequestType = "aws4_request"

// The AWS authorization header name for the security session token if available.
awsSecurityTokenHeader = "x-amz-security-token"

// The AWS authorization header name for the auto-generated date.
awsDateHeader = "x-amz-date"

awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatShort = "20060102"
)

func getSha256(input []byte) (string, error) {
hash := sha256.New()
if _, err := hash.Write(input); err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
}

func getHmacSha256(key, input []byte) ([]byte, error) {
hash := hmac.New(sha256.New, key)
if _, err := hash.Write(input); err != nil {
return nil, err
}
return hash.Sum(nil), nil
}

func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
if r.Header != nil {
r2.Header = make(http.Header, len(r.Header))

// Find total number of values.
headerCount := 0
for _, headerValues := range r.Header {
headerCount += len(headerValues)
}
copiedHeaders := make([]string, headerCount) // shared backing array for headers' values

for headerKey, headerValues := range r.Header {
headerCount = copy(copiedHeaders, headerValues)
r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
copiedHeaders = copiedHeaders[headerCount:]
}
}
return r2
}

func canonicalPath(req *http.Request) string {
result := req.URL.EscapedPath()
if result == "" {
return "/"
}
return path.Clean(result)
}

func canonicalQuery(req *http.Request) string {
queryValues := req.URL.Query()
for queryKey := range queryValues {
sort.Strings(queryValues[queryKey])
}
return queryValues.Encode()
}

func canonicalHeaders(req *http.Request) (string, string) {
// Header keys need to be sorted alphabetically.
var headers []string
lowerCaseHeaders := make(http.Header)
for k, v := range req.Header {
k := strings.ToLower(k)
if _, ok := lowerCaseHeaders[k]; ok {
// include additional values
lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
} else {
headers = append(headers, k)
lowerCaseHeaders[k] = v
}
}
sort.Strings(headers)

var fullHeaders strings.Builder
for _, header := range headers {
headerValue := strings.Join(lowerCaseHeaders[header], ",")
fullHeaders.WriteString(header)
fullHeaders.WriteRune(':')
fullHeaders.WriteString(headerValue)
fullHeaders.WriteRune('\n')
}

return strings.Join(headers, ";"), fullHeaders.String()
}

func requestDataHash(req *http.Request) (string, error) {
var requestData []byte
if req.Body != nil {
requestBody, err := req.GetBody()
if err != nil {
return "", err
}
defer requestBody.Close()

requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
if err != nil {
return "", err
}
}

return getSha256(requestData)
}

func requestHost(req *http.Request) string {
if req.Host != "" {
return req.Host
}
return req.URL.Host
}

func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
dataHash, err := requestDataHash(req)
if err != nil {
return "", err
}

return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
}

// SignRequest adds the appropriate headers to an http.Request
// or returns an error if something prevented this.
func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
signedRequest := cloneRequest(req)
timestamp := now()

signedRequest.Header.Add("host", requestHost(req))

if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
}

if signedRequest.Header.Get("date") == "" {
signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
}

authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
if err != nil {
return err
}
signedRequest.Header.Set("Authorization", authorizationCode)

req.Header = signedRequest.Header
return nil
}

func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
if !ok {
return "", errors.New("oauth2/google: missing secret_access_key header")
}
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
if !ok {
return "", errors.New("oauth2/google: missing access_key_id header")
}

canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)

dateStamp := timestamp.Format(awsTimeFormatShort)
serviceName := ""
if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
serviceName = splitHost[0]
}

credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)

requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
if err != nil {
return "", err
}
requestHash, err := getSha256([]byte(requestString))
if err != nil{
return "", err
}

stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)

signingKey := []byte("AWS4" + secretAccessKey)
for _, signingInput := range []string{
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
} {
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
if err != nil{
return "", err
}
}

return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
}
Loading