Skip to content

Commit

Permalink
Fix oauthServiceAccount client interface to include contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
stlaz authored and damemi committed Mar 20, 2020
1 parent da5c711 commit 5807a73
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
22 changes: 11 additions & 11 deletions pkg/oauth/oauthserviceaccountclient/oauthclientregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ func init() {
// Based on the namespace and names provided, it builds a map of resource name to redirect URIs.
// The redirect URIs represent the default values as specified by the resource.
// These values can be overridden by user specified data. Errors returned are informative and non-fatal.
type namesToObjMapperFunc func(namespace string, names sets.String) (map[string]redirectURIList, []error)
type namesToObjMapperFunc func(ctx context.Context, namespace string, names sets.String) (map[string]redirectURIList, []error)

// TODO add ingress support
// var ingressGroupKind = routeapi.SchemeGroupVersion.WithKind(IngressKind).GroupKind()

// OAuthClientGetter exposes a way to get a specific client. This is useful for other registries to get scope limitations
// on particular clients. This interface will make its easier to write a future cache on it
type OAuthClientGetter interface {
Get(name string, options metav1.GetOptions) (*oauthv1.OAuthClient, error)
Get(ctx context.Context, name string, options metav1.GetOptions) (*oauthv1.OAuthClient, error)
}

type saOAuthClientAdapter struct {
Expand Down Expand Up @@ -228,14 +228,14 @@ func NewServiceAccountOAuthClientGetter(
}
}

func (a *saOAuthClientAdapter) Get(name string, options metav1.GetOptions) (*oauthv1.OAuthClient, error) {
func (a *saOAuthClientAdapter) Get(ctx context.Context, name string, options metav1.GetOptions) (*oauthv1.OAuthClient, error) {
var err error
saNamespace, saName, err := apiserverserviceaccount.SplitUsername(name)
if err != nil {
return a.delegate.Get(name, options)
return a.delegate.Get(ctx, name, options)
}

sa, err := a.saClient.ServiceAccounts(saNamespace).Get(context.TODO(), saName, metav1.GetOptions{})
sa, err := a.saClient.ServiceAccounts(saNamespace).Get(ctx, saName, metav1.GetOptions{})
if err != nil {
return nil, err
}
Expand All @@ -256,7 +256,7 @@ func (a *saOAuthClientAdapter) Get(name string, options metav1.GetOptions) (*oau
}

if len(modelsMap) > 0 {
uris, extractErrors := a.extractRedirectURIs(modelsMap, saNamespace)
uris, extractErrors := a.extractRedirectURIs(ctx, modelsMap, saNamespace)
if len(uris) > 0 {
redirectURIs = append(redirectURIs, uris.extractValidRedirectURIStrings()...)
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func parseModelPrefixName(key string) (string, string, bool) {

// extractRedirectURIs builds redirect URIs using the given models and namespace.
// The returned redirect URIs may contain duplicates and invalid entries. Errors returned are informative and non-fatal.
func (a *saOAuthClientAdapter) extractRedirectURIs(modelsMap map[string]model, namespace string) (redirectURIList, []error) {
func (a *saOAuthClientAdapter) extractRedirectURIs(ctx context.Context, modelsMap map[string]model, namespace string) (redirectURIList, []error) {
var data redirectURIList
routeErrors := []error{}
groupKindModelListMapper := map[schema.GroupKind]modelList{} // map of GroupKind to all models belonging to it
Expand All @@ -373,7 +373,7 @@ func (a *saOAuthClientAdapter) extractRedirectURIs(modelsMap map[string]model, n

for gk, models := range groupKindModelListMapper {
if names := models.getNames(); names.Len() > 0 {
objMapper, errs := groupKindModelToURI[gk](namespace, names)
objMapper, errs := groupKindModelToURI[gk](ctx, namespace, names)
if len(objMapper) > 0 {
data = append(data, models.getRedirectURIs(objMapper)...)
}
Expand All @@ -389,18 +389,18 @@ func (a *saOAuthClientAdapter) extractRedirectURIs(modelsMap map[string]model, n
// redirectURIsFromRoutes is the namesToObjMapperFunc specific to Routes.
// Returns a map of route name to redirect URIs that contain the default data as specified by the route's ingresses.
// Errors returned are informative and non-fatal.
func (a *saOAuthClientAdapter) redirectURIsFromRoutes(namespace string, osRouteNames sets.String) (map[string]redirectURIList, []error) {
func (a *saOAuthClientAdapter) redirectURIsFromRoutes(ctx context.Context, namespace string, osRouteNames sets.String) (map[string]redirectURIList, []error) {
var routes []routev1.Route
routeErrors := []error{}
routeInterface := a.routeClient.Routes(namespace)
if osRouteNames.Len() > 1 {
if r, err := routeInterface.List(context.TODO(), metav1.ListOptions{}); err == nil {
if r, err := routeInterface.List(ctx, metav1.ListOptions{}); err == nil {
routes = r.Items
} else {
routeErrors = append(routeErrors, err)
}
} else {
if r, err := routeInterface.Get(context.TODO(), osRouteNames.List()[0], metav1.GetOptions{}); err == nil {
if r, err := routeInterface.Get(ctx, osRouteNames.List()[0], metav1.GetOptions{}); err == nil {
routes = append(routes, *r)
} else {
routeErrors = append(routeErrors, err)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oauthserviceaccountclient

import (
"context"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -578,7 +579,7 @@ func TestGetClient(t *testing.T) {
grantMethod: oauthv1.GrantHandlerPrompt,
decoder: codecFactory.UniversalDecoder(),
}
client, err := getter.Get(tc.clientName, metav1.GetOptions{})
client, err := getter.Get(context.TODO(), tc.clientName, metav1.GetOptions{})
switch {
case len(tc.expectedErr) == 0 && err == nil:
case len(tc.expectedErr) == 0 && err != nil,
Expand Down Expand Up @@ -625,7 +626,7 @@ type fakeDelegate struct {
called bool
}

func (d *fakeDelegate) Get(name string, options metav1.GetOptions) (*oauthv1.OAuthClient, error) {
func (d *fakeDelegate) Get(ctx context.Context, name string, options metav1.GetOptions) (*oauthv1.OAuthClient, error) {
d.called = true
return nil, nil
}
Expand Down Expand Up @@ -1049,7 +1050,7 @@ func TestGetRedirectURIs(t *testing.T) {
},
} {
a := buildRouteClient(test.routes)
uris, errs := a.redirectURIsFromRoutes(test.namespace, test.models.getNames())
uris, errs := a.redirectURIsFromRoutes(context.TODO(), test.namespace, test.models.getNames())
if len(errs) > 0 {
t.Errorf("%s: unexpected redirectURIsFromRoutes errors %v", test.name, errs)
}
Expand Down Expand Up @@ -1220,7 +1221,7 @@ func TestRedirectURIsFromRoutes(t *testing.T) {
},
} {
a := buildRouteClient(test.routes)
uris, errs := a.redirectURIsFromRoutes(test.namespace, test.names)
uris, errs := a.redirectURIsFromRoutes(context.TODO(), test.namespace, test.names)
if len(errs) > 0 {
t.Errorf("%s: unexpected redirectURIsFromRoutes errors %v", test.name, errs)
}
Expand Down

0 comments on commit 5807a73

Please sign in to comment.