Skip to content

Commit

Permalink
add dex config flag for enabling client secret encryption
Browse files Browse the repository at this point in the history
* if enabled, it will make sure client secret is bcrypted correctly
* if not, it falls back to old behaviour that allowing empty client
secret and comparing plain text, though now it will do
ConstantTimeCompare to avoid a timing attack.

So in either way it should provide more secure of client secret
verification.

Co-authored-by: Alex Surraci <suraci.alex@gmail.com>
Signed-off-by: Rui Yang <ruiya@vmware.com>
  • Loading branch information
2 people authored and CI Bot committed Mar 20, 2021
1 parent ec6f3a2 commit d658c24
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 15 deletions.
30 changes: 15 additions & 15 deletions server/handlers.go
Expand Up @@ -2,6 +2,7 @@ package server

import (
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"errors"
Expand Down Expand Up @@ -681,22 +682,21 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
return
}

if client.Secret != clientSecret {
if clientSecret == "" {
s.logger.Infof("missing client_secret on token request for client: %s", client.ID)
} else {
s.logger.Infof("invalid client_secret on token request for client: %s", client.ID)
if s.hashClientSecret {
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}
} else {
if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 {
if clientSecret == "" {
s.logger.Infof("missing client_secret on token request for client: %s", client.ID)
} else {
s.logger.Infof("invalid client_secret on token request for client: %s", client.ID)
}
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}
}

if err := checkCost([]byte(client.Secret)); err != nil {
s.logger.Errorf("failed to check cost of client secret: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}

grantType := r.PostFormValue("grant_type")
Expand Down
25 changes: 25 additions & 0 deletions server/server.go
Expand Up @@ -77,6 +77,9 @@ type Config struct {
// If enabled, the connectors selection page will always be shown even if there's only one
AlwaysShowLoginScreen bool

// If enabled, the client secret is expected to be encrypted
HashClientSecret bool

RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours
AuthRequestsValidFor time.Duration // Defaults to 24 hours
Expand Down Expand Up @@ -151,6 +154,9 @@ type Server struct {
// If enabled, show the connector selection screen even if there's only one
alwaysShowLogin bool

// If enabled, the client secret is expected to be encrypted
hashClientSecret bool

// Used for password grant
passwordConnector string

Expand Down Expand Up @@ -189,6 +195,24 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
if c.Storage == nil {
return nil, errors.New("server: storage cannot be nil")
}

if c.HashClientSecret {
clients, err := c.Storage.ListClients()
if err != nil {
return nil, fmt.Errorf("server: failed to list clients")
}

for _, client := range clients {
if client.Secret == "" {
return nil, fmt.Errorf("server: client secret can't be empty")
}

if err = checkCost([]byte(client.Secret)); err != nil {
return nil, fmt.Errorf("server: failed to check cost of client secret: %v", err)
}
}
}

if len(c.SupportedResponseTypes) == 0 {
c.SupportedResponseTypes = []string{responseTypeCode}
}
Expand Down Expand Up @@ -232,6 +256,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen,
hashClientSecret: c.HashClientSecret,
now: now,
templates: tmpls,
passwordConnector: c.PasswordConnector,
Expand Down
160 changes: 160 additions & 0 deletions server/server_test.go
Expand Up @@ -1637,3 +1637,163 @@ func TestOAuth2DeviceFlow(t *testing.T) {
}()
}
}

func TestClientSecretEncryption(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.HashClientSecret = true
})
defer httpServer.Close()

clientID := "testclient"
clientSecret := "testclientsecret"
hash, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to bcrypt: %s", err)
}

// Query server's provider metadata.
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}

var (
// If the OAuth2 client didn't get a response, we need
// to print the requests the user saw.
gotCode bool
reqDump, respDump []byte // Auth step, not token.
state = "a_state"
)
defer func() {
if !gotCode {
t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump)
}
}()

// Setup OAuth2 client.
var oauth2Config *oauth2.Config

requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}

// Create the OAuth2 config.
oauth2Config = &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: p.Endpoint(),
Scopes: requestedScopes,
}

oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
return
}

// User is at '/callback' so they were just redirected _from_ dex.
q := r.URL.Query()

// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
gotCode = true
token, err := oauth2Config.Exchange(ctx, code)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
}

oidcConfig := &oidc.Config{SkipClientIDCheck: true}

idToken, ok := token.Extra("id_token").(string)
if !ok {
t.Errorf("no id token found")
return
}
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
t.Errorf("failed to verify id token: %v", err)
return
}
}

w.WriteHeader(http.StatusOK)
}))

oauth2Config.RedirectURL = oauth2Client.URL + "/callback"

defer oauth2Client.Close()

// Regester the client above with dex.
client := storage.Client{
ID: clientID,
Secret: string(hash),
RedirectURIs: []string{oauth2Client.URL + "/callback"},
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}

// Login!
//
// 1. First request to client, redirects to dex.
// 2. Dex "logs in" the user, redirects to client with "code".
// 3. Client exchanges "code" for "token" (id_token, refresh_token, etc.).
// 4. Test is run with OAuth2 token response.
//
resp, err := http.Get(oauth2Client.URL + "/login")
if err != nil {
t.Fatalf("get failed: %v", err)
}
defer resp.Body.Close()

if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil {
t.Fatal(err)
}
if respDump, err = httputil.DumpResponse(resp, true); err != nil {
t.Fatal(err)
}
}

func TestClientSecretEncryptionCost(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clientID := "testclient"
clientSecret := "testclientsecret"
hash, err := bcrypt.GenerateFromPassword([]byte(clientSecret), 5)
if err != nil {
t.Fatalf("failed to bcrypt: %s", err)
}

// Register the client above with dex.
client := storage.Client{
ID: clientID,
Secret: string(hash),
}

config := Config{
Storage: memory.New(logger),
Web: WebConfig{
Dir: "../web",
},
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
HashClientSecret: true,
}

err = config.Storage.CreateClient(client)
if err != nil {
t.Fatalf("failed to create client: %v", err)
}

_, err = newServer(ctx, config, staticRotationStrategy(testKey))
if err == nil {
t.Error("constructing server should have failed")
}

if !strings.Contains(err.Error(), "failed to check cost") {
t.Error("should have failed with cost error")
}
}

0 comments on commit d658c24

Please sign in to comment.