Skip to content

Commit

Permalink
Merge pull request #182 from s-urbaniak/fix-mon-690
Browse files Browse the repository at this point in the history
pkg/oauth2: consider session cancellation
  • Loading branch information
openshift-merge-robot committed Jun 13, 2019
2 parents ee32c53 + fe20139 commit 83d33cd
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 43 deletions.
14 changes: 13 additions & 1 deletion cmd/telemeter-server/main.go
Expand Up @@ -23,6 +23,8 @@ import (
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"

oidc "github.com/coreos/go-oidc"
"github.com/oklog/run"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -250,8 +252,18 @@ func (o *Options) Run() error {
Endpoint: provider.Endpoint(),
}

grantsTotal := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "telemeter_password_credentials_grants_total",
Help: "Tracks the number of resource owner password credential grants.",
},
[]string{"cause", "status"},
)

prometheus.MustRegister(grantsTotal)

src := telemeter_oauth2.NewPasswordCredentialsTokenSource(
ctx, &cfg,
ctx, &cfg, grantsTotal,
o.AuthorizeUsername, o.AuthorizePassword,
)

Expand Down
66 changes: 50 additions & 16 deletions pkg/oauth2/token_source.go
Expand Up @@ -3,16 +3,19 @@ package oauth2
import (
"context"
"fmt"
"net/http"
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"
"golang.org/x/oauth2"
)

type passwordCredentialsTokenSource struct {
ctx context.Context
cfg *oauth2.Config
username, password string
grantsCounter *prometheus.CounterVec

mu sync.Mutex // protects the fields below
refreshToken *oauth2.Token
Expand All @@ -32,20 +35,17 @@ type passwordCredentialsTokenSource struct {
// using the given resource owner and password.
//
// It is safe for concurrent use.
func NewPasswordCredentialsTokenSource(ctx context.Context, cfg *oauth2.Config, username, password string) *passwordCredentialsTokenSource {
func NewPasswordCredentialsTokenSource(ctx context.Context, cfg *oauth2.Config, grantsCounter *prometheus.CounterVec, username, password string) *passwordCredentialsTokenSource {
return &passwordCredentialsTokenSource{
ctx: ctx,
username: username,
password: password,
cfg: cfg,
ctx: ctx,
username: username,
password: password,
cfg: cfg,
grantsCounter: grantsCounter,
}
}

func (c *passwordCredentialsTokenSource) Token() (*oauth2.Token, error) {
return c.token(time.Now)
}

func (c *passwordCredentialsTokenSource) token(now func() time.Time) (*oauth2.Token, error) {
c.mu.Lock()
defer c.mu.Unlock()

Expand All @@ -56,6 +56,12 @@ func (c *passwordCredentialsTokenSource) token(now func() time.Time) (*oauth2.To

if c.refreshToken.Valid() {
tok, err = c.accessTokenSource.Token()

rerr, ok := err.(*oauth2.RetrieveError)
if ok && rerr.Response != nil && rerr.Response.StatusCode == http.StatusBadRequest {
return c.passwordCredentialsToken("session_removed")
}

if err != nil {
return nil, fmt.Errorf("access token source failed: %v", err)
}
Expand All @@ -65,26 +71,54 @@ func (c *passwordCredentialsTokenSource) token(now func() time.Time) (*oauth2.To
if tok.RefreshToken == c.refreshToken.RefreshToken {
return tok, nil
}
} else {
tok, err = c.cfg.PasswordCredentialsToken(c.ctx, c.username, c.password)

err = c.setRefreshToken(tok)
if err != nil {
return nil, fmt.Errorf("password credentials token source failed: %v", err)
return nil, err
}

c.accessTokenSource = c.cfg.TokenSource(c.ctx, tok)
return tok, nil
}

return c.passwordCredentialsToken("token_expired")
}

func (c *passwordCredentialsTokenSource) passwordCredentialsToken(cause string) (*oauth2.Token, error) {
tok, err := c.cfg.PasswordCredentialsToken(c.ctx, c.username, c.password)

status := "success"
if err != nil {
status = "failed"
}

c.grantsCounter.WithLabelValues(cause, status).Inc()

if err != nil {
return nil, fmt.Errorf("password credentials token source failed: %v", err)
}

c.accessTokenSource = c.cfg.TokenSource(c.ctx, tok)

err = c.setRefreshToken(tok)
if err != nil {
return nil, err
}

return tok, nil
}

func (c *passwordCredentialsTokenSource) setRefreshToken(tok *oauth2.Token) error {
expires, ok := tok.Extra("refresh_expires_in").(float64)
if !ok {
return nil, fmt.Errorf("refresh_expires_in is not a float64, but %T", tok.Extra("refresh_expires_in"))
return fmt.Errorf("refresh_expires_in is not a float64, but %T", tok.Extra("refresh_expires_in"))
}

// create a dummy access token to reuse calculation logic for the Valid() method
c.refreshToken = &oauth2.Token{
AccessToken: tok.RefreshToken,
RefreshToken: tok.RefreshToken,
Expiry: now().Add(time.Duration(int64(expires)) * time.Second),
Expiry: time.Now().Add(time.Duration(int64(expires)) * time.Second),
}

return tok, nil
return nil
}

0 comments on commit 83d33cd

Please sign in to comment.