Skip to content

Commit

Permalink
Add web identify support
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Kropachev committed May 12, 2023
1 parent 1a1ac8e commit 92e55db
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
13 changes: 13 additions & 0 deletions cmd/aws-iam-authenticator/token.go
Expand Up @@ -38,6 +38,7 @@ var tokenCmd = &cobra.Command{
tokenOnly := viper.GetBool("tokenOnly")
forwardSessionName := viper.GetBool("forwardSessionName")
sessionName := viper.GetString("sessionName")
wiToken := viper.GetString("token")
cache := viper.GetBool("cache")

if clusterID == "" {
Expand All @@ -52,6 +53,15 @@ var tokenCmd = &cobra.Command{
os.Exit(1)
}

if wiToken != "" {
_, err := os.Stat(wiToken)
if os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: token path is provided, but there is no such file\n")
cmd.Usage()
os.Exit(1)
}
}

var tok token.Token
var out string
var err error
Expand All @@ -67,6 +77,7 @@ var tokenCmd = &cobra.Command{
AssumeRoleExternalID: externalID,
SessionName: sessionName,
Region: region,
Token: wiToken,
})
if err != nil {
fmt.Fprintf(os.Stderr, "could not get token: %v\n", err)
Expand All @@ -88,6 +99,7 @@ func init() {
tokenCmd.Flags().StringP("external-id", "e", "", "External ID to pass when assuming the IAM Role")
tokenCmd.Flags().StringP("session-name", "s", "", "Session name to pass when assuming the IAM Role")
tokenCmd.Flags().Bool("token-only", false, "Return only the token for use with Bearer token based tools")
tokenCmd.Flags().StringP("token", "t", "", "Path to a web identity token file or it's raw value")
tokenCmd.Flags().Bool("forward-session-name",
false,
"Enable mapping a federated sessions caller-specified-role-name attribute onto newly assumed sessions. NOTE: Only applicable when a new role is requested via --role")
Expand All @@ -98,6 +110,7 @@ func init() {
viper.BindPFlag("tokenOnly", tokenCmd.Flags().Lookup("token-only"))
viper.BindPFlag("forwardSessionName", tokenCmd.Flags().Lookup("forward-session-name"))
viper.BindPFlag("sessionName", tokenCmd.Flags().Lookup("session-name"))
viper.BindPFlag("token", tokenCmd.Flags().Lookup("token"))
viper.BindPFlag("cache", tokenCmd.Flags().Lookup("cache"))
viper.BindEnv("role", "DEFAULT_ROLE")
}
44 changes: 41 additions & 3 deletions pkg/token/token.go
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -109,6 +110,7 @@ type GetTokenOptions struct {
ClusterID string
AssumeRoleARN string
AssumeRoleExternalID string
Token string
SessionName string
Session *session.Session
}
Expand Down Expand Up @@ -276,9 +278,31 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
// use an STS client based on the direct credentials
stsAPI := sts.New(options.Session)

// if a roleARN was specified, replace the STS client with one that uses
// temporary credentials from that role.
if options.AssumeRoleARN != "" {
if options.AssumeRoleARN != "" && options.Token != "" {
// if a roleARN and Token were specified,
// replace the STS client with one that uses web identity temporary credentials from that role.

webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithOptions(
stsAPI,
options.AssumeRoleARN,
options.SessionName,
getTokenFetcher(options.Token),
)

// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
return Token{}, errors.Wrap(err, "failed to get web identity provider")
}

stsAPI = sts.New(
options.Session,
&aws.Config{
Credentials: credentials.NewCredentials(webIdentityProvider),
})
} else if options.AssumeRoleARN != "" {
// if a roleARN was specified, replace the STS client with one that uses
// temporary credentials from that role.
var sessionSetters []func(*stscreds.AssumeRoleProvider)

if options.AssumeRoleExternalID != "" {
Expand Down Expand Up @@ -627,3 +651,17 @@ func hasSignedClusterIDHeader(paramsLower *url.Values) bool {
}
return false
}

func getTokenFetcher(token string) stscreds.TokenFetcher {
if strings.HasPrefix(token, "/") || strings.HasPrefix(token, "./") {
return stscreds.FetchTokenPath(token)
}
return FetchTokenRaw(token)
}

type FetchTokenRaw string

// FetchToken returns a token by reading from the filesystem
func (f FetchTokenRaw) FetchToken(_ credentials.Context) ([]byte, error) {
return []byte(f), nil
}

0 comments on commit 92e55db

Please sign in to comment.