Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions actions/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ func Review(c *cli.Context) error {
ctx := c.Context
deps := deps.FromContext(ctx)

if deps.Auth == nil {
return errors.New("error loading GitHub credentials, run plz auth")
token, err := deps.Auth.Token()
if err != nil {
return err
}
gitHubRepo, err := newGitHubRepo(ctx, deps.Auth.Token())
gitHubRepo, err := newGitHubRepo(ctx, token)
if err != nil {
return err
}
graphqlClient := graphql.NewClient(deps.PlzAPIBaseURL+"/api/v1", &http.Client{
Transport: &authTransport{Token: deps.Auth.Token()},
Transport: &authTransport{Token: token},
})

if err := checkCleanWorktree(ctx, gitHubRepo); err != nil {
Expand Down
9 changes: 5 additions & 4 deletions actions/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ func (r review) String() string {
func Switch(c *cli.Context) error {
ctx := c.Context
deps := deps.FromContext(ctx)
if deps.Auth == nil {
return errors.New("error loading GitHub credentials, run plz auth")
token, err := deps.Auth.Token()
if err != nil {
return err
}
gitHubRepo, err := newGitHubRepo(ctx, deps.Auth.Token())
gitHubRepo, err := newGitHubRepo(ctx, token)
if err != nil {
return err
}
Expand Down Expand Up @@ -99,7 +100,7 @@ func Switch(c *cli.Context) error {
}

graphqlClient := graphql.NewClient(deps.PlzAPIBaseURL+"/api/v1", &http.Client{
Transport: &authTransport{Token: deps.Auth.Token()},
Transport: &authTransport{Token: token},
})
curReview, err := loadReview(ctx, graphqlClient, reviewID)
if err != nil {
Expand Down
109 changes: 61 additions & 48 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/zalando/go-keyring"
)

var ErrNoAuthCredentials = errors.New("no auth credentials")

type state struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
Expand All @@ -24,34 +26,12 @@ type state struct {
}

type Auth struct {
state state
plzAPIBaseURL string
*state
}

func (a *Auth) Token() string {
return a.state.Token
}

func LoadFromKeyRing(plzAPIBaseURL string) (*Auth, error) {
authInfoJSON, err := keyring.Get("plz", "authState")
if err != nil {
if errors.Is(err, keyring.ErrNotFound) {
err = nil
}
return nil, err
}
var auth Auth
err = json.Unmarshal([]byte(authInfoJSON), &auth.state)
if err != nil {
return nil, err
}
// Refresh the token if it's expired or nearly expired.
if auth.state.ExpiresAt.Before(time.Now().Add(10 * time.Minute)) {
err := auth.refresh(plzAPIBaseURL)
if err != nil {
return nil, errors.Wrap(err, "failed to refresh auth token")
}
}
return &auth, nil
func New(plzAPIBaseURL string) *Auth {
return &Auth{plzAPIBaseURL: plzAPIBaseURL}
}

func Prompt(plzAPIBaseURL string) (*Auth, error) {
Expand Down Expand Up @@ -85,50 +65,71 @@ func Prompt(plzAPIBaseURL string) (*Auth, error) {
if err != nil {
return nil, errors.WithStack(err)
}
auth := &Auth{
state: state{
Token: accessToken.Token,
RefreshToken: accessToken.RefreshToken,
Type: accessToken.Type,
Scope: accessToken.Scope,
},
}
// The device library doesn't return the expiry time, so we have to
// immediately refresh the token to get the expiry time.
err = auth.refresh(plzAPIBaseURL)
state, err := loadStateFromRefreshToken(plzAPIBaseURL, accessToken.RefreshToken)
if err != nil {
return nil, err
}
return auth, nil
return &Auth{
plzAPIBaseURL: plzAPIBaseURL,
state: state,
}, nil
}

func (a *Auth) Token() (string, error) {
if a.state == nil {
state, err := loadStateFromKeyRing(a.plzAPIBaseURL)
if err != nil {
return "", ErrNoAuthCredentials
}
a.state = state
}
// Refresh the token if it's expired or nearly expired.
if a.state.ExpiresAt.Before(time.Now().Add(10 * time.Minute)) {
if a.state.RefreshTokenExpiresAt.Before(time.Now().Add(10 * time.Minute)) {
// When refresh token is expired, we have to re-auth from scratch.
return "", ErrNoAuthCredentials
}
state, err := loadStateFromRefreshToken(a.plzAPIBaseURL, a.state.RefreshToken)
if err != nil {
return "", errors.Wrap(err, "failed to refresh auth token")
}
a.state = state
err = a.SaveToKeyRing()
if err != nil {
return "", errors.Wrap(err, "failed to save new auth token while refreshing")
}
}
return a.state.Token, nil
}

func (a *Auth) SaveToKeyRing() error {
stateJSON, err := json.Marshal(a.state)
if err != nil {
return err
return errors.WithStack(err)
}
err = keyring.Set("plz", "authState", string(stateJSON))
return err
return errors.WithStack(keyring.Set("plz", "authState", string(stateJSON)))
}

func (a *Auth) refresh(plzAPIBaseURL string) error {
params := url.Values{"refresh_token": {a.state.RefreshToken}}
func loadStateFromRefreshToken(plzAPIBaseURL, refreshToken string) (*state, error) {
params := url.Values{"refresh_token": {refreshToken}}
refreshURL := fmt.Sprintf(
"%s/auth/github/device/refresh?%s",
plzAPIBaseURL,
params.Encode(),
)
req, err := http.NewRequest("POST", refreshURL, nil)
if err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}
req.Header.Add("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}
if resp.StatusCode != http.StatusOK {
return errors.Errorf("failed to refresh token: %s", resp.Status)
return nil, errors.Errorf("failed to refresh token: %s", resp.Status)
}
body := struct {
AccessToken string `json:"access_token"`
Expand All @@ -140,17 +141,16 @@ func (a *Auth) refresh(plzAPIBaseURL string) error {
}{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}
a.state = state{
return &state{
Token: body.AccessToken,
ExpiresAt: time.Now().Add(time.Duration(body.ExpiresIn) * time.Second),
RefreshToken: body.RefreshToken,
RefreshTokenExpiresAt: time.Now().Add(time.Duration(body.RefreshTokenExpiresIn) * time.Second),
Type: body.TokenType,
Scope: body.Scope,
}
return nil
}, nil
}

func fetchGitHubAppClientID(client *http.Client, plzAPIBaseURL string) (string, error) {
Expand All @@ -167,3 +167,16 @@ func fetchGitHubAppClientID(client *http.Client, plzAPIBaseURL string) (string,
}
return string(clientIDBytes), nil
}

func loadStateFromKeyRing(plzAPIBaseURL string) (*state, error) {
authInfoJSON, err := keyring.Get("plz", "authState")
if err != nil {
return nil, errors.WithStack(err)
}
var state state
err = json.Unmarshal([]byte(authInfoJSON), &state)
if err != nil {
return nil, errors.WithStack(err)
}
return &state, nil
}
27 changes: 13 additions & 14 deletions cmd/plz/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,29 +59,28 @@ func main() {
debugWriter = os.Stdout
}
plzAPIBaseURL := c.String("plz-api-base-url")
auth, err := auth.LoadFromKeyRing(plzAPIBaseURL)
if err != nil {
return errors.WithStack(err)
}
baseDeps := &deps.Deps{
c.Context = deps.ContextWithDeps(c.Context, &deps.Deps{
ErrorLog: log.New(os.Stderr, "", 0),
InfoLog: log.New(os.Stdout, "", 0),
DebugLog: log.New(debugWriter, "[debug] ", log.Ldate|log.Lmicroseconds),
Auth: auth,
PlzAPIBaseURL: plzAPIBaseURL,
}
c.Context = deps.ContextWithDeps(c.Context, baseDeps)
Auth: auth.New(plzAPIBaseURL),
})
return nil
},
ExitErrHandler: func(c *cli.Context, err error) {
deps := deps.FromContext(c.Context)
if err != nil {
deps.ErrorLog.Println(err.Error())
var stackTracer interface {
StackTrace() errors.StackTrace
}
if errors.As(err, &stackTracer) {
deps.DebugLog.Printf("%+v", stackTracer.StackTrace())
if errors.Is(err, auth.ErrNoAuthCredentials) {
deps.ErrorLog.Println("no auth credentials, run plz auth")
} else {
deps.ErrorLog.Println(err.Error())
var stackTracer interface {
StackTrace() errors.StackTrace
}
if errors.As(err, &stackTracer) {
deps.DebugLog.Printf("%+v", stackTracer.StackTrace())
}
}
os.Exit(1)
}
Expand Down