forked from Versent/saml2aws
/
exec.go
143 lines (119 loc) · 3.99 KB
/
exec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package commands
import (
"fmt"
"log"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gngj/saml2aws/v2/pkg/awsconfig"
"github.com/gngj/saml2aws/v2/pkg/flags"
"github.com/gngj/saml2aws/v2/pkg/shell"
"github.com/pkg/errors"
)
// Exec execute the supplied command after seeding the environment
func Exec(execFlags *flags.LoginExecFlags, cmdline []string) error {
if len(cmdline) < 1 {
return fmt.Errorf("Command to execute required")
}
account, err := buildIdpAccount(execFlags)
if err != nil {
return errors.Wrap(err, "error building login details")
}
sharedCreds := awsconfig.NewSharedCredentials(account.Profile, account.CredentialsFile)
// this checks if the credentials file has been created yet
// can only really be triggered if saml2aws exec is run on a new
// system prior to creating $HOME/.aws
exist, err := sharedCreds.CredsExists()
if err != nil {
return errors.Wrap(err, "error loading credentials")
}
if !exist {
log.Println("unable to load credentials, login required to create them")
return nil
}
awsCreds, err := sharedCreds.Load()
if err != nil {
return errors.Wrap(err, "error loading credentials")
}
if time.Until(awsCreds.Expires) < 0 {
return errors.New("error aws credentials have expired")
}
ok, err := checkToken(account.Profile)
if err != nil {
return errors.Wrap(err, "error validating token")
}
if !ok {
err = Login(execFlags)
}
if err != nil {
return errors.Wrap(err, "error logging in")
}
if execFlags.ExecProfile != "" {
// Assume the desired role before generating env vars
awsCreds, err = assumeRoleWithProfile(execFlags.ExecProfile, execFlags.CommonFlags.SessionDuration)
if err != nil {
return errors.Wrap(err,
fmt.Sprintf("error acquiring credentials for profile: %s", execFlags.ExecProfile))
}
}
return shell.ExecShellCmd(cmdline, shell.BuildEnvVars(awsCreds, account, execFlags))
}
// assumeRoleWithProfile uses an AWS profile (via ~/.aws/config) and performs (multiple levels of) role assumption
// This is extremely useful in the case of a central "authentication account" which then requires secondary, and
// often tertiary, role assumptions to acquire credentials for the target role.
func assumeRoleWithProfile(targetProfile string, sessionDuration int) (*awsconfig.AWSCredentials, error) {
// AWS session config with verbose errors on chained credential errors
config := *aws.NewConfig().WithCredentialsChainVerboseErrors(true)
duration, _ := time.ParseDuration(strconv.Itoa(sessionDuration) + "s")
// a session forcing usage of the aws config file, sets the target profile which will be found in the config
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: config,
Profile: targetProfile,
SharedConfigState: session.SharedConfigEnable,
AssumeRoleDuration: duration,
}))
// use an STS client to perform the multiple role assumptions
stsClient := sts.New(sess)
input := &sts.GetCallerIdentityInput{}
_, err := stsClient.GetCallerIdentity(input)
if err != nil {
return nil, err
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
return nil, err
}
expiredAt, err := sess.Config.Credentials.ExpiresAt()
if err != nil {
return nil, err
}
return &awsconfig.AWSCredentials{
AWSAccessKey: creds.AccessKeyID,
AWSSecretKey: creds.SecretAccessKey,
AWSSessionToken: creds.SessionToken,
Expires: expiredAt,
}, nil
}
func checkToken(profile string) (bool, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Profile: profile,
})
if err != nil {
return false, err
}
svc := sts.New(sess)
params := &sts.GetCallerIdentityInput{}
_, err = svc.GetCallerIdentity(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == "ExpiredToken" || awsErr.Code() == "NoCredentialProviders" {
return false, nil
}
}
return false, err
}
return true, nil
}