Skip to content

Commit

Permalink
signature: use region from Auth header if server's region not configu…
Browse files Browse the repository at this point in the history
…red (#4329)
  • Loading branch information
Krishna Srinivas authored and harshavardhana committed May 16, 2017
1 parent 465274c commit 5db1e9f
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 108 deletions.
59 changes: 36 additions & 23 deletions cmd/bucket-notification-utils.go
Expand Up @@ -17,6 +17,7 @@
package cmd

import (
"errors"
"strings"

"github.com/minio/minio-go/pkg/set"
Expand Down Expand Up @@ -111,20 +112,21 @@ func checkARN(arn, arnType string) APIErrorCode {
if !strings.HasPrefix(arn, arnType) {
return ErrARNNotification
}
if !strings.HasPrefix(arn, arnType+serverConfig.GetRegion()+":") {
return ErrRegionNotification
}
account := strings.SplitN(strings.TrimPrefix(arn, arnType+serverConfig.GetRegion()+":"), ":", 2)
switch len(account) {
case 1:
// This means ARN is malformed, account should have min of 2elements.
strs := strings.SplitN(arn, ":", -1)
if len(strs) != 6 {
return ErrARNNotification
case 2:
// Account topic id or topic name cannot be empty.
if account[0] == "" || account[1] == "" {
return ErrARNNotification
}
if serverConfig.GetRegion() != "" {
region := strs[3]
if region != serverConfig.GetRegion() {
return ErrRegionNotification
}
}
accountID := strs[4]
resource := strs[5]
if accountID == "" || resource == "" {
return ErrARNNotification
}
return ErrNone
}

Expand Down Expand Up @@ -258,28 +260,39 @@ func validateNotificationConfig(nConfig notificationConfig) APIErrorCode {
// - webhook
func unmarshalSqsARN(queueARN string) (mSqs arnSQS) {
mSqs = arnSQS{}
if !strings.HasPrefix(queueARN, minioSqs+serverConfig.GetRegion()+":") {
strs := strings.SplitN(queueARN, ":", -1)
if len(strs) != 6 {
return mSqs
}
sqsType := strings.TrimPrefix(queueARN, minioSqs+serverConfig.GetRegion()+":")
switch {
case hasSuffix(sqsType, queueTypeAMQP):
if serverConfig.GetRegion() != "" {
region := strs[3]
if region != serverConfig.GetRegion() {
return mSqs
}
}
sqsType := strs[5]
switch sqsType {
case queueTypeAMQP:
mSqs.Type = queueTypeAMQP
case hasSuffix(sqsType, queueTypeNATS):
case queueTypeNATS:
mSqs.Type = queueTypeNATS
case hasSuffix(sqsType, queueTypeElastic):
case queueTypeElastic:
mSqs.Type = queueTypeElastic
case hasSuffix(sqsType, queueTypeRedis):
case queueTypeRedis:
mSqs.Type = queueTypeRedis
case hasSuffix(sqsType, queueTypePostgreSQL):
case queueTypePostgreSQL:
mSqs.Type = queueTypePostgreSQL
case hasSuffix(sqsType, queueTypeMySQL):
case queueTypeMySQL:
mSqs.Type = queueTypeMySQL
case hasSuffix(sqsType, queueTypeKafka):
case queueTypeKafka:
mSqs.Type = queueTypeKafka
case hasSuffix(sqsType, queueTypeWebhook):
case queueTypeWebhook:
mSqs.Type = queueTypeWebhook
default:
errorIf(errors.New("invalid SQS type"), "SQS type: %s", sqsType)
} // Add more queues here.
mSqs.AccountID = strings.TrimSuffix(sqsType, ":"+mSqs.Type)

mSqs.AccountID = strs[4]

return mSqs
}
70 changes: 60 additions & 10 deletions cmd/bucket-notification-utils_test.go
Expand Up @@ -259,11 +259,6 @@ func TestQueueARN(t *testing.T) {
queueARN: "arn:minio:sns:us-east-1:1:listen",
errCode: ErrARNNotification,
},
// Invalid region 'us-west-1' in queue arn.
{
queueARN: "arn:minio:sqs:us-west-1:1:redis",
errCode: ErrRegionNotification,
},
// Invalid queue name empty in queue arn.
{
queueARN: "arn:minio:sqs:us-east-1:1:",
Expand Down Expand Up @@ -298,6 +293,37 @@ func TestQueueARN(t *testing.T) {
t.Errorf("Test %d: Expected \"%d\", got \"%d\"", i+1, testCase.errCode, errCode)
}
}

// Test when server region is set.
rootPath, err = newTestConfig("us-east-1")
if err != nil {
t.Fatalf("unable initialize config file, %s", err)
}
defer removeAll(rootPath)

testCases = []struct {
queueARN string
errCode APIErrorCode
}{
// Incorrect region should produce error.
{
queueARN: "arn:minio:sqs:us-west-1:1:webhook",
errCode: ErrRegionNotification,
},
// Correct region should not produce error.
{
queueARN: "arn:minio:sqs:us-east-1:1:webhook",
errCode: ErrNone,
},
}

// Validate all tests for queue arn.
for i, testCase := range testCases {
errCode := checkQueueARN(testCase.queueARN)
if testCase.errCode != errCode {
t.Errorf("Test %d: Expected \"%d\", got \"%d\"", i+1, testCase.errCode, errCode)
}
}
}

// Test unmarshal queue arn.
Expand Down Expand Up @@ -337,11 +363,6 @@ func TestUnmarshalSQSARN(t *testing.T) {
queueARN: "",
Type: "",
},
// Invalid region 'us-west-1' in queue arn.
{
queueARN: "arn:minio:sqs:us-west-1:1:redis",
Type: "",
},
// Partial queue arn.
{
queueARN: "arn:minio:sqs:",
Expand All @@ -361,4 +382,33 @@ func TestUnmarshalSQSARN(t *testing.T) {
}
}

// Test when the server region is set.
rootPath, err = newTestConfig("us-east-1")
if err != nil {
t.Fatalf("unable initialize config file, %s", err)
}
defer removeAll(rootPath)

testCases = []struct {
queueARN string
Type string
}{
// Incorrect region in ARN returns empty mSqs.Type
{
queueARN: "arn:minio:sqs:us-west-1:1:webhook",
Type: "",
},
// Correct regionin ARN returns valid mSqs.Type
{
queueARN: "arn:minio:sqs:us-east-1:1:webhook",
Type: "webhook",
},
}

for i, testCase := range testCases {
mSqs := unmarshalSqsARN(testCase.queueARN)
if testCase.Type != mSqs.Type {
t.Errorf("Test %d: Expected \"%s\", got \"%s\"", i+1, testCase.Type, mSqs.Type)
}
}
}
5 changes: 0 additions & 5 deletions cmd/config-v18.go
Expand Up @@ -261,11 +261,6 @@ func getValidConfig() (*serverConfigV18, error) {
return nil, err
}

// Validate region field
if srvCfg.Region == "" {
return nil, errors.New("Region config value cannot be empty")
}

// Validate credential fields only when
// they are not set via the environment

Expand Down
2 changes: 1 addition & 1 deletion cmd/globals.go
Expand Up @@ -29,7 +29,7 @@ import (
const (
globalMinioCertExpireWarnDays = time.Hour * 24 * 30 // 30 days.

globalMinioDefaultRegion = "us-east-1"
globalMinioDefaultRegion = ""
globalMinioDefaultOwnerID = "minio"
globalMinioDefaultStorageClass = "STANDARD"
globalWindowsOSName = "windows"
Expand Down
4 changes: 2 additions & 2 deletions cmd/handler-utils.go
Expand Up @@ -38,15 +38,15 @@ func parseLocationConstraint(r *http.Request) (location string, s3Error APIError
} // else for both err as nil or io.EOF
location = locationConstraint.Location
if location == "" {
location = globalMinioDefaultRegion
location = serverConfig.GetRegion()
}
return location, ErrNone
}

// Validates input location is same as configured region
// of Minio server.
func isValidLocation(location string) bool {
return serverConfig.GetRegion() == location
return serverConfig.GetRegion() == "" || serverConfig.GetRegion() == location
}

// Supported headers that needs to be extracted.
Expand Down
26 changes: 15 additions & 11 deletions cmd/post-policy_test.go
Expand Up @@ -246,6 +246,7 @@ func testPostPolicyBucketHandler(obj ObjectLayer, instanceType string, t TestErr
}
}

region := "us-east-1"
// Test cases for signature-V4.
testCasesV4BadData := []struct {
objectName string
Expand Down Expand Up @@ -330,7 +331,7 @@ func testPostPolicyBucketHandler(obj ObjectLayer, instanceType string, t TestErr
testCase.policy = fmt.Sprintf(testCase.policy, testCase.dates...)

req, perr := newPostRequestV4Generic("", bucketName, testCase.objectName, testCase.data, testCase.accessKey,
testCase.secretKey, curTime, []byte(testCase.policy), nil, testCase.corruptedBase64, testCase.corruptedMultipart)
testCase.secretKey, region, curTime, []byte(testCase.policy), nil, testCase.corruptedBase64, testCase.corruptedMultipart)
if perr != nil {
t.Fatalf("Test %d: %s: Failed to create HTTP request for PostPolicyHandler: <ERROR> %v", i+1, instanceType, perr)
}
Expand Down Expand Up @@ -473,9 +474,10 @@ func testPostPolicyBucketHandlerRedirect(obj ObjectLayer, instanceType string, t
// Generate the final policy document
policy = fmt.Sprintf(policy, dates...)

region := "us-east-1"
// Create a new POST request with success_action_redirect field specified
req, perr := newPostRequestV4Generic("", bucketName, keyName, []byte("objData"),
credentials.AccessKey, credentials.SecretKey, curTime,
credentials.AccessKey, credentials.SecretKey, region, curTime,
[]byte(policy), map[string]string{"success_action_redirect": redirectURL.String()}, false, false)

if perr != nil {
Expand Down Expand Up @@ -565,11 +567,11 @@ func newPostRequestV2(endPoint, bucketName, objectName string, accessKey, secret
return req, nil
}

func buildGenericPolicy(t time.Time, accessKey, bucketName, objectName string, contentLengthRange bool) []byte {
func buildGenericPolicy(t time.Time, accessKey, region, bucketName, objectName string, contentLengthRange bool) []byte {
// Expire the request five minutes from now.
expirationTime := t.Add(time.Minute * 5)

credStr := getCredentialString(accessKey, serverConfig.GetRegion(), t)
credStr := getCredentialString(accessKey, region, t)
// Create a new post policy.
policy := newPostPolicyBytesV4(credStr, bucketName, objectName, expirationTime)
if contentLengthRange {
Expand All @@ -578,10 +580,10 @@ func buildGenericPolicy(t time.Time, accessKey, bucketName, objectName string, c
return policy
}

func newPostRequestV4Generic(endPoint, bucketName, objectName string, objData []byte, accessKey, secretKey string,
func newPostRequestV4Generic(endPoint, bucketName, objectName string, objData []byte, accessKey, secretKey string, region string,
t time.Time, policy []byte, addFormData map[string]string, corruptedB64 bool, corruptedMultipart bool) (*http.Request, error) {
// Get the user credential.
credStr := getCredentialString(accessKey, serverConfig.GetRegion(), t)
credStr := getCredentialString(accessKey, region, t)

// Only need the encoding.
encodedPolicy := base64.StdEncoding.EncodeToString(policy)
Expand All @@ -591,7 +593,7 @@ func newPostRequestV4Generic(endPoint, bucketName, objectName string, objData []
}

// Presign with V4 signature based on the policy.
signature := postPresignSignatureV4(encodedPolicy, t, secretKey, serverConfig.GetRegion())
signature := postPresignSignatureV4(encodedPolicy, t, secretKey, region)

formData := map[string]string{
"bucket": bucketName,
Expand Down Expand Up @@ -645,12 +647,14 @@ func newPostRequestV4Generic(endPoint, bucketName, objectName string, objData []

func newPostRequestV4WithContentLength(endPoint, bucketName, objectName string, objData []byte, accessKey, secretKey string) (*http.Request, error) {
t := UTCNow()
policy := buildGenericPolicy(t, accessKey, bucketName, objectName, true)
return newPostRequestV4Generic(endPoint, bucketName, objectName, objData, accessKey, secretKey, t, policy, nil, false, false)
region := "us-east-1"
policy := buildGenericPolicy(t, accessKey, region, bucketName, objectName, true)
return newPostRequestV4Generic(endPoint, bucketName, objectName, objData, accessKey, secretKey, region, t, policy, nil, false, false)
}

func newPostRequestV4(endPoint, bucketName, objectName string, objData []byte, accessKey, secretKey string) (*http.Request, error) {
t := UTCNow()
policy := buildGenericPolicy(t, accessKey, bucketName, objectName, false)
return newPostRequestV4Generic(endPoint, bucketName, objectName, objData, accessKey, secretKey, t, policy, nil, false, false)
region := "us-east-1"
policy := buildGenericPolicy(t, accessKey, region, bucketName, objectName, false)
return newPostRequestV4Generic(endPoint, bucketName, objectName, objData, accessKey, secretKey, region, t, policy, nil, false, false)
}
8 changes: 5 additions & 3 deletions cmd/server-startup-msg.go
Expand Up @@ -79,7 +79,9 @@ func printServerCommonMsg(apiEndpoints []string) {
log.Println(colorBlue("\nEndpoint: ") + colorBold(fmt.Sprintf(getFormatStr(len(apiEndpointStr), 1), apiEndpointStr)))
log.Println(colorBlue("AccessKey: ") + colorBold(fmt.Sprintf("%s ", cred.AccessKey)))
log.Println(colorBlue("SecretKey: ") + colorBold(fmt.Sprintf("%s ", cred.SecretKey)))
log.Println(colorBlue("Region: ") + colorBold(fmt.Sprintf(getFormatStr(len(region), 3), region)))
if region != "" {
log.Println(colorBlue("Region: ") + colorBold(fmt.Sprintf(getFormatStr(len(region), 3), region)))
}
printEventNotifiers()

log.Println(colorBlue("\nBrowser Access:"))
Expand All @@ -92,12 +94,12 @@ func printEventNotifiers() {
// In case initEventNotifier() was not done or failed.
return
}
arnMsg := colorBlue("SQS ARNs: ")
// Get all configured external notification targets
externalTargets := globalEventNotifier.GetAllExternalTargets()
if len(externalTargets) == 0 {
arnMsg += colorBold(fmt.Sprintf(getFormatStr(len("<none>"), 1), "<none>"))
return
}
arnMsg := colorBlue("SQS ARNs: ")
for queueArn := range externalTargets {
arnMsg += colorBold(fmt.Sprintf(getFormatStr(len(queueArn), 1), queueArn))
}
Expand Down
3 changes: 0 additions & 3 deletions cmd/signature-v4-parser.go
Expand Up @@ -69,9 +69,6 @@ func parseCredentialHeader(credElement string) (credentialHeader, APIErrorCode)
if e != nil {
return credentialHeader{}, ErrMalformedCredentialDate
}
if credElements[2] == "" {
return credentialHeader{}, ErrMalformedCredentialRegion
}
cred.scope.region = credElements[2]
if credElements[3] != "s3" {
return credentialHeader{}, ErrInvalidService
Expand Down
19 changes: 3 additions & 16 deletions cmd/signature-v4-parser_test.go
Expand Up @@ -141,19 +141,6 @@ func TestParseCredentialHeader(t *testing.T) {
expectedErrCode: ErrMalformedCredentialDate,
},
// Test Case - 6.
// Test case with invalid region.
// region should a non empty string.
{
inputCredentialStr: generateCredentialStr(
"Z7IXGOO6BZ0REAN1Q26I",
UTCNow().Format(yyyymmdd),
"",
"ABCD",
"ABCD"),
expectedCredentials: credentialHeader{},
expectedErrCode: ErrMalformedCredentialRegion,
},
// Test Case - 7.
// Test case with invalid service.
// "s3" is the valid service string.
{
Expand All @@ -166,7 +153,7 @@ func TestParseCredentialHeader(t *testing.T) {
expectedCredentials: credentialHeader{},
expectedErrCode: ErrInvalidService,
},
// Test Case - 8.
// Test Case - 7.
// Test case with invalid request version.
// "aws4_request" is the valid request version.
{
Expand All @@ -179,7 +166,7 @@ func TestParseCredentialHeader(t *testing.T) {
expectedCredentials: credentialHeader{},
expectedErrCode: ErrInvalidRequestVersion,
},
// Test Case - 9.
// Test Case - 8.
// Test case with right inputs. Expected to return a valid CredentialHeader.
// "aws4_request" is the valid request version.
{
Expand All @@ -204,7 +191,7 @@ func TestParseCredentialHeader(t *testing.T) {
actualCredential, actualErrCode := parseCredentialHeader(testCase.inputCredentialStr)
// validating the credential fields.
if testCase.expectedErrCode != actualErrCode {
t.Fatalf("Test %d: Expected the APIErrCode to be %d, got %d", i+1, testCase.expectedErrCode, actualErrCode)
t.Fatalf("Test %d: Expected the APIErrCode to be %s, got %s", i+1, errorCodeResponse[testCase.expectedErrCode].Code, errorCodeResponse[actualErrCode].Code)
}
if actualErrCode == ErrNone {
validateCredentialfields(t, i+1, testCase.expectedCredentials, actualCredential)
Expand Down

0 comments on commit 5db1e9f

Please sign in to comment.