Skip to content

Commit

Permalink
Merge pull request #558 from kmala/release-0.6
Browse files Browse the repository at this point in the history
Add default instance region in sts hostname
  • Loading branch information
k8s-ci-robot committed Feb 15, 2023
2 parents 1ba7c36 + 5ef9dfa commit 54856dd
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 52 deletions.
11 changes: 10 additions & 1 deletion cmd/aws-iam-authenticator/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (

"sigs.k8s.io/aws-iam-authenticator/pkg/token"

"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
Expand All @@ -50,7 +52,14 @@ var verifyCmd = &cobra.Command{
os.Exit(1)
}

id, err := token.NewVerifier(clusterID, partition).Verify(tok)
sess := session.Must(session.NewSession())
ec2metadata := ec2metadata.New(sess)
instanceRegion, err := ec2metadata.Region()
if err != nil {
fmt.Printf("[Warn] Region not found in instance metadata, err: %v", err)
}

id, err := token.NewVerifier(clusterID, partition, instanceRegion).Verify(tok)
if err != nil {
fmt.Fprintf(os.Stderr, "could not verify token: %v\n", err)
os.Exit(1)
Expand Down
14 changes: 4 additions & 10 deletions pkg/ec2provider/ec2provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
Expand Down Expand Up @@ -61,7 +60,7 @@ type ec2ProviderImpl struct {
instanceIdsChannel chan string
}

func New(roleARN string, qps int, burst int) EC2Provider {
func New(roleARN, region string, qps int, burst int) EC2Provider {
dnsCache := ec2PrivateDNSCache{
cache: make(map[string]string),
lock: sync.RWMutex{},
Expand All @@ -71,7 +70,7 @@ func New(roleARN string, qps int, burst int) EC2Provider {
lock: sync.RWMutex{},
}
return &ec2ProviderImpl{
ec2: ec2.New(newSession(roleARN, qps, burst)),
ec2: ec2.New(newSession(roleARN, region, qps, burst)),
privateDNSCache: dnsCache,
ec2Requests: ec2Requests,
instanceIdsChannel: make(chan string, maxChannelSize),
Expand All @@ -82,20 +81,15 @@ func New(roleARN string, qps int, burst int) EC2Provider {
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role.

func newSession(roleARN string, qps int, burst int) *session.Session {
func newSession(roleARN, region string, qps int, burst int) *session.Session {
sess := session.Must(session.NewSession())
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Fn: request.MakeAddToUserAgentHandler(
"aws-iam-authenticator", pkg.Version),
})
if aws.StringValue(sess.Config.Region) == "" {
ec2metadata := ec2metadata.New(sess)
regionFound, err := ec2metadata.Region()
if err != nil {
logrus.WithError(err).Fatal("Region not found in shared credentials, environment variable, or instance metadata.")
}
sess.Config.Region = aws.String(regionFound)
sess.Config.Region = aws.String(region)
}

if roleARN != "" {
Expand Down
12 changes: 10 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"sigs.k8s.io/aws-iam-authenticator/pkg/config"
"sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider"
"sigs.k8s.io/aws-iam-authenticator/pkg/mapper"
Expand Down Expand Up @@ -167,10 +169,16 @@ func (c *Server) getHandler(mappers []mapper.Mapper, ec2DescribeQps int, ec2Desc
panic(fmt.Sprintf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN))
}
}
sess := session.Must(session.NewSession())
ec2metadata := ec2metadata.New(sess)
instanceRegion, err := ec2metadata.Region()
if err != nil {
logrus.WithError(err).Errorln("Region not found in instance metadata.")
}

h := &handler{
verifier: token.NewVerifier(c.ClusterID, c.PartitionID),
ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, ec2DescribeQps, ec2DescribeBurst),
verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion),
ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst),
clusterID: c.ClusterID,
mappers: mappers,
scrubbedAccounts: c.Config.ScrubbedAWSAccounts,
Expand Down
31 changes: 26 additions & 5 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ const (
dateHeaderFormat = "20060102T150405Z"
kindExecCredential = "ExecCredential"
execInfoEnvKey = "KUBERNETES_EXEC_INFO"
stsServiceID = "sts"
)

// Token is generated and used by Kubernetes client-go to authenticate with a Kubernetes cluster.
Expand Down Expand Up @@ -377,7 +378,7 @@ type tokenVerifier struct {
validSTShostnames map[string]bool
}

func stsHostsForPartition(partitionID string) map[string]bool {
func stsHostsForPartition(partitionID, region string) map[string]bool {
validSTShostnames := map[string]bool{}

var partition *endpoints.Partition
Expand All @@ -391,12 +392,14 @@ func stsHostsForPartition(partitionID string) map[string]bool {
logrus.Errorf("Partition %s not valid", partitionID)
return validSTShostnames
}
stsSvc, ok := partition.Services()["sts"]

stsSvc, ok := partition.Services()[stsServiceID]
if !ok {
logrus.Errorf("STS service not found in partition %s", partitionID)
return validSTShostnames
}
for epName, ep := range stsSvc.Endpoints() {
stsSvcEndPoints := stsSvc.Endpoints()
for epName, ep := range stsSvcEndPoints {
rep, err := ep.ResolveEndpoint(endpoints.STSRegionalEndpointOption)
if err != nil {
logrus.WithError(err).Errorf("Error resolving endpoint for %s in partition %s", epName, partitionID)
Expand All @@ -409,24 +412,42 @@ func stsHostsForPartition(partitionID string) map[string]bool {
}
validSTShostnames[parsedURL.Hostname()] = true
}

// Add the host of the current instances region if not already exists so we don't fail if the region is not
// present in the go sdk but matches the instances region.
if _, ok := stsSvcEndPoints[region]; !ok {
rep, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption)
if err != nil {
logrus.WithError(err).Errorf("Error resolving endpoint for %s in partition %s", region, partitionID)
return validSTShostnames
}
parsedURL, err := url.Parse(rep.URL)
if err != nil {
logrus.WithError(err).Errorf("Error parsing STS URL %s", rep.URL)
return validSTShostnames
}
validSTShostnames[parsedURL.Hostname()] = true
}

return validSTShostnames
}

// NewVerifier creates a Verifier that is bound to the clusterID and uses the default http client.
func NewVerifier(clusterID string, partitionID string) Verifier {
func NewVerifier(clusterID, partitionID, region string) Verifier {
// Initialize metrics if they haven't already been initialized to avoid a
// nil pointer panic when setting metric values.
if !metrics.Initialized() {
metrics.InitMetrics(prometheus.NewRegistry())
}

return tokenVerifier{
client: &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
clusterID: clusterID,
validSTShostnames: stsHostsForPartition(partitionID),
validSTShostnames: stsHostsForPartition(partitionID, region),
}
}

Expand Down
70 changes: 36 additions & 34 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestMain(m *testing.M) {
func validationErrorTest(t *testing.T, partition string, token string, expectedErr string) {
t.Helper()

_, err := NewVerifier("", partition).(tokenVerifier).Verify(token)
_, err := NewVerifier("", partition, "").(tokenVerifier).Verify(token)
errorContains(t, err, expectedErr)
}

Expand Down Expand Up @@ -86,7 +86,7 @@ func newVerifier(partition string, statusCode int, body string, err error) Verif
},
},
},
validSTShostnames: stsHostsForPartition(partition),
validSTShostnames: stsHostsForPartition(partition, ""),
}
}

Expand Down Expand Up @@ -124,40 +124,42 @@ func TestSTSEndpoints(t *testing.T) {
partition string
domain string
valid bool
region string
}{
{"aws-cn", "sts.cn-northwest-1.amazonaws.com.cn", true},
{"aws-cn", "sts.cn-north-1.amazonaws.com.cn", true},
{"aws-cn", "sts.us-iso-east-1.c2s.ic.gov", false},
{"aws", "sts.amazonaws.com", true},
{"aws", "sts-fips.us-west-2.amazonaws.com", true},
{"aws", "sts-fips.us-east-1.amazonaws.com", true},
{"aws", "sts.us-east-1.amazonaws.com", true},
{"aws", "sts.us-east-2.amazonaws.com", true},
{"aws", "sts.us-west-1.amazonaws.com", true},
{"aws", "sts.us-west-2.amazonaws.com", true},
{"aws", "sts.ap-south-1.amazonaws.com", true},
{"aws", "sts.ap-northeast-1.amazonaws.com", true},
{"aws", "sts.ap-northeast-2.amazonaws.com", true},
{"aws", "sts.ap-southeast-1.amazonaws.com", true},
{"aws", "sts.ap-southeast-2.amazonaws.com", true},
{"aws", "sts.ca-central-1.amazonaws.com", true},
{"aws", "sts.eu-central-1.amazonaws.com", true},
{"aws", "sts.eu-west-1.amazonaws.com", true},
{"aws", "sts.eu-west-2.amazonaws.com", true},
{"aws", "sts.eu-west-3.amazonaws.com", true},
{"aws", "sts.eu-north-1.amazonaws.com", true},
{"aws", "sts.amazonaws.com.cn", false},
{"aws", "sts.not-a-region.amazonaws.com", false},
{"aws-iso", "sts.us-iso-east-1.c2s.ic.gov", true},
{"aws-iso", "sts.cn-north-1.amazonaws.com.cn", false},
{"aws-iso-b", "sts.cn-north-1.amazonaws.com.cn", false},
{"aws-us-gov", "sts.us-gov-east-1.amazonaws.com", true},
{"aws-us-gov", "sts.amazonaws.com", false},
{"aws-not-a-partition", "sts.amazonaws.com", false},
{"aws-cn", "sts.cn-northwest-1.amazonaws.com.cn", true, ""},
{"aws-cn", "sts.cn-north-1.amazonaws.com.cn", true, ""},
{"aws-cn", "sts.us-iso-east-1.c2s.ic.gov", false, ""},
{"aws", "sts.amazonaws.com", true, ""},
{"aws", "sts-fips.us-west-2.amazonaws.com", true, ""},
{"aws", "sts-fips.us-east-1.amazonaws.com", true, ""},
{"aws", "sts.us-east-1.amazonaws.com", true, ""},
{"aws", "sts.us-east-2.amazonaws.com", true, ""},
{"aws", "sts.us-west-1.amazonaws.com", true, ""},
{"aws", "sts.us-west-2.amazonaws.com", true, ""},
{"aws", "sts.ap-south-1.amazonaws.com", true, ""},
{"aws", "sts.ap-northeast-1.amazonaws.com", true, ""},
{"aws", "sts.ap-northeast-2.amazonaws.com", true, ""},
{"aws", "sts.ap-southeast-1.amazonaws.com", true, ""},
{"aws", "sts.ap-southeast-2.amazonaws.com", true, ""},
{"aws", "sts.ca-central-1.amazonaws.com", true, ""},
{"aws", "sts.eu-central-1.amazonaws.com", true, ""},
{"aws", "sts.eu-west-1.amazonaws.com", true, ""},
{"aws", "sts.eu-west-2.amazonaws.com", true, ""},
{"aws", "sts.eu-west-3.amazonaws.com", true, ""},
{"aws", "sts.eu-north-1.amazonaws.com", true, ""},
{"aws", "sts.amazonaws.com.cn", false, ""},
{"aws", "sts.not-a-region.amazonaws.com", false, ""},
{"aws", "sts.default-region.amazonaws.com", true, "default-region"},
{"aws-iso", "sts.us-iso-east-1.c2s.ic.gov", true, ""},
{"aws-iso", "sts.cn-north-1.amazonaws.com.cn", false, ""},
{"aws-iso-b", "sts.cn-north-1.amazonaws.com.cn", false, ""},
{"aws-us-gov", "sts.us-gov-east-1.amazonaws.com", true, ""},
{"aws-us-gov", "sts.amazonaws.com", false, ""},
{"aws-not-a-partition", "sts.amazonaws.com", false, ""},
}

for _, c := range cases {
verifier := NewVerifier("", c.partition).(tokenVerifier)
verifier := NewVerifier("", c.partition, c.region).(tokenVerifier)
if err := verifier.verifyHost(c.domain); err != nil && c.valid {
t.Errorf("%s is not valid endpoint for partition %s", c.domain, c.partition)
}
Expand Down Expand Up @@ -215,7 +217,7 @@ func TestVerifyNoRedirectsFollowed(t *testing.T) {
}))
defer ts.Close()

tokVerifier := NewVerifier("", "aws").(tokenVerifier)
tokVerifier := NewVerifier("", "aws", "").(tokenVerifier)

resp, err := tokVerifier.client.Get(ts.URL)
if err != nil {
Expand All @@ -242,7 +244,7 @@ func TestVerifyBodyReadError(t *testing.T) {
},
},
},
validSTShostnames: stsHostsForPartition("aws"),
validSTShostnames: stsHostsForPartition("aws", ""),
}
_, err := verifier.Verify(validToken)
errorContains(t, err, "error reading HTTP result")
Expand Down

0 comments on commit 54856dd

Please sign in to comment.