Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/davecgh/go-spew v1.1.1
github.com/go-jose/go-jose/v3 v3.0.3
github.com/go-ldap/ldap/v3 v3.4.3
github.com/google/go-cmp v0.6.0
github.com/gophercloud/gophercloud/v2 v2.4.0
github.com/gorilla/sessions v1.2.1
github.com/openshift/api v0.0.0-20240410131754-6e4793fb6a4b
Expand Down Expand Up @@ -54,7 +55,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/cel-go v0.17.7 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
Expand Down
46 changes: 38 additions & 8 deletions pkg/oauth/external/openid/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ func (p provider) GetUserIdentity(data *osincli.AccessData) (authapi.UserIdentit
identity.Extra[authapi.IdentityDisplayNameKey] = name
}

if groups, ok := getArrayOrStringClaimValue(claims, p.GroupClaims...); ok {
groups, err := getGroupClaimValue(claims, p.GroupClaims...)
if err != nil {
return nil, fmt.Errorf("could not retrieve group claim for %#v from %#v: %w", p.GroupClaims, claims, err)
}
if len(groups) > 0 {
identity.ProviderGroups = groups
}

Expand All @@ -232,24 +236,50 @@ func getClaimValue(data map[string]interface{}, claims ...string) (string, bool)
return "", false
}

func getArrayOrStringClaimValue(data map[string]interface{}, claims ...string) ([]string, bool) {
func getArrayOrStringClaimValue(data map[string]interface{}, claims ...string) ([]string, error) {
for _, claim := range claims {
val, ok := data[claim]
if !ok {
continue
}

switch valTyped := val.(type) {
case []interface{}:
case []any:
ret := make([]string, 0, len(valTyped))
for _, s := range valTyped {
ret = append(ret, s.(string))
for _, v := range valTyped {
if s, ok := v.(string); ok {
ret = append(ret, s)
} else {
return nil, fmt.Errorf("expected string array item, got %T", v)
}
}
return ret, true
return ret, nil

case string:
return []string{valTyped}, true
return []string{valTyped}, nil

default:
return nil, fmt.Errorf("expected array or string, got %T", val)
}
}
return nil, nil
}

func getGroupClaimValue(data map[string]interface{}, claims ...string) ([]string, error) {
groups, err := getArrayOrStringClaimValue(data, claims...)
if err != nil || len(groups) == 0 {
return groups, err
}

if len(groups) == 1 && strings.HasPrefix(groups[0], "[") && strings.HasSuffix(groups[0], "]") {
var actualGroups []string
if err := json.Unmarshal([]byte(groups[0]), &actualGroups); err != nil {
return nil, fmt.Errorf("could not unmarshal group claim array string: %w", err)
}
return actualGroups, nil
}
return nil, false

return groups, nil
}

// fetch and decode JSON from the given UserInfo URL
Expand Down
111 changes: 100 additions & 11 deletions pkg/oauth/external/openid/openid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"

"github.com/openshift/oauth-server/pkg/oauth/external"
)

Expand All @@ -24,7 +26,7 @@ func TestOpenID(t *testing.T) {
}

func TestDecodeJWT(t *testing.T) {
testcases := []struct {
testCases := []struct {
Name string
JWT string
ExpectData map[string]interface{}
Expand Down Expand Up @@ -81,15 +83,102 @@ func TestDecodeJWT(t *testing.T) {
ExpectErr: false,
},
}
for i, tc := range testcases {
data, err := decodeJWT(tc.JWT)
if tc.ExpectErr != (err != nil) {
t.Errorf("%d: expected error %v, got %v", i, tc.ExpectErr, err)
continue
}
if !reflect.DeepEqual(data, tc.ExpectData) {
t.Errorf("%d: expected\n\t%#v\ngot\n\t%#v", i, tc.ExpectData, data)
continue
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
data, err := decodeJWT(tc.JWT)
if tc.ExpectErr != (err != nil) {
t.Fatalf("expected error %v, got %v", tc.ExpectErr, err)
}
if !reflect.DeepEqual(data, tc.ExpectData) {
t.Fatalf("expected\n\t%#v\ngot\n\t%#v", tc.ExpectData, data)
}
})
}
}

func TestGetGroupClaimValue(t *testing.T) {
testCases := []struct {
Name string
Data map[string]any
Claims []string
ExpectedGroups []string
ExpectErr bool
}{
{
Name: "unset",
Data: map[string]any{},
Claims: []string{"groups"},
ExpectedGroups: nil,
ExpectErr: false,
},
{
Name: "string",
Data: map[string]any{"groups": "groupA"},
Claims: []string{"groups"},
ExpectedGroups: []string{"groupA"},
ExpectErr: false,
},
{
Name: "other simple type ignored",
Data: map[string]any{"groups": 10},
Claims: []string{"groups"},
ExpectedGroups: nil,
ExpectErr: true,
},
{
Name: "array with string values",
Data: map[string]any{"groups": []any{"groupA", "groupB"}},
Claims: []string{"groups"},
ExpectedGroups: []string{"groupA", "groupB"},
ExpectErr: false,
},
{
Name: "array with int values",
Data: map[string]any{"groups": []any{10, 20}},
Claims: []string{"groups"},
ExpectedGroups: nil,
ExpectErr: true,
},
{
Name: "string-encoded array",
Data: map[string]any{"groups": `["groupA", "groupB"]`},
Claims: []string{"groups"},
ExpectedGroups: []string{"groupA", "groupB"},
ExpectErr: false,
},
{
Name: "string-encoded array wrapped in an array",
Data: map[string]any{"groups": []any{`["groupA", "groupB"]`}},
Claims: []string{"groups"},
ExpectedGroups: []string{"groupA", "groupB"},
ExpectErr: false,
},
{
Name: "invalid string-encoded array",
Data: map[string]any{"groups": `[groupA, groupB]`},
Claims: []string{"groups"},
ExpectedGroups: nil,
ExpectErr: true,
},
{
Name: "multiple string-encoded arrays not unwrapped",
Data: map[string]any{"groups": []any{`["groupA", "groupB"]`, `["groupC"]`}},
Claims: []string{"groups"},
ExpectedGroups: []string{`["groupA", "groupB"]`, `["groupC"]`},
ExpectErr: false,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
groups, err := getGroupClaimValue(tc.Data, tc.Claims...)
if tc.ExpectErr != (err != nil) {
t.Fatalf("Unexpected error '%v', expected '%v'", err, tc.ExpectErr)
}

if !cmp.Equal(groups, tc.ExpectedGroups) {
t.Fatalf("Unexpected groups:\n%s", cmp.Diff(tc.ExpectedGroups, groups))
}
})
}
}