Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/3718.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
provider: Fixes STS region resolution when using cross-region authentication
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it'd have been better to create an associated cloudp ticket

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, thank you Leo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created CLOUDP-347906

```
24 changes: 21 additions & 3 deletions .github/workflows/acceptance-tests-runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,23 @@ jobs:
needs: [ change-detection, get-provider-version ]
if: ${{ needs.change-detection.outputs.assume_role == 'true' || inputs.test_group == 'assume_role' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test coverage!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 !!

include:
# Secret and STS Endpoint in same region
- name: same-region-us-east-1
aws_region: US_EAST_1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we lease a comment to explain what this aws_region is used for?

sts_endpoint: https://sts.us-east-1.amazonaws.com/
# Secret and STS Endpoint in different regions(Cross-region)
- name: cross-sts-us-east-1-secret-eu-north-1
aws_region: EU_NORTH_1
sts_endpoint: https://sts.us-east-1.amazonaws.com/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although the test group list is "polluted" instead of having only one entry for assume_role, i think it's good use case for matrix.

also we'll need to merge with changes in SA dev branch, we might need to remove the matrix approach if we want to keep all authentication tests together, but that's ok. cc @oarbusi

no need for action about this comment.

Image

# Global STS endpoint (signs as us-east-1), secrets in eu-west-1
- name: global-sts-secret-eu-west-1
aws_region: EU_WEST_1
sts_endpoint: https://sts.amazonaws.com
name: assume_role – ${{ matrix.name }}
permissions: {}
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8
Expand All @@ -476,19 +493,20 @@ jobs:
AWS_ACCESS_KEY_ID: ${{ secrets.aws_access_key_id }}
ASSUME_ROLE_ARN: ${{ vars.ASSUME_ROLE_ARN }}
run: bash ./scripts/generate-credentials-with-sts-assume-role.sh
- name: Acceptance Tests
- name: Acceptance Tests (matrix)
env:
MONGODB_ATLAS_PUBLIC_KEY: ""
MONGODB_ATLAS_PRIVATE_KEY: ""
ASSUME_ROLE_ARN: ${{ vars.ASSUME_ROLE_ARN }}
AWS_REGION: ${{ vars.AWS_REGION }}
STS_ENDPOINT: ${{ vars.STS_ENDPOINT }}
AWS_REGION: ${{ matrix.aws_region }}
STS_ENDPOINT: ${{ matrix.sts_endpoint }}
SECRET_NAME: ${{ inputs.aws_secret_name }}
AWS_ACCESS_KEY_ID: ${{ steps.sts-assume-role.outputs.aws_access_key_id }}
AWS_SECRET_ACCESS_KEY: ${{ steps.sts-assume-role.outputs.aws_secret_access_key }}
AWS_SESSION_TOKEN: ${{ steps.sts-assume-role.outputs.AWS_SESSION_TOKEN }}
MONGODB_ATLAS_LAST_VERSION: ${{ needs.get-provider-version.outputs.provider_version }}
ACCTEST_PACKAGES: ./internal/provider
ACCTEST_REGEX_RUN: ^TestAccSTSAssumeRole_basic$
run: make testacc

autogen:
Expand Down
80 changes: 56 additions & 24 deletions internal/provider/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"fmt"
"log"
"net/url"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
Expand All @@ -12,48 +14,39 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/sts"

"github.com/mongodb/terraform-provider-mongodbatlas/internal/config"
)

const (
endPointSTSDefault = "https://sts.amazonaws.com"
endPointSTSHostnameDefault = "sts.amazonaws.com"
DefaultRegionSTS = "us-east-1"
minSegmentsForSTSRegionalHost = 4
)

func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID, awsSecretAccessKey, awsSessionToken, endpoint string) (config.Config, error) {
ep, err := endpoints.GetSTSRegionalEndpoint("regional")
if err != nil {
log.Printf("GetSTSRegionalEndpoint error: %s", err)
return *cfg, err
}

defaultResolver := endpoints.DefaultResolver()
stsCustResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == endpoints.StsServiceID {
if endpoint == "" {
return endpoints.ResolvedEndpoint{
URL: endPointSTSDefault,
SigningRegion: region,
}, nil
stsCustResolverFn := func(service, _ string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == sts.EndpointsID {
resolved, err := ResolveSTSEndpoint(endpoint, region)
if err != nil {
return endpoints.ResolvedEndpoint{}, err
}
return endpoints.ResolvedEndpoint{
URL: endpoint,
SigningRegion: region,
}, nil
return resolved, nil
}

return defaultResolver.EndpointFor(service, region, optFns...)
}

sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken),
STSRegionalEndpoint: ep,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we don't need this/

EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn),
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken),
EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn),
}))

creds := stscreds.NewCredentials(sess, cfg.AssumeRole.RoleARN)

_, err = sess.Config.Credentials.Get()
_, err := sess.Config.Credentials.Get()
if err != nil {
log.Printf("Session get credentials error: %s", err)
return *cfg, err
Expand Down Expand Up @@ -87,6 +80,45 @@ func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID,
return *cfg, nil
}

func DeriveSTSRegionFromEndpoint(ep string) string {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance we could easily unit test this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet, you've done it- ignore me

if ep == "" {
return ""
}
u, err := url.Parse(ep)
if err != nil {
return DefaultRegionSTS
}
host := u.Hostname() // valid values: sts.us-west-2.amazonaws.com or sts.amazonaws.com
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the port if provided, correct? Not sure if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly, "Hostname returns u.Host, stripping any valid port number if present.". It's not strictly necessary but with it we make sure at this point we have either sts.us-west-2.amazonaws.com or sts.amazonaws.com


if host == endPointSTSHostnameDefault {
return DefaultRegionSTS
}

parts := strings.Split(host, ".")
if len(parts) >= minSegmentsForSTSRegionalHost && parts[0] == "sts" {
return parts[1]
}
return DefaultRegionSTS
}

func ResolveSTSEndpoint(stsEndpoint, secretsRegion string) (endpoints.ResolvedEndpoint, error) {
ep := stsEndpoint
if ep == "" {
r := secretsRegion
if r == "" {
r = DefaultRegionSTS
}
ep = fmt.Sprintf("https://sts.%s.amazonaws.com/", r)
}

signingRegion := DeriveSTSRegionFromEndpoint(ep)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the endpoint is "", no need to derive the region since we know it already.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if endpoint is "" the region can be secretsRegion or default region. I think we can keep this as is to avoid having repeated logic(e.g. if r == ""). WDYT?

Copy link
Collaborator

@manupedrozo manupedrozo Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something along the lines of:

ep := stsEndpoint
var signingRegion string
if ep == "" {
	signingRegion = secretsRegion
	if signingRegion == "" {			
       signingRegion = DefaultRegionSTS
	}
	ep = fmt.Sprintf("https://sts.%s.amazonaws.com/", signingRegion)
} else {
	signingRegion = DeriveSTSRegionFromEndpoint(ep)
}

But ok either way


return endpoints.ResolvedEndpoint{
URL: ep,
SigningRegion: signingRegion,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind to check if SigningRegion is really needed or URL is enough so we don't need to calculate the sts region?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, can we actually check how AWS does it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, SigningRegion is needed. URL alone isn’t sufficient.

  • The AWS signer uses the client’s region when SigningRegion is not set. In this case, the client region is the Secrets Manager region, which may not match the STS endpoint’s region.

  • For the global endpoint sts.amazonaws.com, requests must be signed with us-east-1; without setting SigningRegion, signatures will be computed with the client region and can fail.

  • For regional STS endpoints, the signature must match that region as well.

}, nil
}

func secretsManagerGetSecretValue(sess *session.Session, creds *aws.Config, secret string) (string, error) {
svc := secretsmanager.New(sess, creds)
input := &secretsmanager.GetSecretValueInput{
Expand Down
94 changes: 94 additions & 0 deletions internal/provider/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package provider_test

import (
"testing"

"github.com/mongodb/terraform-provider-mongodbatlas/internal/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_deriveSTSRegionFromEndpoint(t *testing.T) {
testCases := map[string]struct {
input string
expected string
}{
"empty endpoint": {
input: "",
expected: "",
},
"global endpoint": {
input: "https://sts.amazonaws.com",
expected: provider.DefaultRegionSTS,
},
"regional": {
input: "https://sts.us-east-1.amazonaws.com/",
expected: "us-east-1",
},
"regional eu-north-1": {
input: "https://sts.eu-north-1.amazonaws.com/",
expected: "eu-north-1",
},
"malformed url": {
input: "://not-a-url",
expected: provider.DefaultRegionSTS,
},
"unexpected host shape": {
input: "https://sts.something-weird",
expected: provider.DefaultRegionSTS,
},
}

for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()
got := provider.DeriveSTSRegionFromEndpoint(tc.input)
if got != tc.expected {
t.Fatalf("deriveSTSRegionFromEndpoint(%q) = %q; want %q", tc.input, got, tc.expected)
}
})
}
}

func Test_resolveSTSEndpoint(t *testing.T) {
testCases := map[string]struct {
stsEndpoint string
secretsRegion string
expectedURL string
expectedSign string
}{
"explicit regional endpoint": {
stsEndpoint: "https://sts.eu-north-1.amazonaws.com/",
secretsRegion: "us-east-1",
expectedURL: "https://sts.eu-north-1.amazonaws.com/",
expectedSign: "eu-north-1",
},
"global endpoint - us-east-1 signing": {
stsEndpoint: "https://sts.amazonaws.com",
secretsRegion: "eu-west-1",
expectedURL: "https://sts.amazonaws.com",
expectedSign: provider.DefaultRegionSTS,
},
"no endpoint - uses secrets region": {
stsEndpoint: "",
secretsRegion: "us-west-2",
expectedURL: "https://sts.us-west-2.amazonaws.com/",
expectedSign: "us-west-2",
},
"no endpoint and empty region": {
stsEndpoint: "",
secretsRegion: "",
expectedURL: "https://sts.us-east-1.amazonaws.com/",
expectedSign: provider.DefaultRegionSTS,
},
}

for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
ep, err := provider.ResolveSTSEndpoint(tc.stsEndpoint, tc.secretsRegion)
require.NoError(t, err)
assert.Equal(t, tc.expectedURL, ep.URL)
assert.Equal(t, tc.expectedSign, ep.SigningRegion)
})
}
}