Skip to content
This repository has been archived by the owner on Mar 4, 2024. It is now read-only.

Added support for v4 signing to S3. #27

Merged
merged 3 commits into from
Mar 1, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package aws
import (
"errors"
"os"
"strings"
)

// Region defines the URLs where AWS services may be accessed.
Expand All @@ -27,7 +28,13 @@ type Region struct {
SNSEndpoint string
SQSEndpoint string
IAMEndpoint string
Sign Signer // Method which will be used to sign requests.
}

func (r Region) ResolveS3BucketEndpoint(bucketName string) string {
if r.S3BucketEndpoint != "" {
return strings.ToLower(strings.Replace(r.S3BucketEndpoint, "${bucket}", bucketName, -1))
}
return strings.ToLower(r.S3Endpoint + "/" + bucketName + "/")
}

var USEast = Region{
Expand All @@ -41,7 +48,6 @@ var USEast = Region{
"https://sns.us-east-1.amazonaws.com",
"https://sqs.us-east-1.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var USWest = Region{
Expand All @@ -55,7 +61,6 @@ var USWest = Region{
"https://sns.us-west-1.amazonaws.com",
"https://sqs.us-west-1.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var USWest2 = Region{
Expand All @@ -69,7 +74,6 @@ var USWest2 = Region{
"https://sns.us-west-2.amazonaws.com",
"https://sqs.us-west-2.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var EUWest = Region{
Expand All @@ -83,7 +87,6 @@ var EUWest = Region{
"https://sns.eu-west-1.amazonaws.com",
"https://sqs.eu-west-1.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var APSoutheast = Region{
Expand All @@ -97,7 +100,6 @@ var APSoutheast = Region{
"https://sns.ap-southeast-1.amazonaws.com",
"https://sqs.ap-southeast-1.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var APSoutheast2 = Region{
Expand All @@ -111,7 +113,6 @@ var APSoutheast2 = Region{
"https://sns.ap-southeast-2.amazonaws.com",
"https://sqs.ap-southeast-2.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var APNortheast = Region{
Expand All @@ -125,7 +126,6 @@ var APNortheast = Region{
"https://sns.ap-northeast-1.amazonaws.com",
"https://sqs.ap-northeast-1.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var SAEast = Region{
Expand All @@ -139,7 +139,6 @@ var SAEast = Region{
"https://sns.sa-east-1.amazonaws.com",
"https://sqs.sa-east-1.amazonaws.com",
"https://iam.amazonaws.com",
SignV2,
}

var CNNorth = Region{
Expand All @@ -153,7 +152,6 @@ var CNNorth = Region{
"https://sns.cn-north-1.amazonaws.com.cn",
"https://sqs.cn-north-1.amazonaws.com.cn",
"https://iam.amazonaws.com.cn",
SignV4Factory("cn-north-1"),
}

var Regions = map[string]Region{
Expand Down
97 changes: 60 additions & 37 deletions aws/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,26 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"sort"
"strings"
"time"
)

var debug = log.New(
// Remove the c-style comment header to front of line to debug information.
/*os.Stdout, //*/ ioutil.Discard,
"DEBUG: ",
log.LstdFlags,
)

type Signer func(*http.Request, Auth) error

// Ensure our signers meet the interface
var _ Signer = SignV2
var _ Signer = SignV4Factory("")
var _ Signer = SignV4Factory("", "")

type hasher func([]byte) string

Expand Down Expand Up @@ -71,30 +79,30 @@ func SignV2(req *http.Request, auth Auth) (err error) {

// SignV4Factory returns a version 4 Signer which will utilize the
// given region name.
func SignV4Factory(regionName string) Signer {
func SignV4Factory(regionName, serviceName string) Signer {
return func(req *http.Request, auth Auth) error {
return SignV4(req, auth, regionName)
return SignV4(req, auth, regionName, serviceName)
}
}

// SignV4 signs an HTTP request utilizing version 4 of the AWS
// signature, and the given credentials.
func SignV4(req *http.Request, auth Auth, regionName string) (err error) {
func SignV4(req *http.Request, auth Auth, regionName, svcName string) (err error) {

var reqTime time.Time
if reqTime, err = requestTime(req); err != nil {
return err
}

svcName := inferServiceName(req.URL)
credScope := credentialScope(reqTime, regionName, svcName)
// Remove any existing authorization headers as they will corrupt
// the signing.
delete(req.Header, "Authorization")
delete(req.Header, "authorization")

// There are several places in the algorithm that call for
// processing the headers sorted by name.
sortedHdrNames := sortHeaderNames(req.Header)
credScope := credentialScope(reqTime, regionName, svcName)

var canonReqHash string
if _, canonReqHash, err = canonicalRequest(req, sortedHdrNames, sha256Hasher); err != nil {
_, canonReqHash, sortedHdrNames, err := canonicalRequest(req, sha256Hasher)
if err != nil {
return err
}

Expand All @@ -106,6 +114,8 @@ func SignV4(req *http.Request, auth Auth, regionName string) (err error) {
key := signingKey(reqTime, auth.SecretKey, regionName, svcName)
signature := fmt.Sprintf("%x", hmacHasher(key, strToSign))

debug.Printf("strToSign:\n\"\"\"\n%s\n\"\"\"", strToSign)

var authHdrVal string
if authHdrVal, err = authHeaderString(
req.Header,
Expand All @@ -126,20 +136,24 @@ func SignV4(req *http.Request, auth Auth, regionName string) (err error) {
// Returns the canonical request, and its hash.
func canonicalRequest(
req *http.Request,
sortedHdrNames []string,
hasher hasher,
) (canReq, canReqHash string, err error) {
) (canReq, canReqHash string, sortedHdrNames []string, err error) {

var canHdr string
if canHdr, err = canonicalHeaders(sortedHdrNames, req.Header); err != nil {
var payHash string
if payHash, err = payloadHash(req, hasher); err != nil {
return
}
req.Header.Set("x-amz-content-sha256", payHash)

var payHash string
if payHash, err = payloadHash(req, hasher); err != nil {
sortedHdrNames = sortHeaderNames(req.Header, "host")
var canHdr string
if canHdr, err = canonicalHeaders(sortedHdrNames, req.Host, req.Header); err != nil {
return
}

debug.Printf("canHdr:\n\"\"\"\n%s\n\"\"\"", canHdr)
debug.Printf("signedHeader: %s\n\n", strings.Join(sortedHdrNames, ";"))

var queryStr string
if queryStr, err = canonicalQueryString(req.URL.Query()); err != nil {
return
Expand All @@ -148,17 +162,18 @@ func canonicalRequest(
c := new(bytes.Buffer)
if err := errorCollector(
fprintfWrapper(c, "%s\n", requestMethodVerb(req.Method)),
fprintfWrapper(c, "%s\n", req.URL.RequestURI()),
fprintfWrapper(c, "%s\n", req.URL.Path),
fprintfWrapper(c, "%s\n", queryStr),
fprintfWrapper(c, "%s\n", canHdr),
fprintfWrapper(c, "%s\n", strings.Join(sortedHdrNames, ";")),
fprintfWrapper(c, "%s", payHash),
); err != nil {
return "", "", err
return "", "", nil, err
}

canReq = c.String()
return canReq, hasher([]byte(canReq)), nil
debug.Printf("canReq:\n\"\"\"\n%s\n\"\"\"", canReq)
return canReq, hasher([]byte(canReq)), sortedHdrNames, nil
}

// Task 2: Create a string to Sign
Expand Down Expand Up @@ -206,8 +221,8 @@ func authHeaderString(
w := new(bytes.Buffer)
if err := errorCollector(
fprintfWrapper(w, "AWS4-HMAC-SHA256 "),
fprintfWrapper(w, "Credential=%s/%s, ", accessKey, credScope),
fprintfWrapper(w, "SignedHeaders=%s, ", strings.Join(sortedHeaderNames, ";")),
fprintfWrapper(w, "Credential=%s/%s,", accessKey, credScope),
fprintfWrapper(w, "SignedHeaders=%s,", strings.Join(sortedHeaderNames, ";")),
fprintfWrapper(w, "Signature=%s", signature),
); err != nil {
return "", err
Expand All @@ -224,37 +239,49 @@ func canonicalQueryString(queryVals url.Values) (string, error) {
return strings.Replace(queryVals.Encode(), "+", "%20", -1), nil
}

func canonicalHeaders(sortedHeaderNames []string, hdr http.Header) (string, error) {
func canonicalHeaders(sortedHeaderNames []string, host string, hdr http.Header) (string, error) {
buffer := new(bytes.Buffer)

for _, hName := range sortedHeaderNames {
canonHdrKey := http.CanonicalHeaderKey(hName)
sortedHdrVals := hdr[canonHdrKey]
sort.Strings(sortedHdrVals)
hdrVals := strings.Join(sortedHdrVals, ",")

hdrVals := host
if hName != "host" {
canonHdrKey := http.CanonicalHeaderKey(hName)
sortedHdrVals := hdr[canonHdrKey]
sort.Strings(sortedHdrVals)
hdrVals = strings.Join(sortedHdrVals, ",")
}

if _, err := fmt.Fprintf(buffer, "%s:%s\n", hName, hdrVals); err != nil {
return "", err
}
}

// There is intentionally a hanging newline at the end of the
// header list.
return buffer.String(), nil
}

// Returns a SHA256 checksum of the request body. Represented as a
// lowercase hexadecimal string.
func payloadHash(req *http.Request, hasher hasher) (string, error) {
if b, err := ioutil.ReadAll(req.Body); err != nil {
if req.Body == nil {
return hasher([]byte("")), nil
}

b, err := ioutil.ReadAll(req.Body)
if err != nil {
return "", err
} else {
req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
return hasher(b), nil
}

req.Body = ioutil.NopCloser(bytes.NewBuffer(b))
return hasher(b), nil
}

// Retrieve the header names, lower-case them, and sort them.
func sortHeaderNames(header http.Header) []string {
func sortHeaderNames(header http.Header, injectedNames ...string) []string {

var sortedNames []string
sortedNames := injectedNames
for hName, _ := range header {
sortedNames = append(sortedNames, strings.ToLower(hName))
}
Expand All @@ -270,10 +297,6 @@ func hmacHasher(key []byte, value string) []byte {
return h.Sum(nil)
}

func inferServiceName(url *url.URL) string {
return strings.Split(url.Host, ".")[0]
}

func sha256Hasher(payload []byte) string {
return fmt.Sprintf("%x", sha256.Sum256(payload))
}
Expand Down
Loading