From 8bbed043b7e58d1d78e7f8ca0d00027867f52fc2 Mon Sep 17 00:00:00 2001 From: Ondra Kupka Date: Tue, 18 Nov 2025 16:34:25 +0100 Subject: [PATCH 1/2] external/openid: Unwrap groups encoded as string Unwrap groups array claim when encoded as a simple string. This means decoding `"groups": "[\"A\", \"B\"]"` into `["A", "B"]` groups claim value. --- go.mod | 2 +- pkg/oauth/external/openid/openid.go | 46 +++++++++--- pkg/oauth/external/openid/openid_test.go | 90 ++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 9bbb25290..fd32fc8ac 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/go-jose/go-jose/v3 v3.0.3 github.com/go-ldap/ldap/v3 v3.4.3 + github.com/google/go-cmp v0.6.0 github.com/gophercloud/gophercloud/v2 v2.4.0 github.com/gorilla/sessions v1.2.1 github.com/openshift/api v0.0.0-20240410131754-6e4793fb6a4b @@ -54,7 +55,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/google/cel-go v0.17.7 // indirect github.com/google/gnostic-models v0.6.8 // indirect - github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/securecookie v1.1.1 // indirect diff --git a/pkg/oauth/external/openid/openid.go b/pkg/oauth/external/openid/openid.go index 587f9d26f..8bcb87ce0 100644 --- a/pkg/oauth/external/openid/openid.go +++ b/pkg/oauth/external/openid/openid.go @@ -213,7 +213,11 @@ func (p provider) GetUserIdentity(data *osincli.AccessData) (authapi.UserIdentit identity.Extra[authapi.IdentityDisplayNameKey] = name } - if groups, ok := getArrayOrStringClaimValue(claims, p.GroupClaims...); ok { + groups, err := getGroupClaimValue(claims, p.GroupClaims...) + if err != nil { + return nil, fmt.Errorf("could not retrieve group claim for %#v from %#v: %w", p.GroupClaims, claims, err) + } + if len(groups) > 0 { identity.ProviderGroups = groups } @@ -232,24 +236,50 @@ func getClaimValue(data map[string]interface{}, claims ...string) (string, bool) return "", false } -func getArrayOrStringClaimValue(data map[string]interface{}, claims ...string) ([]string, bool) { +func getArrayOrStringClaimValue(data map[string]interface{}, claims ...string) ([]string, error) { for _, claim := range claims { val, ok := data[claim] if !ok { continue } + switch valTyped := val.(type) { - case []interface{}: + case []any: ret := make([]string, 0, len(valTyped)) - for _, s := range valTyped { - ret = append(ret, s.(string)) + for _, v := range valTyped { + if s, ok := v.(string); ok { + ret = append(ret, s) + } else { + return nil, fmt.Errorf("expected string array item, got %T", v) + } } - return ret, true + return ret, nil + case string: - return []string{valTyped}, true + return []string{valTyped}, nil + + default: + return nil, fmt.Errorf("expected array or string, got %T", val) + } + } + return nil, nil +} + +func getGroupClaimValue(data map[string]interface{}, claims ...string) ([]string, error) { + groups, err := getArrayOrStringClaimValue(data, claims...) + if err != nil || len(groups) == 0 { + return groups, err + } + + if len(groups) == 1 && strings.HasPrefix(groups[0], "[") && strings.HasSuffix(groups[0], "]") { + var actualGroups []string + if err := json.Unmarshal([]byte(groups[0]), &actualGroups); err != nil { + return nil, fmt.Errorf("could not unmarshal group claim array string: %w", err) } + return actualGroups, nil } - return nil, false + + return groups, nil } // fetch and decode JSON from the given UserInfo URL diff --git a/pkg/oauth/external/openid/openid_test.go b/pkg/oauth/external/openid/openid_test.go index 8b1c48bc7..2583b9769 100644 --- a/pkg/oauth/external/openid/openid_test.go +++ b/pkg/oauth/external/openid/openid_test.go @@ -4,6 +4,8 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/openshift/oauth-server/pkg/oauth/external" ) @@ -86,6 +88,7 @@ func TestDecodeJWT(t *testing.T) { if tc.ExpectErr != (err != nil) { t.Errorf("%d: expected error %v, got %v", i, tc.ExpectErr, err) continue + } if !reflect.DeepEqual(data, tc.ExpectData) { t.Errorf("%d: expected\n\t%#v\ngot\n\t%#v", i, tc.ExpectData, data) @@ -93,3 +96,90 @@ func TestDecodeJWT(t *testing.T) { } } } + +func TestGetGroupClaimValue(t *testing.T) { + testCases := []struct { + Name string + Data map[string]any + Claims []string + ExpectedGroups []string + ExpectErr bool + }{ + { + Name: "unset", + Data: map[string]any{}, + Claims: []string{"groups"}, + ExpectedGroups: nil, + ExpectErr: false, + }, + { + Name: "string", + Data: map[string]any{"groups": "groupA"}, + Claims: []string{"groups"}, + ExpectedGroups: []string{"groupA"}, + ExpectErr: false, + }, + { + Name: "other simple type ignored", + Data: map[string]any{"groups": 10}, + Claims: []string{"groups"}, + ExpectedGroups: nil, + ExpectErr: true, + }, + { + Name: "array with string values", + Data: map[string]any{"groups": []any{"groupA", "groupB"}}, + Claims: []string{"groups"}, + ExpectedGroups: []string{"groupA", "groupB"}, + ExpectErr: false, + }, + { + Name: "array with int values", + Data: map[string]any{"groups": []any{10, 20}}, + Claims: []string{"groups"}, + ExpectedGroups: nil, + ExpectErr: true, + }, + { + Name: "string-encoded array", + Data: map[string]any{"groups": `["groupA", "groupB"]`}, + Claims: []string{"groups"}, + ExpectedGroups: []string{"groupA", "groupB"}, + ExpectErr: false, + }, + { + Name: "string-encoded array wrapped in an array", + Data: map[string]any{"groups": []any{`["groupA", "groupB"]`}}, + Claims: []string{"groups"}, + ExpectedGroups: []string{"groupA", "groupB"}, + ExpectErr: false, + }, + { + Name: "invalid string-encoded array", + Data: map[string]any{"groups": `[groupA, groupB]`}, + Claims: []string{"groups"}, + ExpectedGroups: nil, + ExpectErr: true, + }, + { + Name: "multiple string-encoded arrays not unwrapped", + Data: map[string]any{"groups": []any{`["groupA", "groupB"]`, `["groupC"]`}}, + Claims: []string{"groups"}, + ExpectedGroups: []string{`["groupA", "groupB"]`, `["groupC"]`}, + ExpectErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + groups, err := getGroupClaimValue(tc.Data, tc.Claims...) + if tc.ExpectErr != (err != nil) { + t.Fatalf("Unexpected error '%v', expected '%v'", err, tc.ExpectErr) + } + + if !cmp.Equal(groups, tc.ExpectedGroups) { + t.Fatalf("Unexpected groups:\n%s", cmp.Diff(tc.ExpectedGroups, groups)) + } + }) + } +} From a6f1284bc09d28bf6fa7be42e2f5788b55a7c595 Mon Sep 17 00:00:00 2001 From: Ondra Kupka Date: Tue, 18 Nov 2025 16:38:54 +0100 Subject: [PATCH 2/2] external/openid: Use t.Run to wrap test runs --- pkg/oauth/external/openid/openid_test.go | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/pkg/oauth/external/openid/openid_test.go b/pkg/oauth/external/openid/openid_test.go index 2583b9769..049c684ad 100644 --- a/pkg/oauth/external/openid/openid_test.go +++ b/pkg/oauth/external/openid/openid_test.go @@ -26,7 +26,7 @@ func TestOpenID(t *testing.T) { } func TestDecodeJWT(t *testing.T) { - testcases := []struct { + testCases := []struct { Name string JWT string ExpectData map[string]interface{} @@ -83,17 +83,16 @@ func TestDecodeJWT(t *testing.T) { ExpectErr: false, }, } - for i, tc := range testcases { - data, err := decodeJWT(tc.JWT) - if tc.ExpectErr != (err != nil) { - t.Errorf("%d: expected error %v, got %v", i, tc.ExpectErr, err) - continue - - } - if !reflect.DeepEqual(data, tc.ExpectData) { - t.Errorf("%d: expected\n\t%#v\ngot\n\t%#v", i, tc.ExpectData, data) - continue - } + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + data, err := decodeJWT(tc.JWT) + if tc.ExpectErr != (err != nil) { + t.Fatalf("expected error %v, got %v", tc.ExpectErr, err) + } + if !reflect.DeepEqual(data, tc.ExpectData) { + t.Fatalf("expected\n\t%#v\ngot\n\t%#v", tc.ExpectData, data) + } + }) } }