diff --git a/actions/review.go b/actions/review.go index c18eec2..193b802 100644 --- a/actions/review.go +++ b/actions/review.go @@ -47,15 +47,16 @@ func Review(c *cli.Context) error { ctx := c.Context deps := deps.FromContext(ctx) - if deps.Auth == nil { - return errors.New("error loading GitHub credentials, run plz auth") + token, err := deps.Auth.Token() + if err != nil { + return err } - gitHubRepo, err := newGitHubRepo(ctx, deps.Auth.Token()) + gitHubRepo, err := newGitHubRepo(ctx, token) if err != nil { return err } graphqlClient := graphql.NewClient(deps.PlzAPIBaseURL+"/api/v1", &http.Client{ - Transport: &authTransport{Token: deps.Auth.Token()}, + Transport: &authTransport{Token: token}, }) if err := checkCleanWorktree(ctx, gitHubRepo); err != nil { diff --git a/actions/switch.go b/actions/switch.go index 5abf967..7f18959 100644 --- a/actions/switch.go +++ b/actions/switch.go @@ -51,10 +51,11 @@ func (r review) String() string { func Switch(c *cli.Context) error { ctx := c.Context deps := deps.FromContext(ctx) - if deps.Auth == nil { - return errors.New("error loading GitHub credentials, run plz auth") + token, err := deps.Auth.Token() + if err != nil { + return err } - gitHubRepo, err := newGitHubRepo(ctx, deps.Auth.Token()) + gitHubRepo, err := newGitHubRepo(ctx, token) if err != nil { return err } @@ -99,7 +100,7 @@ func Switch(c *cli.Context) error { } graphqlClient := graphql.NewClient(deps.PlzAPIBaseURL+"/api/v1", &http.Client{ - Transport: &authTransport{Token: deps.Auth.Token()}, + Transport: &authTransport{Token: token}, }) curReview, err := loadReview(ctx, graphqlClient, reviewID) if err != nil { diff --git a/auth/auth.go b/auth/auth.go index cc9c5fe..b378c01 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -14,6 +14,8 @@ import ( "github.com/zalando/go-keyring" ) +var ErrNoAuthCredentials = errors.New("no auth credentials") + type state struct { Token string `json:"token"` ExpiresAt time.Time `json:"expires_at"` @@ -24,34 +26,12 @@ type state struct { } type Auth struct { - state state + plzAPIBaseURL string + *state } -func (a *Auth) Token() string { - return a.state.Token -} - -func LoadFromKeyRing(plzAPIBaseURL string) (*Auth, error) { - authInfoJSON, err := keyring.Get("plz", "authState") - if err != nil { - if errors.Is(err, keyring.ErrNotFound) { - err = nil - } - return nil, err - } - var auth Auth - err = json.Unmarshal([]byte(authInfoJSON), &auth.state) - if err != nil { - return nil, err - } - // Refresh the token if it's expired or nearly expired. - if auth.state.ExpiresAt.Before(time.Now().Add(10 * time.Minute)) { - err := auth.refresh(plzAPIBaseURL) - if err != nil { - return nil, errors.Wrap(err, "failed to refresh auth token") - } - } - return &auth, nil +func New(plzAPIBaseURL string) *Auth { + return &Auth{plzAPIBaseURL: plzAPIBaseURL} } func Prompt(plzAPIBaseURL string) (*Auth, error) { @@ -85,34 +65,55 @@ func Prompt(plzAPIBaseURL string) (*Auth, error) { if err != nil { return nil, errors.WithStack(err) } - auth := &Auth{ - state: state{ - Token: accessToken.Token, - RefreshToken: accessToken.RefreshToken, - Type: accessToken.Type, - Scope: accessToken.Scope, - }, - } // The device library doesn't return the expiry time, so we have to // immediately refresh the token to get the expiry time. - err = auth.refresh(plzAPIBaseURL) + state, err := loadStateFromRefreshToken(plzAPIBaseURL, accessToken.RefreshToken) if err != nil { return nil, err } - return auth, nil + return &Auth{ + plzAPIBaseURL: plzAPIBaseURL, + state: state, + }, nil +} + +func (a *Auth) Token() (string, error) { + if a.state == nil { + state, err := loadStateFromKeyRing(a.plzAPIBaseURL) + if err != nil { + return "", ErrNoAuthCredentials + } + a.state = state + } + // Refresh the token if it's expired or nearly expired. + if a.state.ExpiresAt.Before(time.Now().Add(10 * time.Minute)) { + if a.state.RefreshTokenExpiresAt.Before(time.Now().Add(10 * time.Minute)) { + // When refresh token is expired, we have to re-auth from scratch. + return "", ErrNoAuthCredentials + } + state, err := loadStateFromRefreshToken(a.plzAPIBaseURL, a.state.RefreshToken) + if err != nil { + return "", errors.Wrap(err, "failed to refresh auth token") + } + a.state = state + err = a.SaveToKeyRing() + if err != nil { + return "", errors.Wrap(err, "failed to save new auth token while refreshing") + } + } + return a.state.Token, nil } func (a *Auth) SaveToKeyRing() error { stateJSON, err := json.Marshal(a.state) if err != nil { - return err + return errors.WithStack(err) } - err = keyring.Set("plz", "authState", string(stateJSON)) - return err + return errors.WithStack(keyring.Set("plz", "authState", string(stateJSON))) } -func (a *Auth) refresh(plzAPIBaseURL string) error { - params := url.Values{"refresh_token": {a.state.RefreshToken}} +func loadStateFromRefreshToken(plzAPIBaseURL, refreshToken string) (*state, error) { + params := url.Values{"refresh_token": {refreshToken}} refreshURL := fmt.Sprintf( "%s/auth/github/device/refresh?%s", plzAPIBaseURL, @@ -120,15 +121,15 @@ func (a *Auth) refresh(plzAPIBaseURL string) error { ) req, err := http.NewRequest("POST", refreshURL, nil) if err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } req.Header.Add("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } if resp.StatusCode != http.StatusOK { - return errors.Errorf("failed to refresh token: %s", resp.Status) + return nil, errors.Errorf("failed to refresh token: %s", resp.Status) } body := struct { AccessToken string `json:"access_token"` @@ -140,17 +141,16 @@ func (a *Auth) refresh(plzAPIBaseURL string) error { }{} err = json.NewDecoder(resp.Body).Decode(&body) if err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } - a.state = state{ + return &state{ Token: body.AccessToken, ExpiresAt: time.Now().Add(time.Duration(body.ExpiresIn) * time.Second), RefreshToken: body.RefreshToken, RefreshTokenExpiresAt: time.Now().Add(time.Duration(body.RefreshTokenExpiresIn) * time.Second), Type: body.TokenType, Scope: body.Scope, - } - return nil + }, nil } func fetchGitHubAppClientID(client *http.Client, plzAPIBaseURL string) (string, error) { @@ -167,3 +167,16 @@ func fetchGitHubAppClientID(client *http.Client, plzAPIBaseURL string) (string, } return string(clientIDBytes), nil } + +func loadStateFromKeyRing(plzAPIBaseURL string) (*state, error) { + authInfoJSON, err := keyring.Get("plz", "authState") + if err != nil { + return nil, errors.WithStack(err) + } + var state state + err = json.Unmarshal([]byte(authInfoJSON), &state) + if err != nil { + return nil, errors.WithStack(err) + } + return &state, nil +} diff --git a/cmd/plz/main.go b/cmd/plz/main.go index 90b38db..7bf65ed 100644 --- a/cmd/plz/main.go +++ b/cmd/plz/main.go @@ -59,29 +59,28 @@ func main() { debugWriter = os.Stdout } plzAPIBaseURL := c.String("plz-api-base-url") - auth, err := auth.LoadFromKeyRing(plzAPIBaseURL) - if err != nil { - return errors.WithStack(err) - } - baseDeps := &deps.Deps{ + c.Context = deps.ContextWithDeps(c.Context, &deps.Deps{ ErrorLog: log.New(os.Stderr, "", 0), InfoLog: log.New(os.Stdout, "", 0), DebugLog: log.New(debugWriter, "[debug] ", log.Ldate|log.Lmicroseconds), - Auth: auth, PlzAPIBaseURL: plzAPIBaseURL, - } - c.Context = deps.ContextWithDeps(c.Context, baseDeps) + Auth: auth.New(plzAPIBaseURL), + }) return nil }, ExitErrHandler: func(c *cli.Context, err error) { deps := deps.FromContext(c.Context) if err != nil { - deps.ErrorLog.Println(err.Error()) - var stackTracer interface { - StackTrace() errors.StackTrace - } - if errors.As(err, &stackTracer) { - deps.DebugLog.Printf("%+v", stackTracer.StackTrace()) + if errors.Is(err, auth.ErrNoAuthCredentials) { + deps.ErrorLog.Println("no auth credentials, run plz auth") + } else { + deps.ErrorLog.Println(err.Error()) + var stackTracer interface { + StackTrace() errors.StackTrace + } + if errors.As(err, &stackTracer) { + deps.DebugLog.Printf("%+v", stackTracer.StackTrace()) + } } os.Exit(1) }