diff --git a/go/rtl/auth.go b/go/rtl/auth.go index d993bec7f..ce7810fda 100644 --- a/go/rtl/auth.go +++ b/go/rtl/auth.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "reflect" + "sync" "time" "golang.org/x/oauth2" @@ -39,6 +40,47 @@ func (t *transportWithHeaders) RoundTrip(req *http.Request) (*http.Response, err type AuthSession struct { Config ApiSettings Client http.Client + token *oauth2.Token + source oauth2.TokenSource + mu sync.RWMutex +} + +const tokenLeeway = 10 * time.Second + +func (s *AuthSession) IsActive() bool { + s.mu.RLock() + defer s.mu.RUnlock() + if s.token == nil { + return false + } + return s.token.Expiry.After(time.Now().Add(tokenLeeway)) +} + +func (s *AuthSession) Login() (*oauth2.Token, error) { + // First, check with a read lock to avoid contention if the token is valid. + s.mu.RLock() + if s.token != nil && s.token.Expiry.After(time.Now().Add(tokenLeeway)) { + token := s.token + s.mu.RUnlock() + return token, nil + } + s.mu.RUnlock() + + // If the token is invalid or nil, acquire a write lock to refresh it. + s.mu.Lock() + defer s.mu.Unlock() + + // Re-check after obtaining the write lock, in case another goroutine refreshed it. + if s.token != nil && s.token.Expiry.After(time.Now().Add(tokenLeeway)) { + return s.token, nil + } + + token, err := s.source.Token() + if err != nil { + return nil, err + } + s.token = token + return token, nil } func NewAuthSession(config ApiSettings) *AuthSession { @@ -74,9 +116,11 @@ func NewAuthSessionWithTransport(config ApiSettings, transport http.RoundTripper &http.Client{Transport: appIdHeaderTransport}, ) + source := oauthConfig.TokenSource(ctx) + // Make use of oauth2 transport to handle token management oauthTransport := &oauth2.Transport{ - Source: oauthConfig.TokenSource(ctx), + Source: source, // Will set "x-looker-appid" Header on all other requests Base: appIdHeaderTransport, } @@ -84,11 +128,15 @@ func NewAuthSessionWithTransport(config ApiSettings, transport http.RoundTripper return &AuthSession{ Config: config, Client: http.Client{Transport: oauthTransport}, + source: source, } } func (s *AuthSession) Do(result interface{}, method, ver, path string, reqPars map[string]interface{}, body interface{}, options *ApiSettings) error { - + _, err := s.Login() + if err != nil { + return err + } // prepare URL u := fmt.Sprintf("%s/api%s%s", s.Config.BaseUrl, ver, path)