New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pkg/oauth2: consider session cancellation #182
Changes from 4 commits
7c219eb
b4f1fa8
36c72d7
b10ffcf
71867d0
fe20139
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ package oauth2 | |
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"sync" | ||
"time" | ||
|
||
|
@@ -13,6 +14,7 @@ type passwordCredentialsTokenSource struct { | |
ctx context.Context | ||
cfg *oauth2.Config | ||
username, password string | ||
grantsCounter func() | ||
|
||
mu sync.Mutex // protects the fields below | ||
refreshToken *oauth2.Token | ||
|
@@ -32,20 +34,17 @@ type passwordCredentialsTokenSource struct { | |
// using the given resource owner and password. | ||
// | ||
// It is safe for concurrent use. | ||
func NewPasswordCredentialsTokenSource(ctx context.Context, cfg *oauth2.Config, username, password string) *passwordCredentialsTokenSource { | ||
func NewPasswordCredentialsTokenSource(ctx context.Context, cfg *oauth2.Config, grantsCounter func(), username, password string) *passwordCredentialsTokenSource { | ||
return &passwordCredentialsTokenSource{ | ||
ctx: ctx, | ||
username: username, | ||
password: password, | ||
cfg: cfg, | ||
ctx: ctx, | ||
username: username, | ||
password: password, | ||
cfg: cfg, | ||
grantsCounter: grantsCounter, | ||
} | ||
} | ||
|
||
func (c *passwordCredentialsTokenSource) Token() (*oauth2.Token, error) { | ||
return c.token(time.Now) | ||
} | ||
|
||
func (c *passwordCredentialsTokenSource) token(now func() time.Time) (*oauth2.Token, error) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
|
@@ -56,6 +55,12 @@ func (c *passwordCredentialsTokenSource) token(now func() time.Time) (*oauth2.To | |
|
||
if c.refreshToken.Valid() { | ||
tok, err = c.accessTokenSource.Token() | ||
|
||
rerr, ok := err.(*oauth2.RetrieveError) | ||
if ok && rerr.Response != nil && rerr.Response.StatusCode == http.StatusBadRequest { | ||
return c.passwordCredentialsToken() | ||
} | ||
|
||
if err != nil { | ||
return nil, fmt.Errorf("access token source failed: %v", err) | ||
} | ||
|
@@ -65,26 +70,48 @@ func (c *passwordCredentialsTokenSource) token(now func() time.Time) (*oauth2.To | |
if tok.RefreshToken == c.refreshToken.RefreshToken { | ||
return tok, nil | ||
} | ||
} else { | ||
tok, err = c.cfg.PasswordCredentialsToken(c.ctx, c.username, c.password) | ||
|
||
err = c.setRefreshToken(tok) | ||
if err != nil { | ||
return nil, fmt.Errorf("password credentials token source failed: %v", err) | ||
return nil, err | ||
} | ||
|
||
c.accessTokenSource = c.cfg.TokenSource(c.ctx, tok) | ||
return tok, nil | ||
} | ||
|
||
return c.passwordCredentialsToken() | ||
} | ||
|
||
func (c *passwordCredentialsTokenSource) passwordCredentialsToken() (*oauth2.Token, error) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn’t happen too often, but it would be good to know when it does. Let’s add a counter metric to understand how often this happens. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed 👍 added a metric, PTAL |
||
c.grantsCounter() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems to me we may want to differentiate in a successful and errored token retrieval no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have that indication (status code returned by tollbooth) already in the Would it be idiomatic to add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but it's differentiated in terms of refresh token use and credentials flow? if so lgtm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed oob:
ptal |
||
|
||
tok, err := c.cfg.PasswordCredentialsToken(c.ctx, c.username, c.password) | ||
if err != nil { | ||
return nil, fmt.Errorf("password credentials token source failed: %v", err) | ||
} | ||
|
||
c.accessTokenSource = c.cfg.TokenSource(c.ctx, tok) | ||
|
||
err = c.setRefreshToken(tok) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return tok, nil | ||
} | ||
|
||
func (c *passwordCredentialsTokenSource) setRefreshToken(tok *oauth2.Token) error { | ||
expires, ok := tok.Extra("refresh_expires_in").(float64) | ||
if !ok { | ||
return nil, fmt.Errorf("refresh_expires_in is not a float64, but %T", tok.Extra("refresh_expires_in")) | ||
return fmt.Errorf("refresh_expires_in is not a float64, but %T", tok.Extra("refresh_expires_in")) | ||
} | ||
|
||
// create a dummy access token to reuse calculation logic for the Valid() method | ||
c.refreshToken = &oauth2.Token{ | ||
AccessToken: tok.RefreshToken, | ||
RefreshToken: tok.RefreshToken, | ||
Expiry: now().Add(time.Duration(int64(expires)) * time.Second), | ||
Expiry: time.Now().Add(time.Duration(int64(expires)) * time.Second), | ||
} | ||
|
||
return tok, nil | ||
return nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typically the actual metric is passed in here, or even grouped in a metrics struct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was just to ease testing, but I can pass a metrics struct here (or the counter interface) too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adressed