diff --git a/Gopkg.lock b/Gopkg.lock index aaff75097..cb97f915d 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -19,11 +19,11 @@ [[projects]] branch = "master" - digest = "1:2276e1aa87fc57d4af1091b22aad776635973a1e0764c91c88e1a28edad426ef" + digest = "1:add9addd1e9f693f228f2818f3a08df0df2d3ec7c453614c8a4d113e2cc2fbd0" name = "github.com/cloudtrust/keycloak-client" packages = ["."] pruneopts = "" - revision = "0d5f98aceb4c163562b8474fb3665e0e2c400ce5" + revision = "cf331b0198778ec8dea897dddf99a00720f5f273" source = "github.com/cloudtrust/keycloak-client" [[projects]] @@ -123,7 +123,7 @@ version = "v1.3.1" [[projects]] - digest = "1:da2d72f6fdcdcebe27602c9c516636f654eaae67ae09e8a9be10820a4ba42911" + digest = "1:73b360c5b463f59fdc570e9188efc398b63bc7c20a4498a4f076eaa8574adef6" name = "github.com/google/flatbuffers" packages = ["go"] pruneopts = "" @@ -158,7 +158,7 @@ version = "v1.0.0" [[projects]] - digest = "1:0e6a4d206be26596e3ad0127ec2da0f2952169768f33d7010517ba6f0fafd47c" + digest = "1:17f258c3c2cd12980479112a20afb3024019d5b62a670878ff84742fec621115" name = "github.com/influxdata/influxdb" packages = [ "client/v2", @@ -347,7 +347,7 @@ [[projects]] branch = "master" - digest = "1:958ad9932fc5ac9fb5c794f97580ed123ddfed1d965e1de0f98e2a590d6e9e3e" + digest = "1:ed1a3a3847549ea4be3720fd005148e57702e953f83aa337bf1776eb54d2b910" name = "golang.org/x/crypto" packages = [ "ed25519", @@ -355,7 +355,7 @@ "pbkdf2", ] pruneopts = "" - revision = "88737f569e3a9c7ab309cdc09a07fe7fc87233c3" + revision = "f416ebab96af27ca70b6e5c23d6a0747530da626" [[projects]] branch = "master" diff --git a/api/management/api.go b/api/management/api.go index 025e1b750..307e44c63 100644 --- a/api/management/api.go +++ b/api/management/api.go @@ -58,6 +58,11 @@ type PasswordRepresentation struct { Value *string `json:"value,omitempty"` } +type RealmCustomConfiguration struct { + DefaultClientId *string `json:"default_client_id,omitempty"` + DefaultRedirectUri *string `json:"default_redirect_uri,omitempty"` +} + // ConvertCredential creates an API credential from a KC credential func ConvertCredential(credKc *kc.CredentialRepresentation) CredentialRepresentation { var cred CredentialRepresentation diff --git a/cmd/keycloakb/keycloak_bridge.go b/cmd/keycloakb/keycloak_bridge.go index 179409482..501593fee 100644 --- a/cmd/keycloakb/keycloak_bridge.go +++ b/cmd/keycloakb/keycloak_bridge.go @@ -97,6 +97,7 @@ func main() { // Enabled units eventsDBEnabled = c.GetBool("events-db") + configDBEnabled = c.GetBool("config-db") influxEnabled = c.GetBool("influx") jaegerEnabled = c.GetBool("jaeger") sentryEnabled = c.GetBool("sentry") @@ -143,6 +144,11 @@ func main() { dbMaxIdleConns = c.GetInt("db-max-idle-conns") dbConnMaxLifetime = c.GetInt("db-conn-max-lifetime") + // DB for custom configuration + dbConfigUsername = c.GetString("db-config-username") + dbConfigPassword = c.GetString("db-config-password") + dbConfigDatabase = c.GetString("db-config-database") + // Rate limiting rateLimit = map[string]int{ "event": c.GetInt("rate-event"), @@ -260,7 +266,7 @@ func main() { } // Audit events DB. - type EventsDB interface { + type CloudtrustDB interface { Exec(query string, args ...interface{}) (sql.Result, error) QueryRow(query string, args ...interface{}) *sql.Row SetMaxOpenConns(n int) @@ -268,7 +274,7 @@ func main() { SetConnMaxLifetime(d time.Duration) } - var eventsDBConn EventsDB = keycloakb.NoopEventsDB{} + var eventsDBConn CloudtrustDB = keycloakb.NoopDB{} if eventsDBEnabled { var err error eventsDBConn, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s)/%s", dbUsername, dbPassword, dbProtocol, dbHostPort, dbDatabase)) @@ -284,6 +290,21 @@ func main() { } + var configurationDBConn CloudtrustDB = keycloakb.NoopDB{} + if configDBEnabled { + var err error + configurationDBConn, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s)/%s", dbConfigUsername, dbConfigPassword, dbProtocol, dbHostPort, dbConfigDatabase)) + + if err != nil { + logger.Log("msg", "could not create DB connection for configuration storage", "error", err) + return + } + // the config of the DB should have a max_connections > SetMaxOpenConns + configurationDBConn.SetMaxOpenConns(dbMaxOpenConns) + configurationDBConn.SetMaxIdleConns(dbMaxIdleConns) + configurationDBConn.SetConnMaxLifetime(time.Duration(dbConnMaxLifetime) * time.Second) + } + // Event service. var eventEndpoints = event.Endpoints{} { @@ -312,7 +333,6 @@ func main() { eventsDBModule = event.MakeEventsDBModuleInstrumentingMW(influxMetrics.NewHistogram("eventsDB_module"))(eventsDBModule) eventsDBModule = event.MakeEventsDBModuleLoggingMW(log.With(eventLogger, "mw", "module", "unit", "eventsDB"))(eventsDBModule) eventsDBModule = event.MakeEventsDBModuleTracingMW(tracer)(eventsDBModule) - } var eventAdminComponent event.AdminComponent @@ -372,37 +392,47 @@ func main() { eventsDBModule = event.MakeEventsDBModuleInstrumentingMW(influxMetrics.NewHistogram("eventsDB_module"))(eventsDBModule) eventsDBModule = event.MakeEventsDBModuleLoggingMW(log.With(managementLogger, "mw", "module", "unit", "eventsDB"))(eventsDBModule) eventsDBModule = event.MakeEventsDBModuleTracingMW(tracer)(eventsDBModule) - + } + + // module for storing and retrieving the custom configuration + var configDBModule management.ConfigurationDBModule + { + configDBModule = management.NewConfigurationDBModule(configurationDBConn) + configDBModule = management.MakeConfigurationDBModuleInstrumentingMW(influxMetrics.NewHistogram("configDB_module"))(configDBModule) + configDBModule = management.MakeConfigurationDBModuleLoggingMW(log.With(managementLogger, "mw", "module", "unit", "configDB"))(configDBModule) + configDBModule = management.MakeConfigurationDBModuleTracingMW(tracer)(configDBModule) } var keycloakComponent management.Component { - keycloakComponent = management.NewComponent(keycloakClient, eventsDBModule) + keycloakComponent = management.NewComponent(keycloakClient, eventsDBModule, configDBModule) keycloakComponent = management.MakeAuthorizationManagementComponentMW(log.With(managementLogger, "mw", "endpoint"), keycloakClient, authorizationManager)(keycloakComponent) } managementEndpoints = management.Endpoints{ - GetRealms: prepareEndpoint(management.MakeGetRealmsEndpoint(keycloakComponent), "realms_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetRealm: prepareEndpoint(management.MakeGetRealmEndpoint(keycloakComponent), "realm_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetClients: prepareEndpoint(management.MakeGetClientsEndpoint(keycloakComponent), "get_clients_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetClient: prepareEndpoint(management.MakeGetClientEndpoint(keycloakComponent), "get_client_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - CreateUser: prepareEndpoint(management.MakeCreateUserEndpoint(keycloakComponent), "create_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetUser: prepareEndpoint(management.MakeGetUserEndpoint(keycloakComponent), "get_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - UpdateUser: prepareEndpoint(management.MakeUpdateUserEndpoint(keycloakComponent), "update_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - DeleteUser: prepareEndpoint(management.MakeDeleteUserEndpoint(keycloakComponent), "delete_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetUsers: prepareEndpoint(management.MakeGetUsersEndpoint(keycloakComponent), "get_users_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetUserAccountStatus: prepareEndpoint(management.MakeGetUserAccountStatusEndpoint(keycloakComponent), "get_user_accountstatus", influxMetrics, managementLogger, tracer, rateLimit), - GetRoles: prepareEndpoint(management.MakeGetRolesEndpoint(keycloakComponent), "get_roles_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetRole: prepareEndpoint(management.MakeGetRoleEndpoint(keycloakComponent), "get_role_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetClientRoles: prepareEndpoint(management.MakeGetClientRolesEndpoint(keycloakComponent), "get_client_roles_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - CreateClientRole: prepareEndpoint(management.MakeCreateClientRoleEndpoint(keycloakComponent), "create_client_role_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetClientRoleForUser: prepareEndpoint(management.MakeGetClientRolesForUserEndpoint(keycloakComponent), "get_client_roles_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - AddClientRoleToUser: prepareEndpoint(management.MakeAddClientRolesToUserEndpoint(keycloakComponent), "get_client_roles_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetRealmRoleForUser: prepareEndpoint(management.MakeGetRealmRolesForUserEndpoint(keycloakComponent), "get_realm_roles_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - ResetPassword: prepareEndpoint(management.MakeResetPasswordEndpoint(keycloakComponent), "reset_password_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - SendVerifyEmail: prepareEndpoint(management.MakeSendVerifyEmailEndpoint(keycloakComponent), "send_verify_email_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - GetCredentialsForUser: prepareEndpoint(management.MakeGetCredentialsForUserEndpoint(keycloakComponent), "get_credentials_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), - DeleteCredentialsForUser: prepareEndpoint(management.MakeDeleteCredentialsForUserEndpoint(keycloakComponent), "delete_credentials_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetRealms: prepareEndpoint(management.MakeGetRealmEndpoint(keycloakComponent), "realms_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetRealm: prepareEndpoint(management.MakeGetRealmEndpoint(keycloakComponent), "realm_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetClients: prepareEndpoint(management.MakeGetClientsEndpoint(keycloakComponent), "get_clients_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetClient: prepareEndpoint(management.MakeGetClientEndpoint(keycloakComponent), "get_client_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + CreateUser: prepareEndpoint(management.MakeCreateUserEndpoint(keycloakComponent), "create_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetUser: prepareEndpoint(management.MakeGetUserEndpoint(keycloakComponent), "get_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + UpdateUser: prepareEndpoint(management.MakeUpdateUserEndpoint(keycloakComponent), "update_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + DeleteUser: prepareEndpoint(management.MakeDeleteUserEndpoint(keycloakComponent), "delete_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetUsers: prepareEndpoint(management.MakeGetUsersEndpoint(keycloakComponent), "get_users_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetUserAccountStatus: prepareEndpoint(management.MakeGetUserAccountStatusEndpoint(keycloakComponent), "get_user_accountstatus", influxMetrics, managementLogger, tracer, rateLimit), + GetRoles: prepareEndpoint(management.MakeGetRolesEndpoint(keycloakComponent), "get_roles_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetRole: prepareEndpoint(management.MakeGetRoleEndpoint(keycloakComponent), "get_role_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetClientRoles: prepareEndpoint(management.MakeGetClientRolesEndpoint(keycloakComponent), "get_client_roles_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + CreateClientRole: prepareEndpoint(management.MakeCreateClientRoleEndpoint(keycloakComponent), "create_client_role_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetClientRoleForUser: prepareEndpoint(management.MakeGetClientRolesForUserEndpoint(keycloakComponent), "get_client_roles_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + AddClientRoleToUser: prepareEndpoint(management.MakeAddClientRolesToUserEndpoint(keycloakComponent), "get_client_roles_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetRealmRoleForUser: prepareEndpoint(management.MakeGetRealmRolesForUserEndpoint(keycloakComponent), "get_realm_roles_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + ResetPassword: prepareEndpoint(management.MakeResetPasswordEndpoint(keycloakComponent), "reset_password_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + SendVerifyEmail: prepareEndpoint(management.MakeSendVerifyEmailEndpoint(keycloakComponent), "send_verify_email_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetCredentialsForUser: prepareEndpoint(management.MakeGetCredentialsForUserEndpoint(keycloakComponent), "get_credentials_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + DeleteCredentialsForUser: prepareEndpoint(management.MakeDeleteCredentialsForUserEndpoint(keycloakComponent), "delete_credentials_for_user_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + GetRealmCustomConfiguration: prepareEndpoint(management.MakeGetRealmCustomConfigurationEndpoint(keycloakComponent), "get_realm_custom_config_endpoint", influxMetrics, managementLogger, tracer, rateLimit), + UpdateRealmCustomConfiguration: prepareEndpoint(management.MakeUpdateRealmCustomConfigurationEndpoint(keycloakComponent), "update_realm_custom_config_endpoint", influxMetrics, managementLogger, tracer, rateLimit), } } @@ -467,6 +497,9 @@ func main() { var getCredentialsForUserHandler = ConfigureManagementHandler(ComponentName, ComponentID, idGenerator, keycloakClient, tracer, logger)(managementEndpoints.GetCredentialsForUser) var deleteCredentialsForUserHandler = ConfigureManagementHandler(ComponentName, ComponentID, idGenerator, keycloakClient, tracer, logger)(managementEndpoints.DeleteCredentialsForUser) + var getRealmCustomConfigurationHandler = ConfigureManagementHandler(ComponentName, ComponentID, idGenerator, keycloakClient, tracer, logger)(managementEndpoints.GetRealmCustomConfiguration) + var updateRealmCustomConfigurationHandler = ConfigureManagementHandler(ComponentName, ComponentID, idGenerator, keycloakClient, tracer, logger)(managementEndpoints.UpdateRealmCustomConfiguration) + //realms managementSubroute.Path("/realms").Methods("GET").Handler(getRealmsHandler) managementSubroute.Path("/realms/{realm}").Methods("GET").Handler(getRealmHandler) @@ -498,6 +531,10 @@ func main() { managementSubroute.Path("/realms/{realm}/clients/{clientID}/roles").Methods("GET").Handler(getClientRolesHandler) managementSubroute.Path("/realms/{realm}/clients/{clientID}/roles").Methods("POST").Handler(createClientRolesHandler) + // custom configuration par realm + managementSubroute.Path("/realms/{realm}/configuration").Methods("GET").Handler(getRealmCustomConfigurationHandler) + managementSubroute.Path("/realms/{realm}/configuration").Methods("PUT").Handler(updateRealmCustomConfigurationHandler) + // Export. route.Handle("/export", export.MakeHTTPExportHandler(exportEndpoint)).Methods("GET") route.Handle("/export", export.MakeHTTPExportHandler(exportSaveAndExportEndpoint)).Methods("POST") @@ -595,6 +632,12 @@ func config(logger log.Logger) *viper.Viper { v.SetDefault("db-max-idle-conns", 2) v.SetDefault("db-conn-max-lifetime", 3600) + //Storage custom configuration in DB + v.SetDefault("config-db", true) + v.SetDefault("db-config-username", "") + v.SetDefault("db-config-password", "") + v.SetDefault("db-config-database", "") + // Rate limiting (in requests/second) v.SetDefault("rate-event", 1000) v.SetDefault("rate-management", 1000) diff --git a/configs/authorization.json b/configs/authorization.json index c4675e4fc..ada309d9b 100644 --- a/configs/authorization.json +++ b/configs/authorization.json @@ -111,7 +111,17 @@ "DEP": { "*": {} } - } + }, + "GetRealmCustomConfiguration": { + "DEP": { + "*": {} + } + }, + "UpdateRealmCustomConfiguration": { + "DEP": { + "*": {} + } + } }, "integrator_agent":{ "GetRealms": { @@ -151,7 +161,17 @@ "DEP": { "*": {} } - } + }, + "GetRealmCustomConfiguration": { + "DEP": { + "*": {} + } + }, + "UpdateRealmCustomConfiguration": { + "DEP": { + "*": {} + } + } }, "l2_support_manager":{ "GetRealms": { diff --git a/configs/keycloak_bridge.yml b/configs/keycloak_bridge.yml index 187582dac..c8aff3a24 100644 --- a/configs/keycloak_bridge.yml +++ b/configs/keycloak_bridge.yml @@ -37,6 +37,11 @@ db-max-open-conns: 10 db-max-idle-conns: 2 db-conn-max-lifetime: 3600 +# Configuration DB +db-config-username: bridge +db-config-password: bridge-password +db-config-database: cloudtrust_configuration + # audit events events-db: true diff --git a/internal/keycloakb/eventsdb.go b/internal/keycloakb/cloudtrustdb.go similarity index 50% rename from internal/keycloakb/eventsdb.go rename to internal/keycloakb/cloudtrustdb.go index 73bf639e1..d8d9923f0 100644 --- a/internal/keycloakb/eventsdb.go +++ b/internal/keycloakb/cloudtrustdb.go @@ -5,32 +5,32 @@ import ( "time" ) -// NoopEventsDB is a eventsDB client that does nothing. -type NoopEventsDB struct{} +// NoopDB is a database client that does nothing. +type NoopDB struct{} // Exec does nothing. -func (NoopEventsDB) Exec(query string, args ...interface{}) (sql.Result, error) { +func (NoopDB) Exec(query string, args ...interface{}) (sql.Result, error) { return NoopResult{}, nil } // Query does nothing. -func (NoopEventsDB) Query(query string, args ...interface{}) (*sql.Rows, error) { +func (NoopDB) Query(query string, args ...interface{}) (*sql.Rows, error) { return nil, nil } // QueryRow does nothing. -func (NoopEventsDB) QueryRow(query string, args ...interface{}) *sql.Row { +func (NoopDB) QueryRow(query string, args ...interface{}) *sql.Row { return nil } -func (NoopEventsDB) SetMaxOpenConns(n int) { +func (NoopDB) SetMaxOpenConns(n int) { } -func (NoopEventsDB) SetMaxIdleConns(n int) { +func (NoopDB) SetMaxIdleConns(n int) { } -func (NoopEventsDB) SetConnMaxLifetime(d time.Duration) { +func (NoopDB) SetConnMaxLifetime(d time.Duration) { } diff --git a/internal/security/authorization.go b/internal/security/authorization.go index d01fba2fc..1ea49ea78 100644 --- a/internal/security/authorization.go +++ b/internal/security/authorization.go @@ -155,7 +155,7 @@ type AuthorizationManager interface { // // Note: // '*' can be used to express all target realms -// '-' can be used to express all non master realms +// '/' can be used to express all non master realms // '*' can be used to express all target groups are allowed func NewAuthorizationManager(keycloakClient KeycloakClient, jsonAuthz string) (AuthorizationManager, error) { matrix, err := loadAuthorizations(jsonAuthz) diff --git a/pkg/management/authorization.go b/pkg/management/authorization.go index 6b2ef295b..fed5f6049 100644 --- a/pkg/management/authorization.go +++ b/pkg/management/authorization.go @@ -259,3 +259,25 @@ func (c *authorizationComponentMW) CreateClientRole(ctx context.Context, realmNa return c.next.CreateClientRole(ctx, realmName, clientID, role) } + +func (c *authorizationComponentMW) GetRealmCustomConfiguration(ctx context.Context, realmName string) (api.RealmCustomConfiguration, error) { + var action = "GetRealmCustomConfiguration" + var targetRealm = realmName + + if err := c.authManager.CheckAuthorizationOnTargetRealm(ctx, action, targetRealm); err != nil { + return api.RealmCustomConfiguration{}, err + } + + return c.next.GetRealmCustomConfiguration(ctx, realmName) +} + +func (c *authorizationComponentMW) UpdateRealmCustomConfiguration(ctx context.Context, realmName string, customConfig api.RealmCustomConfiguration) error { + var action = "UpdateRealmCustomConfiguration" + var targetRealm = realmName + + if err := c.authManager.CheckAuthorizationOnTargetRealm(ctx, action, targetRealm); err != nil { + return err + } + + return c.next.UpdateRealmCustomConfiguration(ctx, realmName, customConfig) +} \ No newline at end of file diff --git a/pkg/management/authorization_test.go b/pkg/management/authorization_test.go index 9b67cd7c4..c20b6579f 100644 --- a/pkg/management/authorization_test.go +++ b/pkg/management/authorization_test.go @@ -37,6 +37,8 @@ func TestDeny(t *testing.T) { var pass = "P@ssw0rd" + var clientURI = "https://wwww.cloudtrust.io" + mockKeycloakClient.EXPECT().GetUser(accessToken, realmName, userID).Return(kc.UserRepresentation{ Id: &userID, Username: &userUsername, @@ -60,6 +62,11 @@ func TestDeny(t *testing.T) { Value: &pass, } + var customConfig = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &clientURI, + } + // Nothing allowed { var authorizations, err = security.NewAuthorizationManager(mockKeycloakClient, `{}`) @@ -133,6 +140,12 @@ func TestDeny(t *testing.T) { _, err = authorizationMW.CreateClientRole(ctx, realmName, clientID, role) assert.Equal(t, security.ForbiddenError{}, err) + + _, err = authorizationMW.GetRealmCustomConfiguration(ctx, realmName) + assert.Equal(t, security.ForbiddenError{}, err) + + err = authorizationMW.UpdateRealmCustomConfiguration(ctx, realmName, customConfig) + assert.Equal(t, security.ForbiddenError{}, err) } } @@ -160,6 +173,8 @@ func TestAllowed(t *testing.T) { var pass = "P@ssw0rd" + var clientURI = "https://wwww.cloudtrust.io" + mockKeycloakClient.EXPECT().GetUser(accessToken, realmName, userID).Return(kc.UserRepresentation{ Id: &userID, Username: &userUsername, @@ -183,7 +198,12 @@ func TestAllowed(t *testing.T) { Value: &pass, } - // Nothing allowed + var customConfig = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &clientURI, + } + + // Anything allowed { var authorizations, err = security.NewAuthorizationManager(mockKeycloakClient, `{"master": { @@ -208,7 +228,9 @@ func TestAllowed(t *testing.T) { "GetRoles": {"*": {"*": {} }}, "GetRole": {"*": {"*": {} }}, "GetClientRoles": {"*": {"*": {} }}, - "CreateClientRole": {"*": {"*": {} }} + "CreateClientRole": {"*": {"*": {} }}, + "GetRealmCustomConfiguration": {"*": {"*": {} }}, + "UpdateRealmCustomConfiguration": {"*": {"*": {} }} } } }`) @@ -303,5 +325,13 @@ func TestAllowed(t *testing.T) { mockManagementComponent.EXPECT().CreateClientRole(ctx, realmName, clientID, role).Return("", nil).Times(1) _, err = authorizationMW.CreateClientRole(ctx, realmName, clientID, role) assert.Nil(t, err) + + mockManagementComponent.EXPECT().GetRealmCustomConfiguration(ctx, realmName).Return(customConfig, nil).Times(1) + _, err = authorizationMW.GetRealmCustomConfiguration(ctx, realmName) + assert.Nil(t, err) + + mockManagementComponent.EXPECT().UpdateRealmCustomConfiguration(ctx, realmName, customConfig).Return(nil).Times(1) + err = authorizationMW.UpdateRealmCustomConfiguration(ctx, realmName, customConfig) + assert.Nil(t, err) } } diff --git a/pkg/management/component.go b/pkg/management/component.go index 6ed776ac2..92a529e30 100644 --- a/pkg/management/component.go +++ b/pkg/management/component.go @@ -2,7 +2,9 @@ package management import ( "context" + "encoding/json" "regexp" + "strings" "time" api "github.com/cloudtrust/keycloak-bridge/api/management" @@ -33,7 +35,7 @@ type KeycloakClient interface { CreateClientRole(accessToken string, realmName, clientID string, role kc.RoleRepresentation) (string, error) } -// Component is the event component interface. +// Component is the management component interface. type Component interface { GetRealms(ctx context.Context) ([]api.RealmRepresentation, error) GetRealm(ctx context.Context, realmName string) (api.RealmRepresentation, error) @@ -56,12 +58,15 @@ type Component interface { GetRole(ctx context.Context, realmName string, roleID string) (api.RoleRepresentation, error) GetClientRoles(ctx context.Context, realmName, idClient string) ([]api.RoleRepresentation, error) CreateClientRole(ctx context.Context, realmName, clientID string, role api.RoleRepresentation) (string, error) + GetRealmCustomConfiguration(ctx context.Context, realmName string) (api.RealmCustomConfiguration, error) + UpdateRealmCustomConfiguration(ctx context.Context, realmID string, customConfig api.RealmCustomConfiguration) error } // Component is the management component. type component struct { keycloakClient KeycloakClient eventDBModule event.EventsDBModule + configDBModule ConfigurationDBModule } const ( @@ -69,10 +74,12 @@ const ( ) // NewComponent returns the management component. -func NewComponent(keycloakClient KeycloakClient, eventDBModule event.EventsDBModule) Component { + +func NewComponent(keycloakClient KeycloakClient, eventDBModule event.EventsDBModule, configDBModule ConfigurationDBModule) Component { return &component{ keycloakClient: keycloakClient, eventDBModule: eventDBModule, + configDBModule: configDBModule, } } @@ -677,3 +684,78 @@ func (c *component) CreateClientRole(ctx context.Context, realmName, clientID st return locationURL, nil } + +// Retrieve the configuration from the database +func (c *component) GetRealmCustomConfiguration(ctx context.Context, realmName string) (api.RealmCustomConfiguration, error) { + var accessToken = ctx.Value("access_token").(string) + + var customConfig = api.RealmCustomConfiguration{ + DefaultClientId: new(string), + DefaultRedirectUri: new(string), + } + // get the realm config from Keycloak + realmConfig, err := c.keycloakClient.GetRealm(accessToken, realmName) + if err != nil { + return customConfig, err + } + // from the realm ID, fetch the custom configuration + realmID := realmConfig.Id + customConfigJSON, err := c.configDBModule.GetConfiguration(ctx, *realmID) + if customConfigJSON == "" { + // database is empty + return customConfig, nil + } + // transform json string into + err = json.Unmarshal([]byte(customConfigJSON), &customConfig) + if err != nil { + return customConfig, err + } + return customConfig, nil +} + +// Update the configuration in the database; verify that the content of the configuration is coherent with Keycloak configuration +func (c *component) UpdateRealmCustomConfiguration(ctx context.Context, realmName string, customConfig api.RealmCustomConfiguration) error { + var accessToken = ctx.Value("access_token").(string) + + // get the realm config from Keycloak + realmConfig, err := c.keycloakClient.GetRealm(accessToken, realmName) + if err != nil { + return err + } + // get the desired client (from its ID) + clients, err := c.keycloakClient.GetClients(accessToken, realmName) + if err != nil { + return err + } + var match = false + for _, client := range clients { + if *client.ClientId != *customConfig.DefaultClientId { + continue + } + for _, redirectURI := range *client.RedirectUris { + // escape the regex-specific characters (dots for intance)... + matcher := regexp.QuoteMeta(redirectURI) + // ... but keep the stars + matcher = strings.Replace(matcher, "\\*", "*", -1) + match, _ = regexp.MatchString(matcher, *customConfig.DefaultRedirectUri) + if match { + break + } + } + } + if !match { + return HTTPError{ + Status: 400, + Message: "Invalid client ID or redirect URI", + } + } + // transform customConfig object into JSON string + configJSON, err := json.Marshal(customConfig) + if err != nil { + return err + } + // from the realm ID, update the custom configuration in the DB + realmID := realmConfig.Id + err = c.configDBModule.StoreOrUpdate(ctx, *realmID, string(configJSON)) + return err +} diff --git a/pkg/management/component_test.go b/pkg/management/component_test.go index 39de6b984..25e250b96 100644 --- a/pkg/management/component_test.go +++ b/pkg/management/component_test.go @@ -1,9 +1,11 @@ package management //go:generate mockgen -destination=./mock/keycloak_client.go -package=mock -mock_names=KeycloakClient=KeycloakClient github.com/cloudtrust/keycloak-bridge/pkg/management KeycloakClient +//go:generate mockgen -destination=./mock/module.go -package=mock -mock_names=ConfigurationDBModule=ConfigurationDBModule github.com/cloudtrust/keycloak-bridge/pkg/management ConfigurationDBModule import ( "context" + "errors" "fmt" "testing" "time" @@ -20,8 +22,9 @@ func TestGetRealms(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" @@ -82,8 +85,9 @@ func TestGetRealm(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -145,8 +149,9 @@ func TestGetClient(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -211,8 +216,9 @@ func TestGetClients(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -277,8 +283,9 @@ func TestCreateUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var username = "test" @@ -386,8 +393,9 @@ func TestDeleteUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var userID = "1234-7558-7645" @@ -426,8 +434,9 @@ func TestGetUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -507,8 +516,9 @@ func TestUpdateUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -627,8 +637,9 @@ func TestGetUsers(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -707,8 +718,9 @@ func TestGetUserAccountStatus(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmReq = "master" @@ -784,8 +796,9 @@ func TestGetClientRolesForUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -846,8 +859,9 @@ func TestAddClientRolesToUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -922,8 +936,9 @@ func TestGetRealmRolesForUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -983,8 +998,9 @@ func TestResetPassword(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -1038,8 +1054,9 @@ func TestSendVerifyEmail(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -1079,8 +1096,9 @@ func TestGetCredentialsForUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmReq = "master" var realmName = "otherRealm" @@ -1104,9 +1122,9 @@ func TestDeleteCredentialsForUser(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) - + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmReq = "master" var realmName = "master" @@ -1131,8 +1149,9 @@ func TestGetRoles(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -1191,8 +1210,9 @@ func TestGetRole(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -1248,8 +1268,9 @@ func TestGetClientRoles(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -1309,8 +1330,9 @@ func TestCreateClientRole(t *testing.T) { defer mockCtrl.Finish() var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) - var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule) + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) var accessToken = "TOKEN==" var realmName = "master" @@ -1366,3 +1388,272 @@ func TestCreateClientRole(t *testing.T) { assert.NotNil(t, err) } } + +func TestGetRealmCustomConfiguration(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) + var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) + + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) + + var accessToken = "TOKEN==" + var realmID = "master_id" + + // Get existing config + { + var id = realmID + var keycloakVersion = "4.8.3" + var realm = "master" + var displayName = "Master" + var enabled = true + + var kcRealmRep = kc.RealmRepresentation{ + Id: &id, + KeycloakVersion: &keycloakVersion, + Realm: &realm, + DisplayName: &displayName, + Enabled: &enabled, + } + + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + + var clientID = "ClientID" + var redirectURI = "http://redirect.url.com/test" + + var customRealmConfigStr = `{ + "default_client_id": "` + clientID + `", + "default_redirect_uri": "` + redirectURI + `" + }` + var configInit = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &redirectURI, + } + + var ctx = context.WithValue(context.Background(), "access_token", accessToken) + mockConfigurationDBModule.EXPECT().GetConfiguration(ctx, realmID).Return(customRealmConfigStr, nil).Times(1) + + configJSON, err := managementComponent.GetRealmCustomConfiguration(ctx, realmID) + + assert.Nil(t, err) + assert.Equal(t, *configJSON.DefaultClientId, *configInit.DefaultClientId) + assert.Equal(t, *configJSON.DefaultRedirectUri, *configInit.DefaultRedirectUri) + } + + // Get empty config + { + var id = realmID + var keycloakVersion = "4.8.3" + var realm = "master" + var displayName = "Master" + var enabled = true + + var kcRealmRep = kc.RealmRepresentation{ + Id: &id, + KeycloakVersion: &keycloakVersion, + Realm: &realm, + DisplayName: &displayName, + Enabled: &enabled, + } + + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + + var ctx = context.WithValue(context.Background(), "access_token", accessToken) + mockConfigurationDBModule.EXPECT().GetConfiguration(ctx, realmID).Return("", nil).Times(1) + + configJSON, err := managementComponent.GetRealmCustomConfiguration(ctx, realmID) + + assert.Nil(t, err) + assert.Equal(t, *configJSON.DefaultClientId, *new(string)) + assert.Equal(t, *configJSON.DefaultRedirectUri, *new(string)) + } + + // Invalid structure in DB + { + var id = realmID + var keycloakVersion = "4.8.3" + var realm = "master" + var displayName = "Master" + var enabled = true + + var kcRealmRep = kc.RealmRepresentation{ + Id: &id, + KeycloakVersion: &keycloakVersion, + Realm: &realm, + DisplayName: &displayName, + Enabled: &enabled, + } + + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + + var ctx = context.WithValue(context.Background(), "access_token", accessToken) + mockConfigurationDBModule.EXPECT().GetConfiguration(ctx, realmID).Return("928743", nil).Times(1) + + _, err := managementComponent.GetRealmCustomConfiguration(ctx, realmID) + + assert.NotNil(t, err) + } + + // Unknown realm + { + var id = realmID + var keycloakVersion = "4.8.3" + var realm = "master" + var displayName = "Master" + var enabled = true + + var kcRealmRep = kc.RealmRepresentation{ + Id: &id, + KeycloakVersion: &keycloakVersion, + Realm: &realm, + DisplayName: &displayName, + Enabled: &enabled, + } + + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, errors.New("error")).Times(1) + + var ctx = context.WithValue(context.Background(), "access_token", accessToken) + + _, err := managementComponent.GetRealmCustomConfiguration(ctx, realmID) + + assert.NotNil(t, err) + } +} + +func TestUpdateRealmCustomConfiguration(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + var mockKeycloakClient = mock.NewKeycloakClient(mockCtrl) + var mockEventDBModule = mock.NewEventsDBModule(mockCtrl) + var mockConfigurationDBModule = mock.NewConfigurationDBModule(mockCtrl) + + var managementComponent = NewComponent(mockKeycloakClient, mockEventDBModule, mockConfigurationDBModule) + + var accessToken = "TOKEN==" + var realmID = "master_id" + + var id = realmID + var keycloakVersion = "4.8.3" + var realm = "master" + var displayName = "Master" + var enabled = true + + var kcRealmRep = kc.RealmRepresentation{ + Id: &id, + KeycloakVersion: &keycloakVersion, + Realm: &realm, + DisplayName: &displayName, + Enabled: &enabled, + } + + var clients = make([]kc.ClientRepresentation, 2) + var clientID1 = "clientID1" + var clientName1 = "clientName1" + var redirectURIs1 = []string{"https://www.cloudtrust.io/*", "https://www.cloudtrust-old.com/*"} + var clientID2 = "clientID2" + var clientName2 = "clientName2" + var redirectURIs2 = []string{"https://www.cloudtrust2.io/*", "https://www.cloudtrust2-old.com/*"} + clients[0] = kc.ClientRepresentation{ + ClientId: &clientID1, + Name: &clientName1, + RedirectUris: &redirectURIs1, + } + clients[1] = kc.ClientRepresentation{ + ClientId: &clientID2, + Name: &clientName2, + RedirectUris: &redirectURIs2, + } + + var ctx = context.WithValue(context.Background(), "access_token", accessToken) + var clientID = "clientID1" + var redirectURI = "https://www.cloudtrust.io/test" + var configInit = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &redirectURI, + } + + // Update config with appropriate values + { + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + mockKeycloakClient.EXPECT().GetClients(accessToken, realmID).Return(clients, nil).Times(1) + mockConfigurationDBModule.EXPECT().StoreOrUpdate(ctx, realmID, gomock.Any()).Return(nil).Times(1) + err := managementComponent.UpdateRealmCustomConfiguration(ctx, realmID, configInit) + + assert.Nil(t, err) + } + + // Update config with unknown client ID + { + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + mockKeycloakClient.EXPECT().GetClients(accessToken, realmID).Return(clients, nil).Times(1) + + var clientID = "clientID1Nok" + var redirectURI = "https://www.cloudtrust.io/test" + var configInit = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &redirectURI, + } + err := managementComponent.UpdateRealmCustomConfiguration(ctx, realmID, configInit) + + assert.NotNil(t, err) + assert.IsType(t, HTTPError{}, err) + e := err.(HTTPError) + assert.Equal(t, 400, e.Status) + } + + // Update config with invalid redirect URI + { + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + mockKeycloakClient.EXPECT().GetClients(accessToken, realmID).Return(clients, nil).Times(1) + + var clientID = "clientID1" + var redirectURI = "https://www.cloudtrustnok.io/test" + var configInit = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &redirectURI, + } + err := managementComponent.UpdateRealmCustomConfiguration(ctx, realmID, configInit) + + assert.NotNil(t, err) + assert.IsType(t, HTTPError{}, err) + e := err.(HTTPError) + assert.Equal(t, 400, e.Status) + } + + // Update config with invalid redirect URI (trying to take advantage of the dots in the url) + { + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + mockKeycloakClient.EXPECT().GetClients(accessToken, realmID).Return(clients, nil).Times(1) + + var clientID = "clientID1" + var redirectURI = "https://wwwacloudtrust.io/test" + var configInit = api.RealmCustomConfiguration{ + DefaultClientId: &clientID, + DefaultRedirectUri: &redirectURI, + } + err := managementComponent.UpdateRealmCustomConfiguration(ctx, realmID, configInit) + + assert.NotNil(t, err) + assert.IsType(t, HTTPError{}, err) + e := err.(HTTPError) + assert.Equal(t, 400, e.Status) + } + + // error while calling GetClients + { + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kcRealmRep, nil).Times(1) + mockKeycloakClient.EXPECT().GetClients(accessToken, realmID).Return([]kc.ClientRepresentation{}, errors.New("error")).Times(1) + err := managementComponent.UpdateRealmCustomConfiguration(ctx, realmID, configInit) + + assert.NotNil(t, err) + } + + // error while calling GetRealm + { + mockKeycloakClient.EXPECT().GetRealm(accessToken, realmID).Return(kc.RealmRepresentation{}, errors.New("error")).Times(1) + err := managementComponent.UpdateRealmCustomConfiguration(ctx, realmID, configInit) + + assert.NotNil(t, err) + } +} diff --git a/pkg/management/endpoint.go b/pkg/management/endpoint.go index c18f0f739..7a8c26a33 100644 --- a/pkg/management/endpoint.go +++ b/pkg/management/endpoint.go @@ -11,27 +11,29 @@ import ( // Endpoints wraps a service behind a set of endpoints. type Endpoints struct { - GetRealms endpoint.Endpoint - GetRealm endpoint.Endpoint - GetClient endpoint.Endpoint - GetClients endpoint.Endpoint - DeleteUser endpoint.Endpoint - GetUser endpoint.Endpoint - UpdateUser endpoint.Endpoint - GetUsers endpoint.Endpoint - CreateUser endpoint.Endpoint - GetUserAccountStatus endpoint.Endpoint - GetClientRoleForUser endpoint.Endpoint - AddClientRoleToUser endpoint.Endpoint - GetRealmRoleForUser endpoint.Endpoint - ResetPassword endpoint.Endpoint - SendVerifyEmail endpoint.Endpoint - GetCredentialsForUser endpoint.Endpoint - DeleteCredentialsForUser endpoint.Endpoint - GetRoles endpoint.Endpoint - GetRole endpoint.Endpoint - GetClientRoles endpoint.Endpoint - CreateClientRole endpoint.Endpoint + GetRealms endpoint.Endpoint + GetRealm endpoint.Endpoint + GetClient endpoint.Endpoint + GetClients endpoint.Endpoint + DeleteUser endpoint.Endpoint + GetUser endpoint.Endpoint + UpdateUser endpoint.Endpoint + GetUsers endpoint.Endpoint + CreateUser endpoint.Endpoint + GetUserAccountStatus endpoint.Endpoint + GetClientRoleForUser endpoint.Endpoint + AddClientRoleToUser endpoint.Endpoint + GetRealmRoleForUser endpoint.Endpoint + ResetPassword endpoint.Endpoint + SendVerifyEmail endpoint.Endpoint + GetCredentialsForUser endpoint.Endpoint + DeleteCredentialsForUser endpoint.Endpoint + GetRoles endpoint.Endpoint + GetRole endpoint.Endpoint + GetClientRoles endpoint.Endpoint + CreateClientRole endpoint.Endpoint + GetRealmCustomConfiguration endpoint.Endpoint + UpdateRealmCustomConfiguration endpoint.Endpoint } // ManagementComponent is the interface of the component to send a query to Keycloak. @@ -57,6 +59,8 @@ type ManagementComponent interface { GetRole(ctx context.Context, realmName string, roleID string) (api.RoleRepresentation, error) GetClientRoles(ctx context.Context, realmName, idClient string) ([]api.RoleRepresentation, error) CreateClientRole(ctx context.Context, realmName, clientID string, role api.RoleRepresentation) (string, error) + GetRealmCustomConfiguration(ctx context.Context, realmID string) (api.RealmCustomConfiguration, error) + UpdateRealmCustomConfiguration(ctx context.Context, realmID string, customConfig api.RealmCustomConfiguration) error } // MakeRealmsEndpoint makes the Realms endpoint to retrieve all available realms. @@ -310,6 +314,29 @@ func MakeCreateClientRoleEndpoint(managementComponent ManagementComponent) endpo } } +func MakeGetRealmCustomConfigurationEndpoint(managementComponent ManagementComponent) endpoint.Endpoint { + return func(ctx context.Context, req interface{}) (interface{}, error) { + var m = req.(map[string]string) + + return managementComponent.GetRealmCustomConfiguration(ctx, m["realm"]) + } +} + +func MakeUpdateRealmCustomConfigurationEndpoint(managementComponent ManagementComponent) endpoint.Endpoint { + return func(ctx context.Context, req interface{}) (interface{}, error) { + var m = req.(map[string]string) + + configJson := []byte(m["body"]) + + var customConfig api.RealmCustomConfiguration + err := json.Unmarshal(configJson, &customConfig) + if err != nil { + return nil, err + } + return nil, managementComponent.UpdateRealmCustomConfiguration(ctx, m["realm"], customConfig) + } +} + type LocationHeader struct { URL string } diff --git a/pkg/management/endpoint_test.go b/pkg/management/endpoint_test.go index 2935f3e83..c6801c104 100644 --- a/pkg/management/endpoint_test.go +++ b/pkg/management/endpoint_test.go @@ -655,3 +655,70 @@ func TestCreateClientRoleEndpoint(t *testing.T) { assert.NotNil(t, err) } } + +func TestGetRealmCustomConfigurationEndpoint(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + + var mockManagementComponent = mock.NewManagementComponent(mockCtrl) + + var e = MakeGetRealmCustomConfigurationEndpoint(mockManagementComponent) + + // No error + { + var realmName = "master" + var clientID = "123456" + var ctx = context.Background() + var req = make(map[string]string) + req["realm"] = realmName + req["clientID"] = clientID + + mockManagementComponent.EXPECT().GetRealmCustomConfiguration(ctx, realmName).Return(api.RealmCustomConfiguration{}, nil).Times(1) + var res, err = e(ctx, req) + assert.Nil(t, err) + assert.NotNil(t, res) + } +} + +func TestUpdateRealmCustomConfigurationEndpoint(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + + var mockManagementComponent = mock.NewManagementComponent(mockCtrl) + + var e = MakeUpdateRealmCustomConfigurationEndpoint(mockManagementComponent) + + // No error + { + var realmName = "master" + var clientID = "123456" + var configJSON = "{\"DefaultClientId\":\"clientId\", \"DefaultRedirectUri\":\"http://cloudtrust.io\"}" + var ctx = context.Background() + var req = make(map[string]string) + req["realm"] = realmName + req["clientID"] = clientID + req["body"] = configJSON + + mockManagementComponent.EXPECT().UpdateRealmCustomConfiguration(ctx, realmName, gomock.Any()).Return(nil).Times(1) + var res, err = e(ctx, req) + assert.Nil(t, err) + assert.Nil(t, res) + } + + // JSON error + { + var realmName = "master" + var clientID = "123456" + var configJSON = "{\"DefaultClientId\":\"clientId\", \"DefaultRedirectUri\":\"http://cloudtrust.io\"" + var ctx = context.Background() + var req = make(map[string]string) + req["realm"] = realmName + req["clientID"] = clientID + req["body"] = configJSON + + mockManagementComponent.EXPECT().UpdateRealmCustomConfiguration(ctx, realmName, gomock.Any()).Return(nil).Times(0) + var res, err = e(ctx, req) + assert.NotNil(t, err) + assert.Nil(t, res) + } +} diff --git a/pkg/management/instrumenting.go b/pkg/management/instrumenting.go new file mode 100644 index 000000000..29c14a80c --- /dev/null +++ b/pkg/management/instrumenting.go @@ -0,0 +1,42 @@ +package management + +//go:generate mockgen -destination=./mock/instrumenting.go -package=mock -mock_names=Histogram=Histogram github.com/go-kit/kit/metrics Histogram + +import ( + "context" + "time" + + "github.com/go-kit/kit/metrics" +) + +// Instrumenting middleware at module level. +type configDBModuleInstrumentingMW struct { + h metrics.Histogram + next ConfigurationDBModule +} + +// MakeConfigurationDBModuleInstrumentingMW makes an instrumenting middleware at module level. +func MakeConfigurationDBModuleInstrumentingMW(h metrics.Histogram) func(ConfigurationDBModule) ConfigurationDBModule { + return func(next ConfigurationDBModule) ConfigurationDBModule { + return &configDBModuleInstrumentingMW{ + h: h, + next: next, + } + } +} + +// configDBModuleInstrumentingMW implements Module. +func (m *configDBModuleInstrumentingMW) StoreOrUpdate(ctx context.Context, realmName string, configJSON string) error { + defer func(begin time.Time) { + m.h.With("correlation_id", ctx.Value("correlation_id").(string)).Observe(time.Since(begin).Seconds()) + }(time.Now()) + return m.next.StoreOrUpdate(ctx, realmName, configJSON) +} + +// configDBModuleInstrumentingMW implements Module. +func (m *configDBModuleInstrumentingMW) GetConfiguration(ctx context.Context, realmName string) (string, error) { + defer func(begin time.Time) { + m.h.With("correlation_id", ctx.Value("correlation_id").(string)).Observe(time.Since(begin).Seconds()) + }(time.Now()) + return m.next.GetConfiguration(ctx, realmName) +} diff --git a/pkg/management/instrumenting_test.go b/pkg/management/instrumenting_test.go new file mode 100644 index 000000000..7c48ae454 --- /dev/null +++ b/pkg/management/instrumenting_test.go @@ -0,0 +1,52 @@ +package management + +import ( + "context" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/cloudtrust/keycloak-bridge/pkg/management/mock" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestComponentInstrumentingMW(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + var mockComponent = mock.NewConfigurationDBModule(mockCtrl) + var mockHistogram = mock.NewHistogram(mockCtrl) + + var m = MakeConfigurationDBModuleInstrumentingMW(mockHistogram)(mockComponent) + + rand.Seed(time.Now().UnixNano()) + var corrID = strconv.FormatUint(rand.Uint64(), 10) + var ctx = context.WithValue(context.Background(), "correlation_id", corrID) + + // Get configuration. + mockComponent.EXPECT().GetConfiguration(ctx, "realmID").Return("", nil).Times(1) + mockHistogram.EXPECT().With("correlation_id", corrID).Return(mockHistogram).Times(1) + mockHistogram.EXPECT().Observe(gomock.Any()).Return().Times(1) + m.GetConfiguration(ctx, "realmID") + + // Get configuration without correlation ID. + mockComponent.EXPECT().GetConfiguration(context.Background(), "realmID").Return("", nil).Times(1) + var f = func() { + m.GetConfiguration(context.Background(), "realmID") + } + assert.Panics(t, f) + + // Update configuration. + mockComponent.EXPECT().StoreOrUpdate(ctx, "realmID", "{}").Return(nil).Times(1) + mockHistogram.EXPECT().With("correlation_id", corrID).Return(mockHistogram).Times(1) + mockHistogram.EXPECT().Observe(gomock.Any()).Return().Times(1) + m.StoreOrUpdate(ctx, "realmID", "{}") + + // Update configuration without correlation ID. + mockComponent.EXPECT().StoreOrUpdate(context.Background(), "realmID", "{}").Return(nil).Times(1) + f = func() { + m.StoreOrUpdate(context.Background(), "realmID", "{}") + } + assert.Panics(t, f) +} diff --git a/pkg/management/logging.go b/pkg/management/logging.go new file mode 100644 index 000000000..1738b1e63 --- /dev/null +++ b/pkg/management/logging.go @@ -0,0 +1,42 @@ +package management + +//go:generate mockgen -destination=./mock/logging.go -package=mock -mock_names=Logger=Logger github.com/go-kit/kit/log Logger + +import ( + "context" + "time" + + "github.com/go-kit/kit/log" +) + +// Logging middleware for the statistic module. +type configDBModuleLoggingMW struct { + logger log.Logger + next ConfigurationDBModule +} + +// MakeConfigurationDBModuleLoggingMW makes a logging middleware for the statistic module. +func MakeConfigurationDBModuleLoggingMW(log log.Logger) func(ConfigurationDBModule) ConfigurationDBModule { + return func(next ConfigurationDBModule) ConfigurationDBModule { + return &configDBModuleLoggingMW{ + logger: log, + next: next, + } + } +} + +// configDBModuleLoggingMW implements ConfigurationDBModule. +func (m *configDBModuleLoggingMW) StoreOrUpdate(ctx context.Context, realmName string, configJSON string) error { + defer func(begin time.Time) { + m.logger.Log("method", "StoreOrUpdate", "args", realmName, configJSON, "took", time.Since(begin)) + }(time.Now()) + return m.next.StoreOrUpdate(ctx, realmName, configJSON) +} + +// configDBModuleLoggingMW implements ConfigurationDBModule. +func (m *configDBModuleLoggingMW) GetConfiguration(ctx context.Context, realmName string) (string, error) { + defer func(begin time.Time) { + m.logger.Log("method", "GetConfiguration", "args", realmName, "took", time.Since(begin)) + }(time.Now()) + return m.next.GetConfiguration(ctx, realmName) +} diff --git a/pkg/management/logging_test.go b/pkg/management/logging_test.go new file mode 100644 index 000000000..3a658b6f8 --- /dev/null +++ b/pkg/management/logging_test.go @@ -0,0 +1,37 @@ +package management + +//go:generate mockgen -destination=./mock/logging.go -package=mock -mock_names=Logger=Logger github.com/go-kit/kit/log Logger + +import ( + "context" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/cloudtrust/keycloak-bridge/pkg/management/mock" + "github.com/golang/mock/gomock" +) + +func TestComponentLoggingMW(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + var mockComponent = mock.NewConfigurationDBModule(mockCtrl) + var mockLogger = mock.NewLogger(mockCtrl) + + var m = MakeConfigurationDBModuleLoggingMW(mockLogger)(mockComponent) + + rand.Seed(time.Now().UnixNano()) + var corrID = strconv.FormatUint(rand.Uint64(), 10) + var ctx = context.WithValue(context.Background(), "correlation_id", corrID) + + // Get configuration. + mockComponent.EXPECT().GetConfiguration(ctx, "realmID").Return("", nil).Times(1) + mockLogger.EXPECT().Log("method", "GetConfiguration", "args", "realmID", "took", gomock.Any()).Return(nil).Times(1) + m.GetConfiguration(ctx, "realmID") + + // Update configuration. + mockComponent.EXPECT().StoreOrUpdate(ctx, "realmID", "{}").Return(nil).Times(1) + mockLogger.EXPECT().Log("method", "StoreOrUpdate", "args", "realmID", "{}", "took", gomock.Any()).Return(nil).Times(1) + m.StoreOrUpdate(ctx, "realmID", "{}") +} diff --git a/pkg/management/mock/component.go b/pkg/management/mock/component.go index e7c135036..455cfacab 100644 --- a/pkg/management/mock/component.go +++ b/pkg/management/mock/component.go @@ -174,6 +174,19 @@ func (mr *ManagementComponentMockRecorder) GetRealm(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRealm", reflect.TypeOf((*ManagementComponent)(nil).GetRealm), arg0, arg1) } +// GetRealmCustomConfiguration mocks base method +func (m *ManagementComponent) GetRealmCustomConfiguration(arg0 context.Context, arg1 string) (management.RealmCustomConfiguration, error) { + ret := m.ctrl.Call(m, "GetRealmCustomConfiguration", arg0, arg1) + ret0, _ := ret[0].(management.RealmCustomConfiguration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRealmCustomConfiguration indicates an expected call of GetRealmCustomConfiguration +func (mr *ManagementComponentMockRecorder) GetRealmCustomConfiguration(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRealmCustomConfiguration", reflect.TypeOf((*ManagementComponent)(nil).GetRealmCustomConfiguration), arg0, arg1) +} + // GetRealmRolesForUser mocks base method func (m *ManagementComponent) GetRealmRolesForUser(arg0 context.Context, arg1, arg2 string) ([]management.RoleRepresentation, error) { ret := m.ctrl.Call(m, "GetRealmRolesForUser", arg0, arg1, arg2) @@ -299,6 +312,18 @@ func (mr *ManagementComponentMockRecorder) SendVerifyEmail(arg0, arg1, arg2 inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendVerifyEmail", reflect.TypeOf((*ManagementComponent)(nil).SendVerifyEmail), varargs...) } +// UpdateRealmCustomConfiguration mocks base method +func (m *ManagementComponent) UpdateRealmCustomConfiguration(arg0 context.Context, arg1 string, arg2 management.RealmCustomConfiguration) error { + ret := m.ctrl.Call(m, "UpdateRealmCustomConfiguration", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateRealmCustomConfiguration indicates an expected call of UpdateRealmCustomConfiguration +func (mr *ManagementComponentMockRecorder) UpdateRealmCustomConfiguration(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRealmCustomConfiguration", reflect.TypeOf((*ManagementComponent)(nil).UpdateRealmCustomConfiguration), arg0, arg1, arg2) +} + // UpdateUser mocks base method func (m *ManagementComponent) UpdateUser(arg0 context.Context, arg1, arg2 string, arg3 management.UserRepresentation) error { ret := m.ctrl.Call(m, "UpdateUser", arg0, arg1, arg2, arg3) diff --git a/pkg/management/mock/configuration_db.go b/pkg/management/mock/configuration_db.go new file mode 100644 index 000000000..c72162445 --- /dev/null +++ b/pkg/management/mock/configuration_db.go @@ -0,0 +1,69 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/cloudtrust/keycloak-bridge/pkg/management (interfaces: DBConfiguration) + +// Package mock is a generated GoMock package. +package mock + +import ( + sql "database/sql" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// DBConfiguration is a mock of DBConfiguration interface +type DBConfiguration struct { + ctrl *gomock.Controller + recorder *DBConfigurationMockRecorder +} + +// DBConfigurationMockRecorder is the mock recorder for DBConfiguration +type DBConfigurationMockRecorder struct { + mock *DBConfiguration +} + +// NewDBConfiguration creates a new mock instance +func NewDBConfiguration(ctrl *gomock.Controller) *DBConfiguration { + mock := &DBConfiguration{ctrl: ctrl} + mock.recorder = &DBConfigurationMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *DBConfiguration) EXPECT() *DBConfigurationMockRecorder { + return m.recorder +} + +// Exec mocks base method +func (m *DBConfiguration) Exec(arg0 string, arg1 ...interface{}) (sql.Result, error) { + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec +func (mr *DBConfigurationMockRecorder) Exec(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*DBConfiguration)(nil).Exec), varargs...) +} + +// QueryRow mocks base method +func (m *DBConfiguration) QueryRow(arg0 string, arg1 ...interface{}) *sql.Row { + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(*sql.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow +func (mr *DBConfigurationMockRecorder) QueryRow(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*DBConfiguration)(nil).QueryRow), varargs...) +} diff --git a/pkg/management/mock/instrumenting.go b/pkg/management/mock/instrumenting.go new file mode 100644 index 000000000..4101e6a69 --- /dev/null +++ b/pkg/management/mock/instrumenting.go @@ -0,0 +1,60 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/go-kit/kit/metrics (interfaces: Histogram) + +// Package mock is a generated GoMock package. +package mock + +import ( + metrics "github.com/go-kit/kit/metrics" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// Histogram is a mock of Histogram interface +type Histogram struct { + ctrl *gomock.Controller + recorder *HistogramMockRecorder +} + +// HistogramMockRecorder is the mock recorder for Histogram +type HistogramMockRecorder struct { + mock *Histogram +} + +// NewHistogram creates a new mock instance +func NewHistogram(ctrl *gomock.Controller) *Histogram { + mock := &Histogram{ctrl: ctrl} + mock.recorder = &HistogramMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *Histogram) EXPECT() *HistogramMockRecorder { + return m.recorder +} + +// Observe mocks base method +func (m *Histogram) Observe(arg0 float64) { + m.ctrl.Call(m, "Observe", arg0) +} + +// Observe indicates an expected call of Observe +func (mr *HistogramMockRecorder) Observe(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Observe", reflect.TypeOf((*Histogram)(nil).Observe), arg0) +} + +// With mocks base method +func (m *Histogram) With(arg0 ...string) metrics.Histogram { + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "With", varargs...) + ret0, _ := ret[0].(metrics.Histogram) + return ret0 +} + +// With indicates an expected call of With +func (mr *HistogramMockRecorder) With(arg0 ...interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*Histogram)(nil).With), arg0...) +} diff --git a/pkg/management/mock/module.go b/pkg/management/mock/module.go new file mode 100644 index 000000000..035a8b8d9 --- /dev/null +++ b/pkg/management/mock/module.go @@ -0,0 +1,59 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/cloudtrust/keycloak-bridge/pkg/management (interfaces: ConfigurationDBModule) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// ConfigurationDBModule is a mock of ConfigurationDBModule interface +type ConfigurationDBModule struct { + ctrl *gomock.Controller + recorder *ConfigurationDBModuleMockRecorder +} + +// ConfigurationDBModuleMockRecorder is the mock recorder for ConfigurationDBModule +type ConfigurationDBModuleMockRecorder struct { + mock *ConfigurationDBModule +} + +// NewConfigurationDBModule creates a new mock instance +func NewConfigurationDBModule(ctrl *gomock.Controller) *ConfigurationDBModule { + mock := &ConfigurationDBModule{ctrl: ctrl} + mock.recorder = &ConfigurationDBModuleMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *ConfigurationDBModule) EXPECT() *ConfigurationDBModuleMockRecorder { + return m.recorder +} + +// GetConfiguration mocks base method +func (m *ConfigurationDBModule) GetConfiguration(arg0 context.Context, arg1 string) (string, error) { + ret := m.ctrl.Call(m, "GetConfiguration", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConfiguration indicates an expected call of GetConfiguration +func (mr *ConfigurationDBModuleMockRecorder) GetConfiguration(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfiguration", reflect.TypeOf((*ConfigurationDBModule)(nil).GetConfiguration), arg0, arg1) +} + +// StoreOrUpdate mocks base method +func (m *ConfigurationDBModule) StoreOrUpdate(arg0 context.Context, arg1, arg2 string) error { + ret := m.ctrl.Call(m, "StoreOrUpdate", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// StoreOrUpdate indicates an expected call of StoreOrUpdate +func (mr *ConfigurationDBModuleMockRecorder) StoreOrUpdate(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoreOrUpdate", reflect.TypeOf((*ConfigurationDBModule)(nil).StoreOrUpdate), arg0, arg1, arg2) +} diff --git a/pkg/management/mock/tracing.go b/pkg/management/mock/tracing.go new file mode 100644 index 000000000..bc4b09f88 --- /dev/null +++ b/pkg/management/mock/tracing.go @@ -0,0 +1,283 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/opentracing/opentracing-go (interfaces: Tracer,Span,SpanContext) + +// Package mock is a generated GoMock package. +package mock + +import ( + gomock "github.com/golang/mock/gomock" + opentracing_go "github.com/opentracing/opentracing-go" + log "github.com/opentracing/opentracing-go/log" + reflect "reflect" +) + +// Tracer is a mock of Tracer interface +type Tracer struct { + ctrl *gomock.Controller + recorder *TracerMockRecorder +} + +// TracerMockRecorder is the mock recorder for Tracer +type TracerMockRecorder struct { + mock *Tracer +} + +// NewTracer creates a new mock instance +func NewTracer(ctrl *gomock.Controller) *Tracer { + mock := &Tracer{ctrl: ctrl} + mock.recorder = &TracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *Tracer) EXPECT() *TracerMockRecorder { + return m.recorder +} + +// Extract mocks base method +func (m *Tracer) Extract(arg0, arg1 interface{}) (opentracing_go.SpanContext, error) { + ret := m.ctrl.Call(m, "Extract", arg0, arg1) + ret0, _ := ret[0].(opentracing_go.SpanContext) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Extract indicates an expected call of Extract +func (mr *TracerMockRecorder) Extract(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Extract", reflect.TypeOf((*Tracer)(nil).Extract), arg0, arg1) +} + +// Inject mocks base method +func (m *Tracer) Inject(arg0 opentracing_go.SpanContext, arg1, arg2 interface{}) error { + ret := m.ctrl.Call(m, "Inject", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Inject indicates an expected call of Inject +func (mr *TracerMockRecorder) Inject(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Inject", reflect.TypeOf((*Tracer)(nil).Inject), arg0, arg1, arg2) +} + +// StartSpan mocks base method +func (m *Tracer) StartSpan(arg0 string, arg1 ...opentracing_go.StartSpanOption) opentracing_go.Span { + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StartSpan", varargs...) + ret0, _ := ret[0].(opentracing_go.Span) + return ret0 +} + +// StartSpan indicates an expected call of StartSpan +func (mr *TracerMockRecorder) StartSpan(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSpan", reflect.TypeOf((*Tracer)(nil).StartSpan), varargs...) +} + +// Span is a mock of Span interface +type Span struct { + ctrl *gomock.Controller + recorder *SpanMockRecorder +} + +// SpanMockRecorder is the mock recorder for Span +type SpanMockRecorder struct { + mock *Span +} + +// NewSpan creates a new mock instance +func NewSpan(ctrl *gomock.Controller) *Span { + mock := &Span{ctrl: ctrl} + mock.recorder = &SpanMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *Span) EXPECT() *SpanMockRecorder { + return m.recorder +} + +// BaggageItem mocks base method +func (m *Span) BaggageItem(arg0 string) string { + ret := m.ctrl.Call(m, "BaggageItem", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// BaggageItem indicates an expected call of BaggageItem +func (mr *SpanMockRecorder) BaggageItem(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BaggageItem", reflect.TypeOf((*Span)(nil).BaggageItem), arg0) +} + +// Context mocks base method +func (m *Span) Context() opentracing_go.SpanContext { + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(opentracing_go.SpanContext) + return ret0 +} + +// Context indicates an expected call of Context +func (mr *SpanMockRecorder) Context() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*Span)(nil).Context)) +} + +// Finish mocks base method +func (m *Span) Finish() { + m.ctrl.Call(m, "Finish") +} + +// Finish indicates an expected call of Finish +func (mr *SpanMockRecorder) Finish() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*Span)(nil).Finish)) +} + +// FinishWithOptions mocks base method +func (m *Span) FinishWithOptions(arg0 opentracing_go.FinishOptions) { + m.ctrl.Call(m, "FinishWithOptions", arg0) +} + +// FinishWithOptions indicates an expected call of FinishWithOptions +func (mr *SpanMockRecorder) FinishWithOptions(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinishWithOptions", reflect.TypeOf((*Span)(nil).FinishWithOptions), arg0) +} + +// Log mocks base method +func (m *Span) Log(arg0 opentracing_go.LogData) { + m.ctrl.Call(m, "Log", arg0) +} + +// Log indicates an expected call of Log +func (mr *SpanMockRecorder) Log(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Log", reflect.TypeOf((*Span)(nil).Log), arg0) +} + +// LogEvent mocks base method +func (m *Span) LogEvent(arg0 string) { + m.ctrl.Call(m, "LogEvent", arg0) +} + +// LogEvent indicates an expected call of LogEvent +func (mr *SpanMockRecorder) LogEvent(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogEvent", reflect.TypeOf((*Span)(nil).LogEvent), arg0) +} + +// LogEventWithPayload mocks base method +func (m *Span) LogEventWithPayload(arg0 string, arg1 interface{}) { + m.ctrl.Call(m, "LogEventWithPayload", arg0, arg1) +} + +// LogEventWithPayload indicates an expected call of LogEventWithPayload +func (mr *SpanMockRecorder) LogEventWithPayload(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogEventWithPayload", reflect.TypeOf((*Span)(nil).LogEventWithPayload), arg0, arg1) +} + +// LogFields mocks base method +func (m *Span) LogFields(arg0 ...log.Field) { + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "LogFields", varargs...) +} + +// LogFields indicates an expected call of LogFields +func (mr *SpanMockRecorder) LogFields(arg0 ...interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogFields", reflect.TypeOf((*Span)(nil).LogFields), arg0...) +} + +// LogKV mocks base method +func (m *Span) LogKV(arg0 ...interface{}) { + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "LogKV", varargs...) +} + +// LogKV indicates an expected call of LogKV +func (mr *SpanMockRecorder) LogKV(arg0 ...interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogKV", reflect.TypeOf((*Span)(nil).LogKV), arg0...) +} + +// SetBaggageItem mocks base method +func (m *Span) SetBaggageItem(arg0, arg1 string) opentracing_go.Span { + ret := m.ctrl.Call(m, "SetBaggageItem", arg0, arg1) + ret0, _ := ret[0].(opentracing_go.Span) + return ret0 +} + +// SetBaggageItem indicates an expected call of SetBaggageItem +func (mr *SpanMockRecorder) SetBaggageItem(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBaggageItem", reflect.TypeOf((*Span)(nil).SetBaggageItem), arg0, arg1) +} + +// SetOperationName mocks base method +func (m *Span) SetOperationName(arg0 string) opentracing_go.Span { + ret := m.ctrl.Call(m, "SetOperationName", arg0) + ret0, _ := ret[0].(opentracing_go.Span) + return ret0 +} + +// SetOperationName indicates an expected call of SetOperationName +func (mr *SpanMockRecorder) SetOperationName(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOperationName", reflect.TypeOf((*Span)(nil).SetOperationName), arg0) +} + +// SetTag mocks base method +func (m *Span) SetTag(arg0 string, arg1 interface{}) opentracing_go.Span { + ret := m.ctrl.Call(m, "SetTag", arg0, arg1) + ret0, _ := ret[0].(opentracing_go.Span) + return ret0 +} + +// SetTag indicates an expected call of SetTag +func (mr *SpanMockRecorder) SetTag(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTag", reflect.TypeOf((*Span)(nil).SetTag), arg0, arg1) +} + +// Tracer mocks base method +func (m *Span) Tracer() opentracing_go.Tracer { + ret := m.ctrl.Call(m, "Tracer") + ret0, _ := ret[0].(opentracing_go.Tracer) + return ret0 +} + +// Tracer indicates an expected call of Tracer +func (mr *SpanMockRecorder) Tracer() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tracer", reflect.TypeOf((*Span)(nil).Tracer)) +} + +// SpanContext is a mock of SpanContext interface +type SpanContext struct { + ctrl *gomock.Controller + recorder *SpanContextMockRecorder +} + +// SpanContextMockRecorder is the mock recorder for SpanContext +type SpanContextMockRecorder struct { + mock *SpanContext +} + +// NewSpanContext creates a new mock instance +func NewSpanContext(ctrl *gomock.Controller) *SpanContext { + mock := &SpanContext{ctrl: ctrl} + mock.recorder = &SpanContextMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *SpanContext) EXPECT() *SpanContextMockRecorder { + return m.recorder +} + +// ForeachBaggageItem mocks base method +func (m *SpanContext) ForeachBaggageItem(arg0 func(string, string) bool) { + m.ctrl.Call(m, "ForeachBaggageItem", arg0) +} + +// ForeachBaggageItem indicates an expected call of ForeachBaggageItem +func (mr *SpanContextMockRecorder) ForeachBaggageItem(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForeachBaggageItem", reflect.TypeOf((*SpanContext)(nil).ForeachBaggageItem), arg0) +} diff --git a/pkg/management/module.go b/pkg/management/module.go new file mode 100644 index 000000000..5d8f4f1da --- /dev/null +++ b/pkg/management/module.go @@ -0,0 +1,61 @@ +package management + +import ( + "context" + "database/sql" +) + +const ( + createConfigTableStmt = `CREATE TABLE IF NOT EXISTS realm_configuration( + id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, + realm_id VARCHAR(255) NOT NULL, + configuration JSON, + CHECK (configuration IS NULL OR JSON_VALID(configuration)) + ); + CREATE UNIQUE INDEX IF NOT EXISTS realm_id_idx ON realm_configuration(realm_id);` + updateConfigStmt = `INSERT INTO realm_configuration (realm_id, configuration) + VALUES (?, ?) + ON DUPLICATE KEY UPDATE configuration = ?;` + selectConfigStmt = `SELECT configuration FROM realm_configuration WHERE (realm_id = ?)` +) + +type DBConfiguration interface { + Exec(query string, args ...interface{}) (sql.Result, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +// ConfigurationDBModule is the interface of the configuration module. +type ConfigurationDBModule interface { + StoreOrUpdate(context.Context, string, string) error + GetConfiguration(context.Context, string) (string, error) +} + +type configurationDBModule struct { + db DBConfiguration +} + +// NewConfigurationDBModule returns a ConfigurationDB module. +func NewConfigurationDBModule(db DBConfiguration) ConfigurationDBModule { + db.Exec(createConfigTableStmt) + return &configurationDBModule{ + db: db, + } +} + +func (c *configurationDBModule) StoreOrUpdate(context context.Context, realmID string, configJSON string) error { + // update value in DB + _, err := c.db.Exec(updateConfigStmt, realmID, configJSON, configJSON) + return err +} + +func (c *configurationDBModule) GetConfiguration(context context.Context, realmID string) (string, error) { + var configJSON string + row := c.db.QueryRow(selectConfigStmt, realmID) + + switch err := row.Scan(&configJSON); err { + case sql.ErrNoRows: + return configJSON, nil + default: + return configJSON, err + } +} diff --git a/pkg/management/module_test.go b/pkg/management/module_test.go new file mode 100644 index 000000000..b2887f2e9 --- /dev/null +++ b/pkg/management/module_test.go @@ -0,0 +1,24 @@ +package management + +//go:generate mockgen -destination=./mock/configuration_db.go -package=mock -mock_names=DBConfiguration=DBConfiguration github.com/cloudtrust/keycloak-bridge/pkg/management DBConfiguration + +import ( + "context" + "testing" + + "github.com/cloudtrust/keycloak-bridge/pkg/management/mock" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestConfigurationDBModule(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + var mockDB = mock.NewDBConfiguration(mockCtrl) + + mockDB.EXPECT().Exec(gomock.Any()).Return(nil, nil).Times(1) + mockDB.EXPECT().Exec(gomock.Any(), "realmId", gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + var configDBModule = NewConfigurationDBModule(mockDB) + var err = configDBModule.StoreOrUpdate(context.Background(), "realmId", "{}") + assert.Nil(t, err) +} diff --git a/pkg/management/tracing.go b/pkg/management/tracing.go new file mode 100644 index 000000000..c2b6ffcb8 --- /dev/null +++ b/pkg/management/tracing.go @@ -0,0 +1,51 @@ +package management + +//go:generate mockgen -destination=./mock/tracing.go -package=mock -mock_names=Tracer=Tracer,Span=Span,SpanContext=SpanContext github.com/opentracing/opentracing-go Tracer,Span,SpanContext + +import ( + "context" + + opentracing "github.com/opentracing/opentracing-go" +) + +// Tracing middleware at module level. +type configDBModuleTracingMW struct { + tracer opentracing.Tracer + next ConfigurationDBModule +} + +// MakeConfigurationDBModuleTracingMW makes a tracing middleware at component level. +func MakeConfigurationDBModuleTracingMW(tracer opentracing.Tracer) func(ConfigurationDBModule) ConfigurationDBModule { + return func(next ConfigurationDBModule) ConfigurationDBModule { + return &configDBModuleTracingMW{ + tracer: tracer, + next: next, + } + } +} + +// configDBModuleTracingMW implements StatisticModule. +func (m *configDBModuleTracingMW) StoreOrUpdate(ctx context.Context, realmName string, configJSON string) error { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = m.tracer.StartSpan("configurationDB_module", opentracing.ChildOf(span.Context())) + defer span.Finish() + span.SetTag("correlation_id", ctx.Value("correlation_id").(string)) + + ctx = opentracing.ContextWithSpan(ctx, span) + } + + return m.next.StoreOrUpdate(ctx, realmName, configJSON) +} + +// configDBModuleTracingMW implements StatisticModule. +func (m *configDBModuleTracingMW) GetConfiguration(ctx context.Context, realmName string) (string, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = m.tracer.StartSpan("configurationDB_module", opentracing.ChildOf(span.Context())) + defer span.Finish() + span.SetTag("correlation_id", ctx.Value("correlation_id").(string)) + + ctx = opentracing.ContextWithSpan(ctx, span) + } + + return m.next.GetConfiguration(ctx, realmName) +} diff --git a/pkg/management/tracing_test.go b/pkg/management/tracing_test.go new file mode 100644 index 000000000..e4c2f94b7 --- /dev/null +++ b/pkg/management/tracing_test.go @@ -0,0 +1,89 @@ +package management + +import ( + "context" + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/cloudtrust/keycloak-bridge/pkg/management/mock" + "github.com/golang/mock/gomock" + opentracing "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" +) + +func TestConfigurationDBModuleMW(t *testing.T) { + var mockCtrl = gomock.NewController(t) + defer mockCtrl.Finish() + var mockConfigDBModule = mock.NewConfigurationDBModule(mockCtrl) + var mockTracer = mock.NewTracer(mockCtrl) + var mockSpan = mock.NewSpan(mockCtrl) + var mockSpanContext = mock.NewSpanContext(mockCtrl) + + var m = MakeConfigurationDBModuleTracingMW(mockTracer)(mockConfigDBModule) + + rand.Seed(time.Now().UnixNano()) + var corrID = strconv.FormatUint(rand.Uint64(), 10) + var ctx = context.WithValue(context.Background(), "correlation_id", corrID) + ctx = opentracing.ContextWithSpan(ctx, mockSpan) + + // Get configuration. + mockConfigDBModule.EXPECT().GetConfiguration(gomock.Any(), "realmID").Return("", nil).Times(1) + mockTracer.EXPECT().StartSpan("configurationDB_module", gomock.Any()).Return(mockSpan).Times(1) + mockSpan.EXPECT().Context().Return(mockSpanContext).Times(1) + mockSpan.EXPECT().Finish().Return().Times(1) + mockSpan.EXPECT().SetTag("correlation_id", corrID).Return(mockSpan).Times(1) + m.GetConfiguration(ctx, "realmID") + + // Get configuration error + mockConfigDBModule.EXPECT().GetConfiguration(gomock.Any(), "realmID").Return("", fmt.Errorf("fail")).Times(1) + mockTracer.EXPECT().StartSpan("configurationDB_module", gomock.Any()).Return(mockSpan).Times(1) + mockSpan.EXPECT().Context().Return(mockSpanContext).Times(1) + mockSpan.EXPECT().Finish().Return().Times(1) + mockSpan.EXPECT().SetTag("correlation_id", corrID).Return(mockSpan).Times(1) + m.GetConfiguration(ctx, "realmID") + + // Get configuration without tracer. + mockConfigDBModule.EXPECT().GetConfiguration(gomock.Any(), "realmID").Return("", nil).Times(1) + m.GetConfiguration(context.Background(), "realmID") + + // Get configuration without correlation ID. + mockTracer.EXPECT().StartSpan("configurationDB_module", gomock.Any()).Return(mockSpan).Times(1) + mockSpan.EXPECT().Context().Return(mockSpanContext).Times(1) + mockSpan.EXPECT().Finish().Return().Times(1) + var f = func() { + m.GetConfiguration(opentracing.ContextWithSpan(context.Background(), mockSpan), "realmID") + } + assert.Panics(t, f) + + // Store configuration. + mockConfigDBModule.EXPECT().StoreOrUpdate(gomock.Any(), "realmID", "{}").Return(nil).Times(1) + mockTracer.EXPECT().StartSpan("configurationDB_module", gomock.Any()).Return(mockSpan).Times(1) + mockSpan.EXPECT().Context().Return(mockSpanContext).Times(1) + mockSpan.EXPECT().Finish().Return().Times(1) + mockSpan.EXPECT().SetTag("correlation_id", corrID).Return(mockSpan).Times(1) + m.StoreOrUpdate(ctx, "realmID", "{}") + + // Store configuration error + mockConfigDBModule.EXPECT().StoreOrUpdate(gomock.Any(), "realmID", "{}").Return(fmt.Errorf("fail")).Times(1) + mockTracer.EXPECT().StartSpan("configurationDB_module", gomock.Any()).Return(mockSpan).Times(1) + mockSpan.EXPECT().Context().Return(mockSpanContext).Times(1) + mockSpan.EXPECT().Finish().Return().Times(1) + mockSpan.EXPECT().SetTag("correlation_id", corrID).Return(mockSpan).Times(1) + m.StoreOrUpdate(ctx, "realmID", "{}") + + // Get configuration without tracer. + mockConfigDBModule.EXPECT().StoreOrUpdate(gomock.Any(), "realmID", "{}").Return(nil).Times(1) + m.StoreOrUpdate(context.Background(), "realmID", "{}") + + // Get configuration without correlation ID. + mockTracer.EXPECT().StartSpan("configurationDB_module", gomock.Any()).Return(mockSpan).Times(1) + mockSpan.EXPECT().Context().Return(mockSpanContext).Times(1) + mockSpan.EXPECT().Finish().Return().Times(1) + f = func() { + m.StoreOrUpdate(opentracing.ContextWithSpan(context.Background(), mockSpan), "realmID", "{}") + } + assert.Panics(t, f) +} diff --git a/pkg/middleware/mock/management_component.go b/pkg/middleware/mock/management_component.go index e7c135036..455cfacab 100644 --- a/pkg/middleware/mock/management_component.go +++ b/pkg/middleware/mock/management_component.go @@ -174,6 +174,19 @@ func (mr *ManagementComponentMockRecorder) GetRealm(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRealm", reflect.TypeOf((*ManagementComponent)(nil).GetRealm), arg0, arg1) } +// GetRealmCustomConfiguration mocks base method +func (m *ManagementComponent) GetRealmCustomConfiguration(arg0 context.Context, arg1 string) (management.RealmCustomConfiguration, error) { + ret := m.ctrl.Call(m, "GetRealmCustomConfiguration", arg0, arg1) + ret0, _ := ret[0].(management.RealmCustomConfiguration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRealmCustomConfiguration indicates an expected call of GetRealmCustomConfiguration +func (mr *ManagementComponentMockRecorder) GetRealmCustomConfiguration(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRealmCustomConfiguration", reflect.TypeOf((*ManagementComponent)(nil).GetRealmCustomConfiguration), arg0, arg1) +} + // GetRealmRolesForUser mocks base method func (m *ManagementComponent) GetRealmRolesForUser(arg0 context.Context, arg1, arg2 string) ([]management.RoleRepresentation, error) { ret := m.ctrl.Call(m, "GetRealmRolesForUser", arg0, arg1, arg2) @@ -299,6 +312,18 @@ func (mr *ManagementComponentMockRecorder) SendVerifyEmail(arg0, arg1, arg2 inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendVerifyEmail", reflect.TypeOf((*ManagementComponent)(nil).SendVerifyEmail), varargs...) } +// UpdateRealmCustomConfiguration mocks base method +func (m *ManagementComponent) UpdateRealmCustomConfiguration(arg0 context.Context, arg1 string, arg2 management.RealmCustomConfiguration) error { + ret := m.ctrl.Call(m, "UpdateRealmCustomConfiguration", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateRealmCustomConfiguration indicates an expected call of UpdateRealmCustomConfiguration +func (mr *ManagementComponentMockRecorder) UpdateRealmCustomConfiguration(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRealmCustomConfiguration", reflect.TypeOf((*ManagementComponent)(nil).UpdateRealmCustomConfiguration), arg0, arg1, arg2) +} + // UpdateUser mocks base method func (m *ManagementComponent) UpdateUser(arg0 context.Context, arg1, arg2 string, arg3 management.UserRepresentation) error { ret := m.ctrl.Call(m, "UpdateUser", arg0, arg1, arg2, arg3)