Skip to content

Commit

Permalink
implement sts token exchange
Browse files Browse the repository at this point in the history
Co-authored-by: Maksim Nabokikh <max.nabokih@gmail.com>
Signed-off-by: Sean Liao <sean+git@liao.dev>
  • Loading branch information
seankhliao and nabokihms committed Jun 10, 2023
1 parent fda87ac commit a80245d
Show file tree
Hide file tree
Showing 14 changed files with 411 additions and 34 deletions.
4 changes: 4 additions & 0 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ func (p *password) UnmarshalJSON(b []byte) error {

// OAuth2 describes enabled OAuth2 extensions.
type OAuth2 struct {
// list of allowed grant types,
// defaults to all supported types
GrantTypes []string `json:"grantTypes"`

ResponseTypes []string `json:"responseTypes"`
// If specified, do not prompt the user to approve client authorization. The
// act of logging in implies authorization.
Expand Down
7 changes: 7 additions & 0 deletions cmd/dex/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ staticClients:
oauth2:
alwaysShowLoginScreen: true
grantTypes:
- refresh_token
- "urn:ietf:params:oauth:grant-type:token-exchange"
connectors:
- type: mockCallback
Expand Down Expand Up @@ -161,6 +164,10 @@ logger:
},
OAuth2: OAuth2{
AlwaysShowLoginScreen: true,
GrantTypes: []string{
"refresh_token",
"urn:ietf:params:oauth:grant-type:token-exchange",
},
},
StaticConnectors: []Connector{
{
Expand Down
10 changes: 10 additions & 0 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func runServe(options serveOptions) error {
healthChecker := gosundheit.New()

serverConfig := server.Config{
AllowedGrantTypes: c.OAuth2.GrantTypes,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
Expand Down Expand Up @@ -554,6 +555,15 @@ func applyConfigOverrides(options serveOptions, config *Config) {
if config.Frontend.Dir == "" {
config.Frontend.Dir = os.Getenv("DEX_FRONTEND_DIR")
}

if len(config.OAuth2.GrantTypes) == 0 {
config.OAuth2.GrantTypes = []string{
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
}
}
}

func pprofHandler(router *http.ServeMux) {
Expand Down
4 changes: 4 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ type RefreshConnector interface {
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
}

type TokenIdentityConnector interface {
TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error)
}
4 changes: 4 additions & 0 deletions connector/mock/connectortest.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func (m *Callback) Refresh(ctx context.Context, s connector.Scopes, identity con
return m.Identity, nil
}

func (m *Callback) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
return m.Identity, nil
}

// CallbackConfig holds the configuration parameters for a connector which requires no interaction.
type CallbackConfig struct{}

Expand Down
21 changes: 19 additions & 2 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ type caller uint
const (
createCaller caller = iota
refreshCaller
exchangeCaller
)

func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
Expand Down Expand Up @@ -284,16 +285,32 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
return c.createIdentity(ctx, identity, token, refreshCaller)
}

func (c *oidcConnector) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
var identity connector.Identity
token := &oauth2.Token{
AccessToken: subjectToken,
}
return c.createIdentity(ctx, identity, token, exchangeCaller)
}

func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) {
var claims map[string]interface{}

rawIDToken, ok := token.Extra("id_token").(string)
if ok {
if rawIDToken, ok := token.Extra("id_token").(string); ok {
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}

if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
} else if caller == exchangeCaller {
// AccessToken here could be either an id token or an access token
idToken, err := c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}).Verify(ctx, token.AccessToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify token: %v", err)
}
if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
Expand Down
76 changes: 76 additions & 0 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oidc

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
Expand Down Expand Up @@ -428,6 +429,81 @@ func TestRefresh(t *testing.T) {
}
}

func TestTokenIdentity(t *testing.T) {
tokenTypeAccess := "urn:ietf:params:oauth:token-type:access_token"
tokenTypeID := "urn:ietf:params:oauth:token-type:id_token"
long2short := map[string]string{
tokenTypeAccess: "access_token",
tokenTypeID: "id_token",
}

tests := []struct {
name string
subjectType string
userInfo bool
}{
{
name: "id_token",
subjectType: tokenTypeID,
}, {
name: "access_token",
subjectType: tokenTypeAccess,
}, {
name: "id_token with user info",
subjectType: tokenTypeID,
userInfo: true,
}, {
name: "access_token with user info",
subjectType: tokenTypeAccess,
userInfo: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

testServer, err := setupServer(map[string]any{
"sub": "subvalue",
"name": "namevalue",
}, true)
if err != nil {
t.Fatal("failed to setup test server", err)
}
conn, err := newConnector(Config{
Issuer: testServer.URL,
Scopes: []string{"openid", "groups"},
GetUserInfo: tc.userInfo,
})
if err != nil {
t.Fatal("failed to create new connector", err)
}

res, err := http.Get(testServer.URL + "/token")
if err != nil {
t.Fatal("failed to get initial token", err)
}
defer res.Body.Close()
var tokenResponse map[string]any
err = json.NewDecoder(res.Body).Decode(&tokenResponse)
if err != nil {
t.Fatal("failed to decode initial token", err)
}

origToken := tokenResponse[long2short[tc.subjectType]].(string)
identity, err := conn.TokenIdentity(ctx, tc.subjectType, origToken)
if err != nil {
t.Fatal("failed to get token identity", err)
}

// assert identity
expectEquals(t, identity.UserID, "subvalue")
expectEquals(t, identity.Username, "namevalue")
})
}
}

func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
Expand Down
126 changes: 112 additions & 14 deletions server/handlers.go
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
implicitOrHybrid = true
var err error

accessToken, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -830,6 +830,11 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
}

grantType := r.PostFormValue("grant_type")
if !contains(s.supportedGrantTypes, grantType) {
s.logger.Errorf("unsupported grant type: %v", grantType)
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
return
}
switch grantType {
case grantTypeDeviceCode:
s.handleDeviceToken(w, r)
Expand All @@ -839,6 +844,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword:
s.withClientFromStorage(w, r, s.handlePasswordGrant)
case grantTypeTokenExchange:
s.handleTokenExchange(w, r)
default:
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
Expand Down Expand Up @@ -917,7 +924,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
}

func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1180,7 +1187,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Groups: identity.Groups,
}

accessToken, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
if err != nil {
s.logger.Errorf("password grant failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1319,21 +1326,112 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
s.writeAccessToken(w, resp)
}

func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// TODO: check global allowed grant types?

if err := r.ParseForm(); err != nil {
s.logger.Errorf("could not parse request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
q := r.Form

resource := q.Get("resource") // OPTIONAL, use for issued token audience
scopes := strings.Fields(q.Get("scope")) // OPTIONAL, map to issed token scope
requestedTokenType := q.Get("requested_token_type") // OPTIONAL, default to access token
if requestedTokenType == "" {
requestedTokenType = tokenTypeAccess
}
connID := q.Get("audience") // REQUIRED (RFC 8693 optional), use for connector ID
subjectToken := q.Get("subject_token") // REQUIRED
subjectTokenType := q.Get("subject_token_type") // REQUIRED

switch subjectTokenType {
case tokenTypeID, tokenTypeAccess: // ok, continue
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid subject_token_type.", http.StatusBadRequest)
return
}

if subjectToken == "" {
s.tokenErrHelper(w, errInvalidRequest, "Missing subject_token", http.StatusBadRequest)
return
}

conn, err := s.getConnector(connID)
if err != nil {
s.logger.Errorf("failed to get connector: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
if !ok {
s.logger.Errorf("connector doesn't implement token exchange: %v", connID)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken)
if err != nil {
s.logger.Errorf("failed to verify subject token: %v", err)
s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized)
return
}

claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
PreferredUsername: identity.PreferredUsername,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
Groups: identity.Groups,
}
resp := accessTokenResponse{
IssuedTokenType: requestedTokenType,
TokenType: "bearer",
}
var expiry time.Time
switch requestedTokenType {
case tokenTypeID:
resp.AccessToken, expiry, err = s.newIDToken(resource, claims, scopes, "", "", "", connID)
case tokenTypeAccess:
resp.AccessToken, expiry, err = s.newAccessToken(resource, claims, scopes, "", connID)
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest)
return
}
if err != nil {
s.logger.Errorf("token exchange failed to create new %v token: %v", requestedTokenType, err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
resp.ExpiresIn = int(time.Until(expiry).Seconds())

// Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}

type accessTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
}

func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenResponse {
return &accessTokenResponse{
accessToken,
"bearer",
int(expiry.Sub(s.now()).Seconds()),
refreshToken,
idToken,
AccessToken: accessToken,
TokenType: "bearer",
ExpiresIn: int(expiry.Sub(s.now()).Seconds()),
RefreshToken: refreshToken,
IDToken: idToken,
}
}

Expand All @@ -1355,7 +1453,7 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenRespon

func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) {
if err := s.templates.err(r, w, status, description); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.logger.Errorf("server template error: %v", err)
}
}

Expand Down
Loading

0 comments on commit a80245d

Please sign in to comment.