Skip to content

Commit

Permalink
Fix returned token encoding (istio-ecosystem#45)
Browse files Browse the repository at this point in the history
* Fix returned token encoding

* fix tests
  • Loading branch information
nacx committed Feb 23, 2024
1 parent 7e217a4 commit 90319de
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 26 deletions.
12 changes: 10 additions & 2 deletions internal/authz/oidc.go
Expand Up @@ -731,18 +731,26 @@ func (o *oidcHandler) encodeTokensToHeaders(tokens *oidc.TokenResponse) map[stri
headers := make(map[string]string)

// Always add the ID token to the headers
headers[o.config.GetIdToken().GetHeader()] = o.config.IdToken.GetPreamble() + " " + oidc.EncodeToken(tokens.IDToken)
headers[o.config.GetIdToken().GetHeader()] = encodeHeaderValue(o.config.IdToken.GetPreamble(), tokens.IDToken)

if o.config.GetAccessToken() == nil || tokens.AccessToken == "" {
return headers
}

// If there is an access token and config enables it, add it to the headers
headers[o.config.GetAccessToken().GetHeader()] = o.config.GetAccessToken().GetPreamble() + " " + oidc.EncodeToken(tokens.AccessToken)
headers[o.config.GetAccessToken().GetHeader()] = encodeHeaderValue(o.config.GetAccessToken().GetPreamble(), tokens.AccessToken)

return headers
}

// encodeHeaderValue encodes the value with the given preamble, if any
func encodeHeaderValue(preamble string, value string) string {
if preamble != "" {
return preamble + " " + value
}
return value
}

// areRequiredTokensExpired checks if the required tokens are expired.
func (o *oidcHandler) areRequiredTokensExpired(tokens *oidc.TokenResponse) (bool, error) {
idToken, err := tokens.ParseIDToken()
Expand Down
37 changes: 23 additions & 14 deletions internal/authz/oidc_test.go
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -1152,10 +1151,8 @@ func TestMatchesLogoutPath(t *testing.T) {

func TestEncodeTokensToHeaders(t *testing.T) {
const (
idToken = "id-token"
accessToken = "access-token"
idTokenB64 = "aWQtdG9rZW4="
accessTokenB64 = "YWNjZXNzLXRva2Vu"
idToken = "id-token"
accessToken = "access-token"
)

tests := []struct {
Expand All @@ -1171,7 +1168,7 @@ func TestEncodeTokensToHeaders(t *testing.T) {
},
idToken: idToken, accessToken: "",
want: map[string]string{
"Authorization": "Bearer " + idTokenB64,
"Authorization": "Bearer " + idToken,
},
},
{
Expand All @@ -1182,8 +1179,8 @@ func TestEncodeTokensToHeaders(t *testing.T) {
},
idToken: idToken, accessToken: accessToken,
want: map[string]string{
"Authorization": "Bearer " + idTokenB64,
"X-Access-Token": "Bearer " + accessTokenB64,
"Authorization": "Bearer " + idToken,
"X-Access-Token": "Bearer " + accessToken,
},
},
{
Expand All @@ -1194,8 +1191,8 @@ func TestEncodeTokensToHeaders(t *testing.T) {
},
idToken: idToken, accessToken: accessToken,
want: map[string]string{
"X-Id-Token": "Other " + idTokenB64,
"X-Access-Token-Other": "Other " + accessTokenB64,
"X-Id-Token": "Other " + idToken,
"X-Access-Token-Other": "Other " + accessToken,
},
},
{
Expand All @@ -1206,7 +1203,7 @@ func TestEncodeTokensToHeaders(t *testing.T) {
},
idToken: idToken, accessToken: "",
want: map[string]string{
"Authorization": "Bearer " + idTokenB64,
"Authorization": "Bearer " + idToken,
},
},
{
Expand All @@ -1216,7 +1213,19 @@ func TestEncodeTokensToHeaders(t *testing.T) {
},
idToken: idToken, accessToken: accessToken,
want: map[string]string{
"Authorization": "Bearer " + idTokenB64,
"Authorization": "Bearer " + idToken,
},
},
{
name: "config with out preamble",
config: &oidcv1.OIDCConfig{
IdToken: &oidcv1.TokenConfig{Header: "X-ID-Token"},
AccessToken: &oidcv1.TokenConfig{Header: "X-Access-Token"},
},
idToken: idToken, accessToken: accessToken,
want: map[string]string{
"X-ID-Token": idToken,
"X-Access-Token": accessToken,
},
},
}
Expand Down Expand Up @@ -1487,9 +1496,9 @@ func requireTokensInResponse(t *testing.T, resp *envoy.OkHttpResponse, cfg *oidc
wantIDToken, wantAccessToken string
)

wantIDToken = cfg.GetIdToken().GetPreamble() + " " + base64.URLEncoding.EncodeToString([]byte(idToken))
wantIDToken = encodeHeaderValue(cfg.GetIdToken().GetPreamble(), idToken)
if cfg.GetAccessToken() != nil {
wantAccessToken = cfg.GetAccessToken().GetPreamble() + " " + base64.URLEncoding.EncodeToString([]byte(accessToken))
wantAccessToken = encodeHeaderValue(cfg.GetAccessToken().GetPreamble(), accessToken)
}

for _, header := range resp.GetHeaders() {
Expand Down
6 changes: 0 additions & 6 deletions internal/oidc/token.go
Expand Up @@ -15,7 +15,6 @@
package oidc

import (
"encoding/base64"
"time"

"github.com/lestrrat-go/jwx/jwt"
Expand All @@ -36,8 +35,3 @@ func (t *TokenResponse) ParseIDToken() (jwt.Token, error) { return ParseToken(t.
func ParseToken(token string) (jwt.Token, error) {
return jwt.Parse([]byte(token), jwt.WithValidate(false))
}

// EncodeToken returns the base64 encoded string representation of the token. Compatible with HTTP headers.
func EncodeToken(token string) string {
return base64.URLEncoding.EncodeToString([]byte(token))
}
4 changes: 0 additions & 4 deletions internal/oidc/token_test.go
Expand Up @@ -50,7 +50,3 @@ func newToken() string {
signed, _ := jwt.Sign(token, jwa.HS256, []byte("key"))
return string(signed)
}

func TestEncodeToken(t *testing.T) {
require.Equal(t, "dGVzdA==", EncodeToken("test"))
}

0 comments on commit 90319de

Please sign in to comment.