Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLOUDTRUST-2375] Prevent access to KYC API when features are disabled #192

Merged
merged 1 commit into from
Mar 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions cmd/keycloakb/keycloak_bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,13 @@ func main() {
}
}

// Tools for endpoint middleware
var idRetriever = keycloakb.NewRealmIDRetriever(keycloakClient)
var configurationReaderDBModule *configuration.ConfigurationReaderDBModule
{
configurationReaderDBModule = configuration.NewConfigurationReaderDBModule(configurationRoDBConn, logger)
}

// Export configuration
var exportModule = export.NewModule(keycloakClient, logger)
var cfgStorageModue = export.NewConfigStorageModule(eventsDBConn)
Expand Down Expand Up @@ -847,10 +854,10 @@ func main() {
var createShadowUserHandler = configureManagementHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(managementEndpoints.CreateShadowUser)

// KYC handlers
var kycGetActionsHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.GetActions)
var kycGetUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.GetUser)
var kycGetUserByUsernameHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.GetUserByUsername)
var kycValidateUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, logger)(kycEndpoints.ValidateUser)
var kycGetActionsHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.GetActions)
var kycGetUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.GetUser)
var kycGetUserByUsernameHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.GetUserByUsername)
var kycValidateUserHandler = configureKYCHandler(keycloakb.ComponentName, ComponentID, idGenerator, keycloakClient, audienceRequired, tracer, idRetriever, configurationReaderDBModule, logger)(kycEndpoints.ValidateUser)

// actions
managementSubroute.Path("/actions").Methods("GET").Handler(getManagementActionsHandler)
Expand Down Expand Up @@ -1276,10 +1283,11 @@ func configureAccountHandler(ComponentName string, ComponentID string, idGenerat
}
}

func configureKYCHandler(ComponentName string, ComponentID string, idGenerator idgenerator.IDGenerator, keycloakClient *keycloak.Client, audienceRequired string, tracer tracing.OpentracingClient, logger log.Logger) func(endpoint endpoint.Endpoint) http.Handler {
func configureKYCHandler(ComponentName string, ComponentID string, idGenerator idgenerator.IDGenerator, keycloakClient *keycloak.Client, audienceRequired string, tracer tracing.OpentracingClient, idRetriever middleware.IDRetriever, configReader middleware.AdminConfigurationRetriever, logger log.Logger) func(endpoint endpoint.Endpoint) http.Handler {
return func(endpoint endpoint.Endpoint) http.Handler {
var handler http.Handler
handler = kyc.MakeKYCHandler(endpoint, logger)
handler = middleware.MakeEndpointAvailableCheckMW(configuration.CheckKeyPhysical, idRetriever, configReader, logger)(handler)
handler = middleware.MakeHTTPCorrelationIDMW(idGenerator, tracer, logger, ComponentName, ComponentID)(handler)
handler = middleware.MakeHTTPOIDCTokenValidationMW(keycloakClient, audienceRequired, logger)(handler)
return handler
Expand Down
24 changes: 24 additions & 0 deletions internal/keycloakb/kcauthclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keycloakb
import (
"context"

"github.com/cloudtrust/common-service/middleware"
"github.com/cloudtrust/common-service/security"
kc "github.com/cloudtrust/keycloak-client"
)
Expand All @@ -11,13 +12,18 @@ import (
type KeycloakClient interface {
GetGroupsOfUser(accessToken string, realmName, userID string) ([]kc.GroupRepresentation, error)
GetGroup(accessToken string, realmName, groupID string) (kc.GroupRepresentation, error)
GetRealm(accessToken string, realmName string) (kc.RealmRepresentation, error)
}

type kcAuthClient struct {
keycloak KeycloakClient
logger Logger
}

type idretriever struct {
kcClient KeycloakClient
}

// NewKeycloakAuthClient creates an adaptor for Authorization management to access Keycloak
func NewKeycloakAuthClient(client KeycloakClient, logger Logger) security.KeycloakClient {
return &kcAuthClient{
Expand Down Expand Up @@ -59,3 +65,21 @@ func (k *kcAuthClient) GetGroupName(ctx context.Context, accessToken string, rea

return *(grp.Name), nil
}

// NewRealmIDRetriever is a tool use to convert a realm name in a realm ID
func NewRealmIDRetriever(kcClient KeycloakClient) middleware.IDRetriever {
return &idretriever{
kcClient: kcClient,
}
}

func (ir *idretriever) GetID(accessToken, name string) (string, error) {
var realm, err = ir.kcClient.GetRealm(accessToken, name)
if err != nil {
return "", err
}
if realm.Id == nil {
return "", nil
}
return *realm.Id, nil
}
63 changes: 45 additions & 18 deletions internal/keycloakb/kcauthclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,56 @@ func TestGetGroupNamesOfUserSuccess(t *testing.T) {
})
}

func TestGetGroupNameError(t *testing.T) {
testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) {
mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{}, errors.New("error")).Times(1)
_, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID)
assert.NotNil(t, err)
func TestGetGroupName(t *testing.T) {
t.Run("Error", func(t *testing.T) {
testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) {
mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{}, errors.New("error")).Times(1)
_, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID)
assert.NotNil(t, err)
})
})
t.Run("Nil name", func(t *testing.T) {
testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) {
mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: nil}, nil).Times(1)
res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID)
assert.Nil(t, err)
assert.Equal(t, "", res)
})
})
t.Run("Success", func(t *testing.T) {
testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) {
var groupname = "the name"
mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: &groupname}, nil).Times(1)
res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID)
assert.Nil(t, err)
assert.Equal(t, groupname, res)
})
})
}

func TestGetGroupNameNilName(t *testing.T) {
testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) {
mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: nil}, nil).Times(1)
res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID)
func TestGetID(t *testing.T) {
var mockCtrl = gomock.NewController(t)
defer mockCtrl.Finish()

var mockKeycloak = mock.NewKeycloakClient(mockCtrl)
var idRetriever = NewRealmIDRetriever(mockKeycloak)

t.Run("Error", func(t *testing.T) {
mockKeycloak.EXPECT().GetRealm(accessToken, realm).Return(kc.RealmRepresentation{}, errors.New("error"))
_, err := idRetriever.GetID(accessToken, realm)
assert.NotNil(t, err)
})
t.Run("Nil name", func(t *testing.T) {
mockKeycloak.EXPECT().GetRealm(accessToken, realm).Return(kc.RealmRepresentation{}, nil)
id, err := idRetriever.GetID(accessToken, realm)
assert.Nil(t, err)
assert.Equal(t, "", res)
assert.Equal(t, "", id)
})
}

func TestGetGroupNameSuccess(t *testing.T) {
testKeycloakAuthClient(t, func(t *testing.T, mockKeycloak *mock.KeycloakClient, authClient security.KeycloakClient) {
var groupname = "the name"
mockKeycloak.EXPECT().GetGroup(accessToken, realm, groupID).Return(kc.GroupRepresentation{Name: &groupname}, nil).Times(1)
res, err := authClient.GetGroupName(context.TODO(), accessToken, realm, groupID)
t.Run("Success", func(t *testing.T) {
var id = "the-realm-identifier"
mockKeycloak.EXPECT().GetRealm(accessToken, realm).Return(kc.RealmRepresentation{Id: &id}, nil)
res, err := idRetriever.GetID(accessToken, realm)
assert.Nil(t, err)
assert.Equal(t, groupname, res)
assert.Equal(t, id, res)
})
}