Skip to content

Commit

Permalink
Improve context lifetime handling for OIDC provider
Browse files Browse the repository at this point in the history
  • Loading branch information
briankassouf committed Sep 18, 2018
1 parent fb9c940 commit f2c84b4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 14 additions & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ type jwtAuthBackend struct {
l sync.RWMutex
provider *oidc.Provider
cachedConfig *jwtConfig

providerCtx context.Context
providerCtxCancel context.CancelFunc
}

func backend(c *logical.BackendConfig) *jwtAuthBackend {
b := new(jwtAuthBackend)
b.providerCtx, b.providerCtxCancel = context.WithCancel(context.Background())

b.Backend = &framework.Backend{
AuthRenew: b.pathLoginRenew,
Expand All @@ -55,11 +59,20 @@ func backend(c *logical.BackendConfig) *jwtAuthBackend {
pathConfig(b),
},
),
Clean: b.cleanup,
}

return b
}

func (b *jwtAuthBackend) cleanup(_ context.Context) {
b.l.Lock()
if b.providerCtxCancel != nil {
b.providerCtxCancel()
}
b.l.Unlock()
}

func (b *jwtAuthBackend) invalidate(ctx context.Context, key string) {
switch key {
case "config":
Expand Down Expand Up @@ -91,7 +104,7 @@ func (b *jwtAuthBackend) getProvider(ctx context.Context, config *jwtConfig) (*o
return b.provider, nil
}

provider, err := b.createProvider(ctx, config)
provider, err := b.createProvider(config)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
return logical.ErrorResponse("exactly one of 'oidc_discovery_url' and 'jwt_validation_pubkeys' must be set"), nil

case config.OIDCDiscoveryURL != "":
_, err := b.createProvider(ctx, config)
_, err := b.createProvider(config)
if err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error checking discovery URL: {{err}}", err).Error()), nil
}
Expand Down Expand Up @@ -150,7 +150,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
return nil, nil
}

func (b *jwtAuthBackend) createProvider(ctx context.Context, config *jwtConfig) (*oidc.Provider, error) {
func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) {
var certPool *x509.CertPool
if config.OIDCDiscoveryCAPEM != "" {
certPool = x509.NewCertPool()
Expand All @@ -168,7 +168,7 @@ func (b *jwtAuthBackend) createProvider(ctx context.Context, config *jwtConfig)
tc := &http.Client{
Transport: tr,
}
oidcCtx := context.WithValue(ctx, oauth2.HTTPClient, tc)
oidcCtx := context.WithValue(b.providerCtx, oauth2.HTTPClient, tc)

provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
if err != nil {
Expand Down

0 comments on commit f2c84b4

Please sign in to comment.