Skip to content

Commit

Permalink
Update credshelper package (#572)
Browse files Browse the repository at this point in the history
- Made externalTokenSource implement PerRPC creds interface as well
- Added headers and tests related to this
  • Loading branch information
banikharbanda committed May 30, 2024
1 parent 89b6d6b commit 47bbbf3
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 51 deletions.
111 changes: 69 additions & 42 deletions go/pkg/credshelper/credshelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package credshelper

import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -83,9 +84,13 @@ type Credentials struct {

// externaltokenSource uses a credentialsHelper to obtain gcp oauth tokens.
// This should be wrapped in a "golang.org/x/oauth2".ReuseTokenSource
// to avoid obtaining new tokens each time.
// to avoid obtaining new tokens each time. It implements both the
// oauth2.TokenSource and credentials.PerRPCCredentials interfaces.
type externalTokenSource struct {
credsHelperCmd *reusableCmd
headers map[string]string
expiry time.Time
headersLock sync.RWMutex
}

func buildExternalCredentials(baseCreds cachedCredentials, credsFile string, credsHelperCmd *reusableCmd) *Credentials {
Expand Down Expand Up @@ -166,11 +171,32 @@ func (ts *externalTokenSource) Token() (*oauth2.Token, error) {
if ts == nil {
return nil, fmt.Errorf("empty tokensource")
}
tk, _, err := runCredsHelperCmd(ts.credsHelperCmd)
if err == nil {
log.Infof("'%s' credentials refreshed at %v, expires at %v", ts.credsHelperCmd, time.Now(), tk.Expiry)
credsOut, err := runCredsHelperCmd(ts.credsHelperCmd)
if err != nil {
return nil, err
}
log.Infof("'%s' credentials refreshed at %v, expires at %v", ts.credsHelperCmd, time.Now(), credsOut.tk.Expiry)
return credsOut.tk, err
}

// GetRequestMetadata gets the current request metadata, refreshing tokens if required.
func (ts *externalTokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
ts.headersLock.RLock()
defer ts.headersLock.RUnlock()
if ts.expiry.Before(nowFn()) {
credsOut, err := runCredsHelperCmd(ts.credsHelperCmd)
if err != nil {
return nil, err
}
ts.expiry = credsOut.tk.Expiry
ts.headers = credsOut.hdrs
}
return tk, err
return ts.headers, nil
}

// RequireTransportSecurity indicates whether the credentials require transport security.
func (ts *externalTokenSource) RequireTransportSecurity() bool {
return true
}

// NewExternalCredentials creates credentials obtained from a credshelper.
Expand All @@ -190,14 +216,20 @@ func NewExternalCredentials(credshelper string, credshelperArgs []string, credsF
}
log.Warningf("Failed to use cached credentials: %v", err)
}
tk, rexp, err := runCredsHelperCmd(credsHelperCmd)
credsOut, err := runCredsHelperCmd(credsHelperCmd)
if err != nil {
return nil, err
}
return buildExternalCredentials(cachedCredentials{token: tk, refreshExp: rexp}, credsFile, credsHelperCmd), nil
return buildExternalCredentials(cachedCredentials{token: credsOut.tk, refreshExp: credsOut.rexp}, credsFile, credsHelperCmd), nil
}

type credshelperOutput struct {
hdrs map[string]string
tk *oauth2.Token
rexp time.Time
}

func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*oauth2.Token, time.Time, error) {
func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*credshelperOutput, error) {
log.V(2).Infof("Running %v", credsHelperCmd)
var stdout, stderr bytes.Buffer
cmd := credsHelperCmd.Cmd()
Expand All @@ -209,53 +241,48 @@ func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*oauth2.Token, time.Time, e
log.Errorf("Credentials helper warnings and errors: %v", stderr.String())
}
if err != nil {
return nil, time.Time{}, err
return nil, err
}
token, expiry, refreshExpiry, err := parseTokenExpiryFromOutput(out)
return &oauth2.Token{
AccessToken: token,
Expiry: expiry,
}, refreshExpiry, err
return parseTokenExpiryFromOutput(out)
}

// CredsHelperOut is the struct to record the json output from the credshelper.
type CredsHelperOut struct {
Token string `json:"token"`
Expiry string `json:"expiry"`
RefreshExpiry string `json:"refresh_expiry"`
// JSONOut is the struct to record the json output from the credshelper.
type JSONOut struct {
Token string `json:"token"`
Headers map[string]string `json:"headers"`
Expiry string `json:"expiry"`
RefreshExpiry string `json:"refresh_expiry"`
}

func parseTokenExpiryFromOutput(out string) (string, time.Time, time.Time, error) {
var (
tk string
exp, rexp time.Time
chOut CredsHelperOut
)
if err := json.Unmarshal([]byte(out), &chOut); err != nil {
return tk, exp, rexp,
fmt.Errorf("error while decoding credshelper output:%v", err)
func parseTokenExpiryFromOutput(out string) (*credshelperOutput, error) {
credsOut := &credshelperOutput{}
var jsonOut JSONOut
if err := json.Unmarshal([]byte(out), &jsonOut); err != nil {
return nil, fmt.Errorf("error while decoding credshelper output:%v", err)
}
if jsonOut.Token == "" {
return nil, fmt.Errorf("no token was printed by the credentials helper")
}
tk = chOut.Token
if tk == "" {
return tk, exp, rexp,
fmt.Errorf("no token was printed by the credentials helper")
credsOut.tk = &oauth2.Token{AccessToken: jsonOut.Token}
if len(jsonOut.Headers) == 0 {
return nil, fmt.Errorf("no headers were printed by the credentials helper")
}
if chOut.Expiry != "" {
expiry, err := time.Parse(time.UnixDate, chOut.Expiry)
credsOut.hdrs = jsonOut.Headers
if jsonOut.Expiry != "" {
expiry, err := time.Parse(time.UnixDate, jsonOut.Expiry)
if err != nil {
return tk, exp, rexp, fmt.Errorf("invalid expiry format: %v (Expected time.UnixDate format)", chOut.Expiry)
return nil, fmt.Errorf("invalid expiry format: %v (Expected time.UnixDate format)", jsonOut.Expiry)
}
exp = expiry
rexp = expiry
credsOut.tk.Expiry = expiry
}
if chOut.RefreshExpiry != "" {
rexpiry, err := time.Parse(time.UnixDate, chOut.RefreshExpiry)
if jsonOut.RefreshExpiry != "" {
rexpiry, err := time.Parse(time.UnixDate, jsonOut.RefreshExpiry)
if err != nil {
return tk, exp, rexp, fmt.Errorf("invalid refresh expiry format: %v (Expected time.UnixDate format)", chOut.RefreshExpiry)
return nil, fmt.Errorf("invalid refresh expiry format: %v (Expected time.UnixDate format)", jsonOut.RefreshExpiry)
}
rexp = rexpiry
credsOut.rexp = rexpiry
}
return tk, exp, rexp, nil
return credsOut, nil
}

// binaryRelToAbs converts a path that is relative to the current executable
Expand Down
111 changes: 102 additions & 9 deletions go/pkg/credshelper/credshelper_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package credshelper

import (
"context"
"fmt"
"os"
"path/filepath"
"reflect"
"runtime"
"sync"
"testing"
"time"

Expand All @@ -23,7 +26,7 @@ func TestCredentialsHelperCache(t *testing.T) {
if err != nil {
t.Errorf("failed to create dir for credentials file %q: %v", cf, err)
}
credsHelperCmd := newReusableCmd("echo", []string{`{"token":"testToken", "expiry":"", "refresh_expiry":""}`})
credsHelperCmd := newReusableCmd("echo", []string{`{"headers":{"hdr":"val"},"token":"testToken", "expiry":""}`})
ts := &grpcOauth.TokenSource{
TokenSource: oauth2.ReuseTokenSourceWithExpiry(
&oauth2.Token{},
Expand Down Expand Up @@ -57,7 +60,7 @@ func TestExternalToken(t *testing.T) {
if err != nil {
t.Fatalf("Unable to create temporary file: %v", err)
}
chJSON := fmt.Sprintf(`{"token":"%v","expiry":"%s","refresh_expiry":""}`, tk, exp)
chJSON := fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"%s","refresh_expiry":""}`, tk, exp)
if _, err := tf.Write([]byte(chJSON)); err != nil {
t.Fatalf("Unable to write to file %v: %v", tf.Name(), err)
}
Expand All @@ -69,7 +72,7 @@ func TestExternalToken(t *testing.T) {
}
} else {
credshelper = "echo"
credshelperArgs = []string{fmt.Sprintf(`{"token":"%v","expiry":"%s","refresh_expiry":""}`, tk, exp)}
credshelperArgs = []string{fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"%s","refresh_expiry":""}`, tk, exp)}
}

credsHelperCmd := newReusableCmd(credshelper, credshelperArgs)
Expand Down Expand Up @@ -136,7 +139,7 @@ func writeTokenFile(t *testing.T, path, token string, expiry time.Time) {
t.Fatalf("Unable to open file %v: %v", path, err)
}
defer f.Close()
chJSON := fmt.Sprintf(`{"token":"%v","expiry":"%s","refresh_expiry":""}`, token, expiry.Format(time.UnixDate))
chJSON := fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"%s","refresh_expiry":""}`, token, expiry.Format(time.UnixDate))
if _, err := f.Write([]byte(chJSON)); err != nil {
t.Fatalf("Unable to write to file %v: %v", f.Name(), err)
}
Expand All @@ -153,24 +156,28 @@ func TestNewExternalCredentials(t *testing.T) {
checkExp bool
credshelperOut string
}{{
name: "No Headers",
wantErr: true,
credshelperOut: `{"headers":"","token":"","expiry":"","refresh_expiry":""}`,
}, {
name: "No Token",
wantErr: true,
credshelperOut: `{"token":"","expiry":"","refresh_expiry":""}`,
credshelperOut: `{"headers":{"hdr":"val"},"token":"","expiry":"","refresh_expiry":""}`,
}, {
name: "Credshelper Command Passed - No Expiry",
credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"","refresh_expiry":""}`, testToken),
credshelperOut: fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"","refresh_expiry":""}`, testToken),
}, {
name: "Credshelper Command Passed - Expiry",
checkExp: true,
credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"%v","refresh_expiry":""}`, testToken, unixExp),
credshelperOut: fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"%v","refresh_expiry":""}`, testToken, unixExp),
}, {
name: "Credshelper Command Passed - Refresh Expiry",
checkExp: true,
credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"%v","refresh_expiry":"%v"}`, testToken, unixExp, unixExp),
credshelperOut: fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"%v","refresh_expiry":"%v"}`, testToken, unixExp, unixExp),
}, {
name: "Wrong Expiry Format",
wantErr: true,
credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":"%v","refresh_expiry":"%v"}`, testToken, expStr, expStr),
credshelperOut: fmt.Sprintf(`{"headers":{"hdr":"val"},"token":"%v","expiry":"%v", "refresh_expiry":"%v"}`, testToken, expStr, expStr),
}}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -225,6 +232,92 @@ func TestNewExternalCredentials(t *testing.T) {
}
}

func TestGetRequestMetadata(t *testing.T) {
testToken := "token"
expiredHdrs := map[string]string{"expired": "true"}
testHdrs := map[string]string{"expired": "false"}
expiredExp := time.Now().Add(-time.Hour).Truncate(time.Second)
exp := time.Now().Add(time.Hour).Truncate(time.Second)
unixExp := exp.Format(time.UnixDate)
tests := []struct {
name string
tsExp time.Time
tsHeaders map[string]string
wantErr bool
wantExpired bool
credshelperOut string
}{{
name: "Creds Not Expired",
tsExp: exp,
tsHeaders: testHdrs,
}, {
name: "Creds Expired: Credshelper Successful",
tsExp: expiredExp,
tsHeaders: expiredHdrs,
wantExpired: true,
credshelperOut: fmt.Sprintf(`{"headers":{"expired":"false"},"token":"%v","expiry":"%v"}`, testToken, unixExp),
}, {
name: "Creds Expired: Credshelper Failed",
wantErr: true,
wantExpired: true,
credshelperOut: fmt.Sprintf(`{"headers":"","token":"%v","expiry":""`, testToken),
}}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
credshelper string
credshelperArgs []string
)
if runtime.GOOS == "windows" {
tf, err := os.CreateTemp("", "testnewexternalcreds.json")
if err != nil {
t.Fatalf("Unable to create temporary file: %v", err)
}
if _, err := tf.Write([]byte(test.credshelperOut)); err != nil {
t.Fatalf("Unable to write to file %v: %v", tf.Name(), err)
}
credshelper = "cmd"
credshelperArgs = []string{
"/c",
"cat",
tf.Name(),
}
} else {
credshelper = "echo"
credshelperArgs = []string{test.credshelperOut}
}
credsHelperCmd := newReusableCmd(credshelper, credshelperArgs)
exTs := externalTokenSource{
credsHelperCmd: credsHelperCmd,
expiry: test.tsExp,
headers: test.tsHeaders,
headersLock: sync.RWMutex{},
}
hdrs, err := exTs.GetRequestMetadata(context.Background(), "uri")
if test.wantErr && err == nil {
t.Fatalf("GetRequestMetadata did not return an error.")
}
if !test.wantErr {
if err != nil {
t.Fatalf("GetRequestMetadata returned an error: %v", err)
}
if !reflect.DeepEqual(hdrs, exTs.headers) {
t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, exTs.headers)
}
if !exp.Equal(exTs.expiry) {
t.Errorf("GetRequestMetadata did not update expiry in the tokensource")
}
if !test.wantExpired && !reflect.DeepEqual(hdrs, testHdrs) {
t.Errorf("GetRequestMetadata returned headers: %v, but want headers: %v", hdrs, testHdrs)
}
if test.wantExpired && reflect.DeepEqual(hdrs, expiredHdrs) {
t.Errorf("GetRequestMetadata returned expired headers")
}
}
})
}
}

func TestReusableCmd(t *testing.T) {
binary := "echo"
args := []string{"hello"}
Expand Down

0 comments on commit 47bbbf3

Please sign in to comment.