diff --git a/api/types/database.go b/api/types/database.go index 173fc7458a91c..72c8019d1b013 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -314,6 +314,11 @@ func (d *DatabaseV3) SupportsAutoUsers() bool { case DatabaseTypeSelfHosted, DatabaseTypeRDS: return true } + case DatabaseProtocolMySQL: + switch d.GetType() { + case DatabaseTypeSelfHosted, DatabaseTypeRDS: + return true + } } return false } diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 028adfb30db78..9c41e36c08b89 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -2650,13 +2650,41 @@ func withAzurePostgres(name, authToken string) withDatabaseOption { } } -func withSelfHostedMySQL(name string, opts ...mysql.TestServerOption) withDatabaseOption { +type selfHostedMySQLOptions struct { + serverOptions []mysql.TestServerOption + databaseOptions []databaseOption +} + +type selfHostedMySQLOption func(*selfHostedMySQLOptions) + +func withMySQLServerVersion(version string) selfHostedMySQLOption { + return func(opts *selfHostedMySQLOptions) { + opts.serverOptions = append(opts.serverOptions, mysql.WithServerVersion(version)) + } +} + +func withMySQLAdminUser(username string) selfHostedMySQLOption { + return func(opts *selfHostedMySQLOptions) { + opts.databaseOptions = append(opts.databaseOptions, func(db *types.DatabaseV3) { + db.Spec.AdminUser = &types.DatabaseAdminUser{ + Name: username, + } + }) + } +} + +func withSelfHostedMySQL(name string, applyOpts ...selfHostedMySQLOption) withDatabaseOption { return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { + opts := selfHostedMySQLOptions{} + for _, applyOpt := range applyOpts { + applyOpt(&opts) + } + mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, ClientAuth: tls.RequireAndVerifyClientCert, - }, opts...) + }, opts.serverOptions...) require.NoError(t, err) go mysqlServer.Serve() t.Cleanup(func() { @@ -2670,6 +2698,11 @@ func withSelfHostedMySQL(name string, opts ...mysql.TestServerOption) withDataba DynamicLabels: dynamicLabels, }) require.NoError(t, err) + + for _, applyDatabaseOpt := range opts.databaseOptions { + applyDatabaseOpt(database) + } + testCtx.mysql[name] = testMySQL{ db: mysqlServer, resource: database, diff --git a/lib/srv/db/autousers_test.go b/lib/srv/db/autousers_test.go index 6fa350ed40e74..2ec2985492933 100644 --- a/lib/srv/db/autousers_test.go +++ b/lib/srv/db/autousers_test.go @@ -75,3 +75,56 @@ func TestAutoUsersPostgres(t *testing.T) { t.Fatal("user not deactivated after 5s") } } + +func TestAutoUsersMySQL(t *testing.T) { + ctx := context.Background() + testCtx := setupTestContext(ctx, t, withSelfHostedMySQL("mysql", withMySQLAdminUser("admin"))) + go testCtx.startHandlingConnections() + + // Use a long name to test hashed name is used in database. + teleportUser := "a.very.long.name@teleport.example.com" + wantDatabaseUser := "tp-ZLhdP1FgxXsUvcVpG8ucVm/PCHg" + + // Create user with role that allows user provisioning. + _, role, err := auth.CreateUserAndRole(testCtx.tlsServer.Auth(), teleportUser, []string{"auto"}, nil) + require.NoError(t, err) + options := role.GetOptions() + options.CreateDatabaseUser = types.NewBoolOption(true) + role.SetOptions(options) + role.SetDatabaseRoles(types.Allow, []string{"reader", "writer"}) + role.SetDatabaseNames(types.Allow, []string{"*"}) + err = testCtx.tlsServer.Auth().UpsertRole(ctx, role) + require.NoError(t, err) + + // DatabaseUser must match identity. + _, err = testCtx.mysqlClient(teleportUser, "mysql", "user1") + require.Error(t, err) + + // Try to connect to the database as this user. + mysqlConn, err := testCtx.mysqlClient(teleportUser, "mysql", teleportUser) + require.NoError(t, err) + + select { + case e := <-testCtx.mysql["mysql"].db.UserEventsCh(): + require.Equal(t, teleportUser, e.TeleportUser) + require.Equal(t, wantDatabaseUser, e.DatabaseUser) + require.Equal(t, []string{"reader", "writer"}, e.Roles) + require.True(t, e.Active) + case <-time.After(5 * time.Second): + t.Fatal("user not activated after 5s") + } + + // Disconnect. + err = mysqlConn.Close() + require.NoError(t, err) + + // Verify user was deactivated. + select { + case e := <-testCtx.mysql["mysql"].db.UserEventsCh(): + require.Equal(t, teleportUser, e.TeleportUser) + require.Equal(t, wantDatabaseUser, e.DatabaseUser) + require.False(t, e.Active) + case <-time.After(5 * time.Second): + t.Fatal("user not deactivated after 5s") + } +} diff --git a/lib/srv/db/common/autousers.go b/lib/srv/db/common/autousers.go index 440cbb6280f94..1ff700c34edec 100644 --- a/lib/srv/db/common/autousers.go +++ b/lib/srv/db/common/autousers.go @@ -25,6 +25,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" @@ -68,6 +69,9 @@ func (a *UserProvisioner) Activate(ctx context.Context, sessionCtx *Session) (fu "your Teleport administrator") } + // Observe. + defer methodCallMetrics("UserProvisioner:Activate", teleport.ComponentDatabase, sessionCtx.Database)() + retryCtx, cancel := context.WithTimeout(ctx, defaults.DatabaseConnectTimeout) defer cancel() @@ -101,6 +105,9 @@ func (a *UserProvisioner) Deactivate(ctx context.Context, sessionCtx *Session) e return nil } + // Observe. + defer methodCallMetrics("UserProvisioner:Deactivate", teleport.ComponentDatabase, sessionCtx.Database)() + retryCtx, cancel := context.WithTimeout(ctx, defaults.DatabaseConnectTimeout) defer cancel() diff --git a/lib/srv/db/common/session.go b/lib/srv/db/common/session.go index 95f9e8afd951b..cb9c821b37c80 100644 --- a/lib/srv/db/common/session.go +++ b/lib/srv/db/common/session.go @@ -84,3 +84,11 @@ func (c *Session) WithUser(user string) *Session { copy.DatabaseUser = user return © } + +// WithUserAndDatabase returns a shallow copy of the session with overridden +// database user and overridden database name. +func (c *Session) WithUserAndDatabase(user string, defaultDatabase string) *Session { + copy := c.WithUser(user) + copy.DatabaseName = defaultDatabase + return copy +} diff --git a/lib/srv/db/mysql/autousers.go b/lib/srv/db/mysql/autousers.go new file mode 100644 index 0000000000000..8007fc18fe16c --- /dev/null +++ b/lib/srv/db/mysql/autousers.go @@ -0,0 +1,489 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mysql + +import ( + "context" + "crypto/sha1" + _ "embed" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/coreos/go-semver/semver" + "github.com/go-mysql-org/go-mysql/client" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/srv/db/common" +) + +// activateUserDetails contains details about the user activation request and +// will be marshaled into a JSON parameter for the stored procedure. +type activateUserDetails struct { + // Roles is a list of roles to be assigned to the user. + // + // MySQL stored procedure does not accept array of strings thus using a + // JSON to bypass this limit. + Roles []string `json:"roles"` + // AuthOptions specifies auth options like "IDENTIFIED xxx" used when + // creating a new user. + // + // Using a JSON string can bypass VARCHAR character limit. + AuthOptions string `json:"auth_options"` + // Attributes specifies user attributes used when creating a new user. + // + // User attributes is a MySQL JSON in MySQL databases. + // + // To check current user's attribute: + // SELECT * FROM INFORMATION_SCHEMA.USER_ATTRIBUTES WHERE CONCAT(USER, '@', HOST) = current_user() + // + // Reference: + // https://dev.mysql.com/doc/refman/8.0/en/information-schema-user-attributes-table.html + Attributes struct { + // User is the original Teleport user name. + // + // Find a Teleport user (with "admin" privilege): + // SELECT * FROM INFORMATION_SCHEMA.USER_ATTRIBUTES WHERE ATTRIBUTE->"$.user" = "teleport-user-name"; + User string `json:"user"` + } `json:"attributes"` +} + +// clientConn is a wrapper of client.Conn. +type clientConn struct { + *client.Conn +} + +func (c *clientConn) executeAndCloseResult(command string, args ...any) error { + result, err := c.Execute(command, args...) + if result != nil { + result.Close() + } + return trace.Wrap(err) +} + +func (c *clientConn) isMariaDB() bool { + return strings.Contains(strings.ToLower(c.GetServerVersion()), "mariadb") +} + +// maxUsernameLength returns the username/role character limit. +func (c *clientConn) maxUsernameLength() int { + if c.isMariaDB() { + return mariadbMaxUsernameLength + } + return mysqlMaxUsernameLength +} + +// ActivateUser creates or enables the database user. +func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) error { + if sessionCtx.Database.GetAdminUser() == "" { + return trace.BadParameter("Teleport does not have admin user configured for this database") + } + + conn, err := e.connectAsAdminUser(ctx, sessionCtx) + if err != nil { + return trace.Wrap(err) + } + defer conn.Close() + + // Ensure version is supported. + if err := checkSupportedVersion(conn); err != nil { + return trace.Wrap(err) + } + + // Ensure the roles meet spec. + if err := checkRoles(conn, sessionCtx.DatabaseRoles); err != nil { + return trace.Wrap(err) + } + + // Setup "teleport-auto-user" and stored procedures. + if err := e.setupDatabaseForAutoUsers(conn, sessionCtx); err != nil { + return trace.Wrap(err) + } + + // Use "tp-" in case DatabaseUser is over max username length. + sessionCtx.DatabaseUser = maybeHashUsername(sessionCtx.DatabaseUser, conn.maxUsernameLength()) + e.Log.Infof("Activating MySQL user %q with roles %v for %v.", sessionCtx.DatabaseUser, sessionCtx.DatabaseRoles, sessionCtx.Identity.Username) + + // Prep JSON. + details, err := makeActivateUserDetails(sessionCtx, sessionCtx.Identity.Username) + if err != nil { + return trace.Wrap(err) + } + + // Call activate. + err = conn.executeAndCloseResult( + fmt.Sprintf("CALL %s(?, ?)", activateUserProcedureName), + sessionCtx.DatabaseUser, + details, + ) + if err == nil { + return nil + } + + e.Log.Debugf("Call teleport_activate_user failed: %v", err) + return trace.Wrap(convertActivateError(sessionCtx, err)) +} + +// DeactivateUser disables the database user. +func (e *Engine) DeactivateUser(ctx context.Context, sessionCtx *common.Session) error { + if sessionCtx.Database.GetAdminUser() == "" { + return trace.BadParameter("Teleport does not have admin user configured for this database") + } + + conn, err := e.connectAsAdminUser(ctx, sessionCtx) + if err != nil { + return trace.Wrap(err) + } + defer conn.Close() + + e.Log.Infof("Deactivating MySQL user %q for %v.", sessionCtx.DatabaseUser, sessionCtx.Identity.Username) + + err = conn.executeAndCloseResult( + fmt.Sprintf("CALL %s(?)", deactivateUserProcedureName), + sessionCtx.DatabaseUser, + ) + + if getSQLState(err) == sqlStateActiveUser { + e.Log.Debugf("Failed to deactivate user %q: %v.", sessionCtx.DatabaseUser, err) + return nil + } + return trace.Wrap(err) +} + +func (e *Engine) connectAsAdminUser(ctx context.Context, sessionCtx *common.Session) (*clientConn, error) { + adminSessionCtx := sessionCtx.WithUserAndDatabase( + sessionCtx.Database.GetAdminUser(), + defaultSchema(sessionCtx), + ) + conn, err := e.connect(ctx, adminSessionCtx) + if err != nil { + return nil, trace.Wrap(err) + } + return &clientConn{ + Conn: conn, + }, nil +} + +func (e *Engine) setupDatabaseForAutoUsers(conn *clientConn, sessionCtx *common.Session) error { + // TODO MariaDB requires separate stored procedures to handle auto user: + // - Max user length is different. + // - MariaDB uses mysql.roles_mapping instead of mysql.role_edges. + // - MariaDB cannot set all roles as default role at the same time. + // - MariaDB does not have user attributes. Will need another way for + // saving original Teleport user names. For example, a separate table can + // be used to track User -> JSON attribute mapping (protected view with + // row level security can be used in addition so each user can only read + // their own attributes, if needed). + if conn.isMariaDB() { + return trace.NotImplemented("auto user provisioning is not supported for MariaDB yet") + } + + // Create "teleport-auto-user". + err := conn.executeAndCloseResult(fmt.Sprintf("CREATE ROLE IF NOT EXISTS %q", teleportAutoUserRole)) + if err != nil { + return trace.Wrap(err) + } + + // There is no single command in MySQL to "CREATE OR REPLACE". Instead, + // have to DROP first before CREATE. + // + // To speed up the setup, the procedure "version" is stored as the + // procedure comment. So check if an update is necessary first by checking + // these comments. + // + // To force an update, drop one of the procedures or update the comment: + // ALTER PROCEDURE teleport_activate_user COMMENT 'need update' + if required, err := isProcedureUpdateRequired(conn, defaultSchema(sessionCtx), procedureVersion); err != nil { + return trace.Wrap(err) + } else if !required { + return nil + } + + // If update is necessary, do a transaction. + e.Log.Debugf("Updating stored procedures for MySQL server %s.", sessionCtx.Database.GetName()) + return trace.Wrap(doTransaction(conn, func() error { + for _, procedure := range allProcedures { + dropCommand := fmt.Sprintf("DROP PROCEDURE IF EXISTS %s", procedure.name) + updateCommand := fmt.Sprintf("ALTER PROCEDURE %s COMMENT %q", procedure.name, procedureVersion) + + if err := conn.executeAndCloseResult(dropCommand); err != nil { + return trace.Wrap(err) + } + if err := conn.executeAndCloseResult(procedure.createCommand); err != nil { + return trace.Wrap(err) + } + if err := conn.executeAndCloseResult(updateCommand); err != nil { + return trace.Wrap(err) + } + } + return nil + })) +} + +func getSQLState(err error) string { + var mysqlError *mysql.MyError + if !errors.As(err, &mysqlError) { + return "" + } + return mysqlError.State +} + +func convertActivateError(sessionCtx *common.Session, err error) error { + // This operation-failed message usually appear when the user already + // exists. A different error would be raised if the admin user has no + // permission to "CREATE USER". + if strings.Contains(err.Error(), "Operation CREATE USER failed") { + return trace.AlreadyExists("user %q already exists in this MySQL database and is not managed by Teleport", sessionCtx.DatabaseUser) + } + + switch getSQLState(err) { + case sqlStateUsernameDoesNotMatch: + return trace.AlreadyExists("username %q (Teleport user %q) already exists in this MySQL database and is used for another Teleport user.", sessionCtx.Identity.Username, sessionCtx.DatabaseUser) + + case sqlStateRolesChanged: + return trace.CompareFailed("roles for user %q has changed. Please quit all active connections and try again.", sessionCtx.Identity.Username) + + default: + return trace.Wrap(err) + } +} + +// defaultSchema returns the default database to log into as the admin user. +// +// Use a default database/schema to make sure procedures are always created and +// called from there (and possibly store other data there in the future). +// +// This also avoids "No database selected" errors if client doesn't provide +// one. +func defaultSchema(_ *common.Session) string { + // Aurora MySQL does not allow procedures on built-in "mysql" database. + // Technically we can use another built-in database like "sys". However, + // AWS (or database admins for self-hosted) may restrict permissions on + // these built-in databases eventually. Well, each built-in database has + // its own purpose. + // + // Thus lets use a teleport-specific database. This database should be + // created when configuring the admin user. The admin user should be + // granted the following permissions for this database: + // GRANT ALTER ROUTINE, CREATE ROUTINE, EXECUTE ON teleport.* TO '' + // + // TODO consider allowing user to specify the default database through database + // definition/labels. + return "teleport" +} + +func checkRoles(conn *clientConn, roles []string) error { + maxRoleLength := conn.maxUsernameLength() + for _, role := range roles { + if len(role) > maxRoleLength { + return trace.BadParameter("role %q exceeds maximum length limit of %d", role, maxRoleLength) + } + } + return nil +} + +func checkSupportedVersion(conn *clientConn) error { + if conn.isMariaDB() { + return trace.NotImplemented("auto user provisioning is not supported for MariaDB yet") + } + return trace.Wrap(checkMySQLSupportedVersion(conn.GetServerVersion())) +} + +func checkMySQLSupportedVersion(serverVersion string) error { + ver, err := semver.NewVersion(serverVersion) + switch { + case err != nil: + logrus.Debugf("Invalid MySQL server version %q. Assuming role management is supported.", serverVersion) + return nil + + // Reference: + // https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-0.html#mysqld-8-0-0-account-management + case ver.Major < 8: + return trace.BadParameter("role management is not supported for MySQL servers older than 8.0") + + default: + return nil + } +} + +func maybeHashUsername(teleportUser string, maxUsernameLength int) string { + if len(teleportUser) <= maxUsernameLength { + return teleportUser + } + + // Use sha1 to reduce chance of collision. + hash := sha1.New() + hash.Write([]byte(teleportUser)) + + // Use a prefix to identify the user is managed by Teleport. + return "tp-" + base64.RawStdEncoding.EncodeToString(hash.Sum(nil)) +} + +func authOptions(sessionCtx *common.Session) string { + switch sessionCtx.Database.GetType() { + case types.DatabaseTypeRDS: + return `IDENTIFIED WITH AWSAuthenticationPlugin AS "RDS"` + + case types.DatabaseTypeSelfHosted: + return fmt.Sprintf(`REQUIRE SUBJECT "/CN=%s"`, sessionCtx.DatabaseUser) + + default: + return "" + } +} + +func makeActivateUserDetails(sessionCtx *common.Session, teleportUser string) (json.RawMessage, error) { + details := activateUserDetails{ + Roles: sessionCtx.DatabaseRoles, + AuthOptions: authOptions(sessionCtx), + } + + // Save original username as user attributes in case the name is hashed. + details.Attributes.User = teleportUser + + data, err := json.Marshal(details) + if err != nil { + return nil, trace.Wrap(err) + } + return json.RawMessage(data), nil +} + +func isProcedureUpdateRequired(conn *clientConn, wantSchema, wantVersion string) (bool, error) { + // information_schema.routines is accessible for users/roles with EXECUTE + // permission. + result, err := conn.Execute(fmt.Sprintf( + "SELECT ROUTINE_NAME FROM information_schema.routines WHERE ROUTINE_SCHEMA = %q AND ROUTINE_COMMENT = %q", + wantSchema, + wantVersion, + )) + if err != nil { + return false, trace.Wrap(err) + } + defer result.Close() + + if result.RowNumber() < len(allProcedures) { + return true, nil + } + + // Paranoia, make sure the names match. + foundProcedures := make([]string, 0, result.RowNumber()) + for row := range result.Values { + procedure, err := result.GetString(row, 0) + if err != nil { + return false, trace.Wrap(err) + } + + foundProcedures = append(foundProcedures, procedure) + } + return !allProceduresFound(foundProcedures), nil +} + +func allProceduresFound(foundProcedures []string) bool { + for _, wantProcedure := range allProcedures { + if !slices.Contains(foundProcedures, wantProcedure.name) { + return false + } + } + return true +} + +func doTransaction(conn *clientConn, do func() error) error { + if err := conn.Begin(); err != nil { + return trace.Wrap(err) + } + + if err := do(); err != nil { + return trace.NewAggregate(err, conn.Rollback()) + } + + return trace.Wrap(conn.Commit()) +} + +const ( + // procedureVersion is a hard-coded string that is set as procedure + // comments to indicate the procedure version. + procedureVersion = "teleport-auto-user-v1" + + // mysqlMaxUsernameLength is the maximum username length for MySQL. + // + // https://dev.mysql.com/doc/refman/8.0/en/user-names.html + mysqlMaxUsernameLength = 32 + // mariadbMaxUsernameLength is the maximum username length for MariaDB. + // + // https://mariadb.com/kb/en/identifier-names/#maximum-length + mariadbMaxUsernameLength = 80 + + // teleportAutoUserRole is the name of a MySQL role that all Teleport + // managed users will be a part of. + // + // To find all users that assigned this role: + // SELECT TO_USER AS 'Teleport Managed Users' FROM mysql.role_edges WHERE FROM_USER = 'teleport-auto-user' + teleportAutoUserRole = "teleport-auto-user" + + // sqlStateActiveUser is the SQLSTATE raised by deactivation procedure when + // user has active connections. + // + // SQLSTATE reference: + // https://en.wikipedia.org/wiki/SQLSTATE + sqlStateActiveUser = "TP000" + // sqlStateUsernameDoesNotMatch is the SQLSTATE raised by activation + // procedure when the Teleport username does not match user's attributes. + // + // Possibly there is a hash collision, or someone manually updated the user + // attributes. + sqlStateUsernameDoesNotMatch = "TP001" + // sqlStateRolesChanged is the SQLSTATE raised by activation procedure when + // the user has active connections but roles has changed. + sqlStateRolesChanged = "TP002" + + revokeRolesProcedureName = "teleport_revoke_roles" + activateUserProcedureName = "teleport_activate_user" + deactivateUserProcedureName = "teleport_deactivate_user" +) + +var ( + //go:embed mysql_activate_user.sql + activateUserProcedure string + //go:embed mysql_deactivate_user.sql + deactivateUserProcedure string + //go:embed mysql_revoke_roles.sql + revokeRolesProcedure string + + allProcedures = []struct { + name string + createCommand string + }{ + { + name: revokeRolesProcedureName, + createCommand: revokeRolesProcedure, + }, + { + name: activateUserProcedureName, + createCommand: activateUserProcedure, + }, + { + name: deactivateUserProcedureName, + createCommand: deactivateUserProcedure, + }, + } +) diff --git a/lib/srv/db/mysql/autousers_test.go b/lib/srv/db/mysql/autousers_test.go new file mode 100644 index 0000000000000..614c468d5c019 --- /dev/null +++ b/lib/srv/db/mysql/autousers_test.go @@ -0,0 +1,182 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mysql + +import ( + "testing" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/tlsca" +) + +func Test_maybeHashUsername(t *testing.T) { + tests := []struct { + input string + wantOutput string + }{ + { + input: "short-name", + wantOutput: "short-name", + }, + { + input: "a-very-very-very-long-name-that-is-over-32", + wantOutput: "tp-XnfKd0MysfJ/xaR/b3OgoQvoTuo", + }, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + output := maybeHashUsername(test.input, mysqlMaxUsernameLength) + require.Equal(t, test.wantOutput, output) + require.Less(t, len(output), mysqlMaxUsernameLength) + }) + } +} + +func Test_makeActivateUserDetails(t *testing.T) { + rds, err := types.NewDatabaseV3(types.Metadata{ + Name: "RDS", + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: "aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + }) + require.NoError(t, err) + + teleportUsername := "a-very-very-very-long-name-that-is-over-32" + details, err := makeActivateUserDetails( + &common.Session{ + Database: rds, + DatabaseUser: maybeHashUsername(teleportUsername, mysqlMaxUsernameLength), + DatabaseRoles: []string{"role", "role2"}, + }, + teleportUsername, + ) + require.NoError(t, err) + + wantOutput := `{"roles":["role","role2"],"auth_options":"IDENTIFIED WITH AWSAuthenticationPlugin AS \"RDS\"","attributes":{"user":"a-very-very-very-long-name-that-is-over-32"}}` + require.Equal(t, wantOutput, string(details)) +} + +func Test_convertActivateError(t *testing.T) { + sessionCtx := &common.Session{ + DatabaseUser: "user1", + Identity: tlsca.Identity{ + Username: "user1", + }, + } + + createUserFailedError := &mysql.MyError{ + Code: mysql.ER_CANNOT_USER, + State: "HY000", + Message: `Operation CREATE USER failed for 'user1'@'%'`, + } + usernameDoesNotMatchError := &mysql.MyError{ + Code: mysql.ER_SIGNAL_EXCEPTION, + State: sqlStateUsernameDoesNotMatch, + Message: `Teleport username does not match user attributes`, + } + rolesChangedError := &mysql.MyError{ + Code: mysql.ER_SIGNAL_EXCEPTION, + State: sqlStateRolesChanged, + Message: `user has active connections and roles have changed`, + } + // Currently not converted to trace.AccessDeined as it may conflict with + // common.ConvertConnectError. + permissionError := &mysql.MyError{ + Code: mysql.ER_SPECIFIC_ACCESS_DENIED_ERROR, + State: "42000", + Message: `Access denied; you need (at least one of) the CREATE USER privilege(s) for this operation`, + } + + tests := []struct { + name string + input error + errorIs func(error) bool + errorContains string + }{ + { + name: "create user failed", + input: createUserFailedError, + errorIs: trace.IsAlreadyExists, + errorContains: "is not managed by Teleport", + }, + { + name: "username does not match", + input: usernameDoesNotMatchError, + errorIs: trace.IsAlreadyExists, + errorContains: "used for another Teleport user", + }, + { + name: "roles changed", + input: trace.Wrap(rolesChangedError), + errorIs: trace.IsCompareFailed, + errorContains: "quit all active connections", + }, + { + name: "no permission", + input: trace.Wrap(permissionError), + errorIs: func(err error) bool { + // Not converted. + return trace.Unwrap(err) == permissionError + }, + errorContains: permissionError.Message, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + converted := convertActivateError(sessionCtx, test.input) + require.True(t, test.errorIs(converted)) + require.Contains(t, converted.Error(), test.errorContains) + }) + } +} + +func Test_checkMySQLSupportedVersion(t *testing.T) { + tests := []struct { + input string + checkError require.ErrorAssertionFunc + }{ + { + input: "invalid-server-version", + checkError: require.NoError, + }, + { + input: "8.0.28", + checkError: require.NoError, + }, + { + input: "9.0.0", + checkError: require.NoError, + }, + { + input: "5.7.42", + checkError: require.Error, + }, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + test.checkError(t, checkMySQLSupportedVersion(test.input)) + }) + } +} diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index b7fdc7751c392..249f8254bb467 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -89,9 +89,23 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio if err != nil { return trace.Wrap(err) } + + // Automatically create the database user if needed. + cancelAutoUserLease, err := e.GetUserProvisioner(e).Activate(ctx, sessionCtx) + if err != nil { + return trace.Wrap(err) + } + defer func() { + err := e.GetUserProvisioner(e).Deactivate(ctx, sessionCtx) + if err != nil { + e.Log.WithError(err).Error("Failed to deactivate the user.") + } + }() + // Establish connection to the MySQL server. serverConn, err := e.connect(ctx, sessionCtx) if err != nil { + defer cancelAutoUserLease() if trace.IsLimitExceeded(err) { return trace.LimitExceeded("could not connect to the database, please try again later") } @@ -104,11 +118,15 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio } }() + // Release the auto-users semaphore now that we've successfully connected. + cancelAutoUserLease() + // Internally, updateServerVersion() updates databases only when database version // is not set, or it has changed since previous call. if err := e.updateServerVersion(sessionCtx, serverConn); err != nil { // Log but do not fail connection if the version update fails. e.Log.WithError(err).Warnf("Failed to update the MySQL server version.") + } // Send back OK packet to indicate auth/connect success. At this point @@ -154,17 +172,28 @@ func (e *Engine) updateServerVersion(sessionCtx *common.Session, serverConn *cli // checkAccess does authorization check for MySQL connection about to be established. func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) error { + // When using auto-provisioning, force the database username to be same + // as Teleport username. If it's not provided explicitly, some database + // clients get confused and display incorrect username. + if sessionCtx.AutoCreateUser { + if sessionCtx.DatabaseUser != sessionCtx.Identity.Username { + return trace.AccessDenied("please use your Teleport username (%q) to connect instead of %q", + sessionCtx.Identity.Username, sessionCtx.DatabaseUser) + } + } + authPref, err := e.Auth.GetAuthPreference(ctx) if err != nil { return trace.Wrap(err) } state := sessionCtx.GetAccessState(authPref) - dbRoleMatchers := role.DatabaseRoleMatchers( - sessionCtx.Database, - sessionCtx.DatabaseUser, - sessionCtx.DatabaseName, - ) + dbRoleMatchers := role.GetDatabaseRoleMatchers(role.RoleMatchersConfig{ + Database: sessionCtx.Database, + DatabaseUser: sessionCtx.DatabaseUser, + DatabaseName: sessionCtx.DatabaseName, + AutoCreateUser: sessionCtx.AutoCreateUser, + }) err = sessionCtx.Checker.CheckAccess( sessionCtx.Database, state, diff --git a/lib/srv/db/mysql/mysql_activate_user.sql b/lib/srv/db/mysql/mysql_activate_user.sql new file mode 100644 index 0000000000000..1df03c72b2e27 --- /dev/null +++ b/lib/srv/db/mysql/mysql_activate_user.sql @@ -0,0 +1,75 @@ +CREATE PROCEDURE teleport_activate_user(IN username VARCHAR(32), IN details JSON) +proc_label:BEGIN + DECLARE is_auto_user INT DEFAULT 0; + DECLARE is_active INT DEFAULT 0; + DECLARE is_same_user INT DEFAULT 0; + DECLARE are_roles_same INT DEFAULT 0; + DECLARE role_index INT DEFAULT 0; + DECLARE role VARCHAR(32) DEFAULT ''; + DECLARE cur_roles TEXT DEFAULT ''; + SET @roles = details->"$.roles"; + SET @teleport_user = details->>"$.attributes.user"; + + -- If the user already exists and was provisioned by Teleport, reactivate + -- it, otherwise provision a new one. + SELECT COUNT(TO_USER) INTO is_auto_user FROM mysql.role_edges WHERE FROM_USER = 'teleport-auto-user' AND TO_USER = username; + IF is_auto_user = 1 THEN + SELECT COUNT(USER) INTO is_same_user FROM INFORMATION_SCHEMA.USER_ATTRIBUTES WHERE USER = username AND ATTRIBUTE->"$.user" = @teleport_user; + IF is_same_user = 0 THEN + SIGNAL SQLSTATE 'TP001' SET MESSAGE_TEXT = 'Teleport username does not match user attributes'; + END IF; + + SELECT COUNT(USER) INTO is_active FROM information_schema.processlist WHERE USER = username; + + -- If the user has active connections, make sure the provided roles + -- match what the user currently has. + IF is_active = 1 THEN + SELECT json_arrayagg(FROM_USER) INTO cur_roles FROM mysql.role_edges WHERE FROM_USER != 'teleport-auto-user' AND TO_USER = username; + SELECT @roles = cur_roles INTO are_roles_same; + IF are_roles_same = 1 THEN + LEAVE proc_label; + ELSE + SIGNAL SQLSTATE 'TP002' SET MESSAGE_TEXT = 'user has active connections and roles have changed'; + END IF; + END IF; + + -- Otherwise reactivate the user, but first strip if of all roles to + -- account for scenarios with left-over roles if database agent crashed + -- and failed to cleanup upon session termination. + CALL teleport_revoke_roles(username); + + -- Ensure the user is unlocked. User is locked at deactivation. + SET @sql := CONCAT_WS(' ', 'ALTER USER', QUOTE(username), 'ACCOUNT UNLOCK'); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + ELSE + SET @sql := CONCAT_WS(' ', 'CREATE USER', QUOTE(username), details->>"$.auth_options", 'ATTRIBUTE', QUOTE(details->"$.attributes")); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + + SET @sql := CONCAT_WS(' ', 'GRANT', QUOTE('teleport-auto-user'), 'TO', QUOTE(username)); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + END IF; + + -- Assign roles. + WHILE role_index < JSON_LENGTH(@roles) DO + SELECT JSON_EXTRACT(@roles, CONCAT('$[',role_index,']')) INTO role; + SELECT role_index + 1 INTO role_index; + + -- role extracted from JSON already has double quotes. + SET @sql := CONCAT_WS(' ', 'GRANT', role, 'TO', QUOTE(username)); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + END WHILE; + + -- Ensure all assigned roles are available to use right after connection. + SET @sql := CONCAT('SET DEFAULT ROLE ALL TO ', QUOTE(username)); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; +END diff --git a/lib/srv/db/mysql/mysql_deactivate_user.sql b/lib/srv/db/mysql/mysql_deactivate_user.sql new file mode 100644 index 0000000000000..30a3aae3d7f99 --- /dev/null +++ b/lib/srv/db/mysql/mysql_deactivate_user.sql @@ -0,0 +1,17 @@ +CREATE PROCEDURE teleport_deactivate_user(IN username VARCHAR(32)) +BEGIN + DECLARE is_active INT DEFAULT 0; + SELECT COUNT(USER) INTO is_active FROM information_schema.processlist WHERE USER = username; + IF is_active = 1 THEN + -- Throw a custom error code when user is still active from other sessions. + SIGNAL SQLSTATE 'TP000' SET MESSAGE_TEXT = 'User has active connections'; + ELSE + -- Lock the user then revoke all the roles. + SET @sql := CONCAT_WS(' ', 'ALTER USER', QUOTE(username), 'ACCOUNT LOCK'); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + + CALL teleport_revoke_roles(username); + END IF; +END diff --git a/lib/srv/db/mysql/mysql_revoke_roles.sql b/lib/srv/db/mysql/mysql_revoke_roles.sql new file mode 100644 index 0000000000000..276bdb08bc927 --- /dev/null +++ b/lib/srv/db/mysql/mysql_revoke_roles.sql @@ -0,0 +1,22 @@ +CREATE PROCEDURE teleport_revoke_roles(IN username VARCHAR(32)) +BEGIN + DECLARE role VARCHAR(32) DEFAULT ''; + DECLARE done INT DEFAULT 0; + DECLARE role_cursor CURSOR FOR select FROM_USER from mysql.role_edges where FROM_USER != 'teleport-auto-user' AND TO_USER = username; + DECLARE CONTINUE HANDLER FOR NOT FOUND SET done = 1; + OPEN role_cursor; + + revoke_roles: LOOP + FETCH role_cursor INTO role; + IF done = 1 THEN + LEAVE revoke_roles; + END IF; + + SET @sql := CONCAT_WS(' ', 'REVOKE', QUOTE(role), 'FROM', QUOTE(username)); + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + END LOOP revoke_roles; + + CLOSE role_cursor; +END diff --git a/lib/srv/db/mysql/test.go b/lib/srv/db/mysql/test.go index d939154d52eb8..bb2a307923b42 100644 --- a/lib/srv/db/mysql/test.go +++ b/lib/srv/db/mysql/test.go @@ -18,7 +18,9 @@ package mysql import ( "crypto/tls" + "encoding/json" "net" + "strings" "sync" "sync/atomic" @@ -68,6 +70,18 @@ func MakeTestClientWithoutTLS(addr string, routeToDatabase tlsca.RouteToDatabase return conn, nil } +// UserEvent represents a user activation/deactivation event. +type UserEvent struct { + // TeleportUser is the Teleport username. + TeleportUser string + // DatabaseUser is the in-database username. + DatabaseUser string + // Roles are the user Roles. + Roles []string + // Active is whether user activated or deactivated. + Active bool +} + // TestServer is a test MySQL server used in functional database // access tests. type TestServer struct { @@ -127,7 +141,11 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (sv listener: listener, port: port, log: log, - handler: &testHandler{log: log}, + handler: &testHandler{ + log: log, + userEventsCh: make(chan UserEvent, 100), + usersMapping: make(map[string]string), + }, } if !config.ListenTLS { @@ -251,16 +269,27 @@ func (s *TestServer) ConnsClosed() bool { return true } +// UserEventsCh returns channel that receives user activate/deactivate events. +func (s *TestServer) UserEventsCh() <-chan UserEvent { + return s.handler.userEventsCh +} + type testHandler struct { server.EmptyHandler log logrus.FieldLogger // queryCount keeps track of the number of queries the server has received. queryCount uint32 + + userEventsCh chan UserEvent + // usersMapping maps in-database username to Teleport username. + usersMapping map[string]string + usersMappingMu sync.Mutex } func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { h.log.Debugf("Received query %q.", query) atomic.AddUint32(&h.queryCount, 1) + // When getting a "show tables" query, construct the response in a way // which previously caused server packets parsing logic to fail. if query == "show tables" { @@ -279,6 +308,80 @@ func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { Resultset: resultSet, }, nil } + + return TestQueryResponse, nil +} + +func (h *testHandler) HandleStmtPrepare(prepare string) (int, int, interface{}, error) { + params := strings.Count(prepare, "?") + return params, 0, nil, nil +} +func (h *testHandler) HandleStmtExecute(_ interface{}, query string, args []interface{}) (*mysql.Result, error) { + h.log.Debugf("Received execute %q with args %+v.", args) + if strings.HasPrefix(query, "CALL ") { + return h.handleCallProcedure(query, args) + } + return TestQueryResponse, nil +} + +func (h *testHandler) handleCallProcedure(query string, args []interface{}) (*mysql.Result, error) { + query = strings.TrimSpace(strings.TrimPrefix(query, "CALL")) + openBracketIndex := strings.IndexByte(query, '(') + endBracketIndex := strings.LastIndexByte(query, ')') + if openBracketIndex < 0 || endBracketIndex < 0 { + return nil, trace.BadParameter("invalid query: %v", query) + } + + procedureName := query[:openBracketIndex] + switch procedureName { + case activateUserProcedureName: + if len(args) != 2 { + return nil, trace.BadParameter("invalid number of parameters: %v", args) + } + databaseUserBytes, ok := args[0].([]byte) + if !ok { + return nil, trace.BadParameter("invalid database user: %v", args[0]) + } + detailsBytes, ok := args[1].([]byte) + if !ok { + return nil, trace.BadParameter("invalid details: %v", args[1]) + } + details := activateUserDetails{} + err := json.Unmarshal(detailsBytes, &details) + if err != nil { + return nil, trace.BadParameter("invalid JSON: %v", err) + } + + // Update mapping and send event. + databaseUser := string(databaseUserBytes) + h.usersMappingMu.Lock() + defer h.usersMappingMu.Unlock() + h.usersMapping[databaseUser] = details.Attributes.User + h.userEventsCh <- UserEvent{ + DatabaseUser: databaseUser, + TeleportUser: h.usersMapping[databaseUser], + Roles: details.Roles, + Active: true, + } + + case deactivateUserProcedureName: + if len(args) != 1 { + return nil, trace.BadParameter("invalid number of parameters: %v", args) + } + databaseUserBytes, ok := args[0].([]byte) + if !ok { + return nil, trace.BadParameter("invalid database user: %v", args[0]) + } + + // Send event. + h.usersMappingMu.Lock() + defer h.usersMappingMu.Unlock() + h.userEventsCh <- UserEvent{ + DatabaseUser: string(databaseUserBytes), + TeleportUser: h.usersMapping[string(databaseUserBytes)], + Active: false, + } + } return TestQueryResponse, nil } diff --git a/lib/srv/db/proxy_test.go b/lib/srv/db/proxy_test.go index ea2215dbed0c3..74553a7b8053b 100644 --- a/lib/srv/db/proxy_test.go +++ b/lib/srv/db/proxy_test.go @@ -487,7 +487,7 @@ func setConfigClientIdleTimoutAndDisconnectExpiredCert(ctx context.Context, t *t func TestExtractMySQLVersion(t *testing.T) { t.Parallel() ctx := context.Background() - testCtx := setupTestContext(ctx, t, withSelfHostedMySQL("mysql", mysql.WithServerVersion("8.0.25"))) + testCtx := setupTestContext(ctx, t, withSelfHostedMySQL("mysql", withMySQLServerVersion("8.0.25"))) go testCtx.startHandlingConnections() testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{"root"}, []string{types.Wildcard}) diff --git a/lib/srv/db/proxyserver_test.go b/lib/srv/db/proxyserver_test.go index 769f8c973b792..5a23a7f8525a4 100644 --- a/lib/srv/db/proxyserver_test.go +++ b/lib/srv/db/proxyserver_test.go @@ -25,7 +25,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/limiter" - "github.com/gravitational/teleport/lib/srv/db/mysql" ) func TestProxyConnectionLimiting(t *testing.T) { @@ -247,7 +246,7 @@ func TestProxyRateLimiting(t *testing.T) { func TestProxyMySQLVersion(t *testing.T) { ctx := context.Background() testCtx := setupTestContext(ctx, t, - withSelfHostedMySQL("mysql", mysql.WithServerVersion("8.0.12")), + withSelfHostedMySQL("mysql", withMySQLServerVersion("8.0.12")), ) go testCtx.startHandlingConnections()