From 097e99d5247727502ebd89d28d616cabd3280a46 Mon Sep 17 00:00:00 2001 From: Francis PEROT Date: Mon, 2 Mar 2020 16:19:15 +0100 Subject: [PATCH] [CLOUDTRUST-2375] Prevent access to KYC API when features are disabled --- cmd/keycloakb/keycloak_bridge.go | 18 +++++-- internal/keycloakb/kcauthclient.go | 24 ++++++++++ internal/keycloakb/kcauthclient_test.go | 63 ++++++++++++++++++------- 3 files changed, 82 insertions(+), 23 deletions(-) diff --git a/cmd/keycloakb/keycloak_bridge.go b/cmd/keycloakb/keycloak_bridge.go index 12585a8e..44481a90 100644 --- a/cmd/keycloakb/keycloak_bridge.go +++ b/cmd/keycloakb/keycloak_bridge.go @@ -680,6 +680,13 @@ func main() { } } + // Tools for endpoint middleware + var idRetriever = keycloakb.NewRealmIDRetriever(keycloakClient) + var configurationReaderDBModule *configuration.ConfigurationReaderDBModule + { + configurationReaderDBModule = configuration.NewConfigurationReaderDBModule(configurationRoDBConn, logger) + } + // Export configuration var exportModule = export.NewModule(keycloakClient, logger) var cfgStorageModue = export.NewConfigStorageModule(eventsDBConn) @@ -847,10 +854,10 @@ func main() { var createShadowUserHandler = configureManagementHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(managementEndpoints.CreateShadowUser) // KYC handlers - var kycGetActionsHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.GetActions) - var kycGetUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.GetUser) - var kycGetUserByUsernameHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.GetUserByUsername) - var kycValidateUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.ValidateUser) + var kycGetActionsHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.GetActions) + var kycGetUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.GetUser) + var kycGetUserByUsernameHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.GetUserByUsername) + var kycValidateUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.ValidateUser) // actions managementSubroute.Path("/actions").Methods("GET").Handler(getManagementActionsHandler) @@ -1276,10 +1283,11 @@ func configureAccountHandler(ComponentName string, ComponentID string, idGenerat } } -func configureKYCHandler(ComponentName string, ComponentID string, idGenerator idgenerator.IDGenerator, keycloakClient *keycloak.Client, audienceRequired string, tracer tracing.OpentracingClient, logger log.Logger) func(endpoint endpoint.Endpoint) http.Handler { +func configureKYCHandler(ComponentName string, ComponentID string, idGenerator idgenerator.IDGenerator, keycloakClient *keycloak.Client, audienceRequired string, tracer tracing.OpentracingClient, idRetriever middleware.IDRetriever, configReader middleware.AdminConfigurationRetriever, logger log.Logger) func(endpoint endpoint.Endpoint) http.Handler { return func(endpoint endpoint.Endpoint) http.Handler { var handler http.Handler handler = kyc.MakeKYCHandler(endpoint, logger) + handler = middleware.MakeEndpointAvailableCheckMW(configuration.CheckKeyPhysical, idRetriever, configReader, logger)(handler) handler = middleware.MakeHTTPCorrelationIDMW(idGenerator, tracer, logger, ComponentName, ComponentID)(handler) handler = middleware.MakeHTTPOIDCTokenValidationMW(keycloakClient, audienceRequired, logger)(handler) return handler diff --git a/internal/keycloakb/kcauthclient.go b/internal/keycloakb/kcauthclient.go index 5aebb6ac..9f187ebf 100644 --- a/internal/keycloakb/kcauthclient.go +++ b/internal/keycloakb/kcauthclient.go @@ -3,6 +3,7 @@ package keycloakb import ( "context" + "github.com/cloudtrust/common-service/middleware" "github.com/cloudtrust/common-service/security" kc "github.com/cloudtrust/keycloak-client" ) @@ -11,6 +12,7 @@ import ( type KeycloakClient interface { GetGroupsOfUser(accessToken string, realmName, userID string) ([]kc.GroupRepresentation, error) GetGroup(accessToken string, realmName, groupID string) (kc.GroupRepresentation, error) + GetRealm(accessToken string, realmName string) (kc.RealmRepresentation, error) } type kcAuthClient struct { @@ -18,6 +20,10 @@ type kcAuthClient struct { logger Logger } +type idretriever struct { + kcClient KeycloakClient +} + // NewKeycloakAuthClient creates an adaptor for Authorization management to access Keycloak func NewKeycloakAuthClient(client KeycloakClient, logger Logger) security.KeycloakClient { return &kcAuthClient{ @@ -59,3 +65,21 @@ func (k *kcAuthClient) GetGroupName(ctx context.Context, accessToken string, rea return *(grp.Name), nil } + +// NewRealmIDRetriever is a tool use to convert a realm name in a realm ID +func NewRealmIDRetriever(kcClient KeycloakClient) middleware.IDRetriever { + return &idretriever{ + kcClient: kcClient, + } +} + +func (ir *idretriever) GetID(accessToken, name string) (string, error) { + var realm, err = ir.kcClient.GetRealm(accessToken, name) + if err != nil { + return "", err + } + if realm.Id == nil { + return "", nil + } + return *realm.Id, nil +} diff --git a/internal/keycloakb/kcauthclient_test.go b/internal/keycloakb/kcauthclient_test.go index 4f45f70e..2f78aa0e 100644 --- a/internal/keycloakb/kcauthclient_test.go +++ b/internal/keycloakb/kcauthclient_test.go @@ -62,29 +62,56 @@ func TestGetGroupNamesOfUserSuccess(t *testing.T) { }) } -func TestGetGroupNameError(t *testing.T) { - testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) { - mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{}, errors.New("error")).Times(1) - _, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID) - assert.NotNil(t, err) +func TestGetGroupName(t *testing.T) { + t.Run("Error", func(t *testing.T) { + testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) { + mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{}, errors.New("error")).Times(1) + _, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID) + assert.NotNil(t, err) + }) + }) + t.Run("Nil name", func(t *testing.T) { + testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) { + mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: nil}, nil).Times(1) + res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID) + assert.Nil(t, err) + assert.Equal(t, "", res) + }) + }) + t.Run("Success", func(t *testing.T) { + testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) { + var groupname = "the name" + mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: &groupname}, nil).Times(1) + res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID) + assert.Nil(t, err) + assert.Equal(t, groupname, res) + }) }) } -func TestGetGroupNameNilName(t *testing.T) { - testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) { - mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: nil}, nil).Times(1) - res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID) +func TestGetID(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + + var mockKeycloak = mock.NewKeycloakClient(mockCtrl) + var idRetriever = NewRealmIDRetriever(mockKeycloak) + + t.Run("Error", func(t *testing.T) { + mockKeycloak.EXPECT().GetRealm(accessToken, realm).Return(kc.RealmRepresentation{}, errors.New("error")) + _, err := idRetriever.GetID(accessToken, realm) + assert.NotNil(t, err) + }) + t.Run("Nil name", func(t *testing.T) { + mockKeycloak.EXPECT().GetRealm(accessToken, realm).Return(kc.RealmRepresentation{}, nil) + id, err := idRetriever.GetID(accessToken, realm) assert.Nil(t, err) - assert.Equal(t, "", res) + assert.Equal(t, "", id) }) -} - -func TestGetGroupNameSuccess(t *testing.T) { - testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) { - var groupname = "the name" - mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: &groupname}, nil).Times(1) - res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID) + t.Run("Success", func(t *testing.T) { + var id = "the-realm-identifier" + mockKeycloak.EXPECT().GetRealm(accessToken, realm).Return(kc.RealmRepresentation{Id: &id}, nil) + res, err := idRetriever.GetID(accessToken, realm) assert.Nil(t, err) - assert.Equal(t, groupname, res) + assert.Equal(t, id, res) }) }