Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] fix database dynamic labels #29373

Merged
merged 3 commits into from Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions lib/srv/app/server.go
Expand Up @@ -473,7 +473,7 @@ func (s *Server) getServerInfo(app types.Application) (types.Resource, error) {
// Make sure to return a new object, because it gets cached by
// heartbeat and will always compare as equal otherwise.
s.mu.RLock()
copy := s.appWithUpdatedLabels(app)
copy := s.appWithUpdatedLabelsLocked(app)
s.mu.RUnlock()
expires := s.c.Clock.Now().UTC().Add(apidefaults.ServerAnnounceTTL)
server, err := types.NewAppServerV3(types.Metadata{
Expand Down Expand Up @@ -1005,18 +1005,19 @@ func (s *Server) getSession(ctx context.Context, identity *tlsca.Identity, app t
func (s *Server) getApp(ctx context.Context, publicAddr string) (types.Application, error) {
s.mu.RLock()
defer s.mu.RUnlock()

for _, a := range s.getApps() {
// don't call s.getApps() as this will call RLock and potentially deadlock.
for _, a := range s.apps {
if publicAddr == a.GetPublicAddr() {
return s.appWithUpdatedLabels(a), nil
return s.appWithUpdatedLabelsLocked(a), nil
}
}
return nil, trace.NotFound("no application at %v found", publicAddr)
}

// appWithUpdatedLabels will inject updated dynamic and cloud labels into an application
// object. The caller must invoke an RLock on `s.mu` before calling this function.
func (s *Server) appWithUpdatedLabels(app types.Application) *types.AppV3 {
// appWithUpdatedLabelsLocked will inject updated dynamic and cloud labels into
// an application object.
// The caller must invoke an RLock on `s.mu` before calling this function.
func (s *Server) appWithUpdatedLabelsLocked(app types.Application) *types.AppV3 {
// Create a copy of the application to modify
copy := app.Copy()

Expand Down
4 changes: 3 additions & 1 deletion lib/srv/app/server_test.go
Expand Up @@ -586,7 +586,9 @@ func TestAppWithUpdatedLabels(t *testing.T) {
require.NoError(t, test.cloudLabels.Sync(context.Background()))
}

updatedApp := s.appServer.appWithUpdatedLabels(test.app)
s.appServer.mu.RLock()
updatedApp := s.appServer.appWithUpdatedLabelsLocked(test.app)
s.appServer.mu.RUnlock()

for key, value := range test.expectedDynamicLabels {
require.Equal(t, value, updatedApp.GetDynamicLabels()[key].GetResult())
Expand Down
87 changes: 76 additions & 11 deletions lib/srv/db/access_test.go
Expand Up @@ -98,18 +98,23 @@ func TestMain(m *testing.M) {
// on the configured RBAC rules.
func TestAccessPostgres(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres"))
testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres", func(db *types.DatabaseV3) {
db.SetStaticLabels(map[string]string{"foo": "bar"})
}))
go testCtx.startHandlingConnections()

dynamicDBLabels := types.Labels{"echo": {"test"}}
staticDBLabels := types.Labels{"foo": {"bar"}}
tests := []struct {
desc string
user string
role string
allowDbNames []string
allowDbUsers []string
dbName string
dbUser string
err string
desc string
user string
role string
allowDbNames []string
allowDbUsers []string
extraRoleOpts []roleOptFn
dbName string
dbUser string
err string
}{
{
desc: "has access to all database names and users",
Expand Down Expand Up @@ -169,12 +174,55 @@ func TestAccessPostgres(t *testing.T) {
dbUser: "postgres",
err: "access to db denied",
},
{
desc: "access allowed to specific user/database by static label",
user: "alice",
role: "admin",
allowDbNames: []string{"metrics"},
allowDbUsers: []string{"alice"},
// The default test role created has wildcard labels allowed.
// This tests that specific allowed database labels matching the
// test database's static labels allows access.
extraRoleOpts: []roleOptFn{withAllowedDBLabels(staticDBLabels)},
dbName: "metrics",
dbUser: "alice",
},
{
desc: "access allowed to specific user/database by dynamic label",
user: "alice",
role: "admin",
allowDbNames: []string{"metrics"},
allowDbUsers: []string{"alice"},
// The default test role created has wildcard labels allowed.
// This tests that specific allowed database labels matching the
// test database's dynamic labels allows access, to ensure
// that RBAC checks against dynamic labels are working.
extraRoleOpts: []roleOptFn{withAllowedDBLabels(dynamicDBLabels)},
dbName: "metrics",
dbUser: "alice",
},
{
desc: "access denied by dynamic label",
user: "alice",
role: "admin",
allowDbNames: []string{"metrics"},
allowDbUsers: []string{"alice"},
// The default test role created has wildcard labels allowed.
// This tests that specific denied database labels matching the
// test database's dynamic labels denies access, to ensure
// that RBAC checks against dynamic labels are working.
extraRoleOpts: []roleOptFn{withDeniedDBLabels(dynamicDBLabels)},
dbName: "metrics",
dbUser: "alice",
err: "access to db denied",
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
// Create user/role with the requested permissions.
testCtx.createUserAndRole(ctx, t, test.user, test.role, test.allowDbUsers, test.allowDbNames)
testCtx.createUserAndRole(ctx, t, test.user, test.role,
test.allowDbUsers, test.allowDbNames, test.extraRoleOpts...)

// Try to connect to the database as this user.
pgConn, err := testCtx.postgresClient(ctx, test.user, "postgres", test.dbUser, test.dbName)
Expand Down Expand Up @@ -1837,13 +1885,30 @@ func (c *testContext) dynamodbClient(ctx context.Context, teleportUser, dbServic
return db, proxy, nil
}

type roleOptFn func(types.Role)

func withAllowedDBLabels(labels types.Labels) roleOptFn {
return func(role types.Role) {
role.SetDatabaseLabels(types.Allow, labels)
}
}

func withDeniedDBLabels(labels types.Labels) roleOptFn {
return func(role types.Role) {
role.SetDatabaseLabels(types.Deny, labels)
}
}

// createUserAndRole creates Teleport user and role with specified names
// and allowed database users/names properties.
func (c *testContext) createUserAndRole(ctx context.Context, t *testing.T, userName, roleName string, dbUsers, dbNames []string) (types.User, types.Role) {
func (c *testContext) createUserAndRole(ctx context.Context, t *testing.T, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) {
user, role, err := auth.CreateUserAndRole(c.tlsServer.Auth(), userName, []string{roleName}, nil)
require.NoError(t, err)
role.SetDatabaseUsers(types.Allow, dbUsers)
role.SetDatabaseNames(types.Allow, dbNames)
for _, roleOpt := range roleOpts {
roleOpt(role)
}
err = c.tlsServer.Auth().UpsertRole(ctx, role)
require.NoError(t, err)
return user, role
Expand Down
49 changes: 32 additions & 17 deletions lib/srv/db/ca.go
Expand Up @@ -59,13 +59,42 @@ func (s *Server) startCARenewer(ctx context.Context) {
// initCACert initializes the provided server's CA certificate in case of a
// cloud hosted database instance.
func (s *Server) initCACert(ctx context.Context, database types.Database) error {
s.mu.RLock()
if !s.shouldInitCACertLocked(database) {
s.mu.RUnlock()
return nil
}
// make a copy so we can safely unlock before doing expensive CA cert
// version checking or downloading.
copy := database.Copy()
s.mu.RUnlock()
bytes, err := s.getCACerts(ctx, copy)
if err != nil {
return trace.Wrap(err)
}
// Make sure the cert we got is valid just in case.
if _, err := tlsca.ParseCertificatePEM(bytes); err != nil {
return trace.Wrap(err, "CA certificate for %v doesn't appear to be a valid x509 certificate: %s",
copy, bytes)
}
s.mu.Lock()
// update the original database under a lock, since we're mutating it.
database.SetStatusCA(string(bytes))
s.mu.Unlock()
return nil
}

// shouldInitCACertLocked returns whether a given database needs to have its
// CA cert initialized.
// The caller must call RLock on `s.mu` before calling this function.
func (s *Server) shouldInitCACertLocked(database types.Database) bool {
// To identify if the CA cert was set automatically, compare the result of
// `GetCA` (which can return user-provided CA) with `GetStatusCA`, which
// only returns the CA set by the Teleport. If both contents differ, we will
// not download CAs for the database. Both sides will be empty at the first
// pass, downloading and populating the `StatusCA`.
if database.GetCA() != database.GetStatusCA() {
return nil
return false
}
// Can only download it for cloud-hosted instances.
switch database.GetType() {
Expand All @@ -78,24 +107,10 @@ func (s *Server) initCACert(ctx context.Context, database types.Database) error
types.DatabaseTypeDynamoDB,
types.DatabaseTypeCloudSQL,
types.DatabaseTypeAzure:

return true
default:
return nil
}
// It's not set so download it or see if it's already downloaded.
// When initializing the CAs do not update the certificates, instead use the
// cached ones.
bytes, err := s.getCACerts(ctx, database)
if err != nil {
return trace.Wrap(err)
return false
}
// Make sure the cert we got is valid just in case.
if _, err := tlsca.ParseCertificatePEM(bytes); err != nil {
return trace.Wrap(err, "CA certificate for %v doesn't appear to be a valid x509 certificate: %s",
database, bytes)
}
database.SetStatusCA(string(bytes))
return nil
}

// getCACerts updates and returns automatically downloaded root certificate for
Expand Down
60 changes: 40 additions & 20 deletions lib/srv/db/server.go
Expand Up @@ -604,6 +604,42 @@ func (s *Server) getProxiedDatabases() (databases types.Databases) {
return databases
}

// getProxiedDatabase returns a proxied database by name with updated dynamic
// and cloud labels.
func (s *Server) getProxiedDatabase(name string) (types.Database, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// don't call s.getProxiedDatabases() as this will call RLock and
// potentially deadlock.
for _, db := range s.proxiedDatabases {
if db.GetName() == name {
return s.copyDatabaseWithUpdatedLabelsLocked(db), nil
}
}
return nil, trace.NotFound("%q not found among registered databases: %v",
name, s.proxiedDatabases)
}

// copyDatabaseWithUpdatedLabelsLocked will inject updated dynamic and cloud labels into
// a database object.
// The caller must invoke an RLock on `s.mu` before calling this function.
func (s *Server) copyDatabaseWithUpdatedLabelsLocked(database types.Database) *types.DatabaseV3 {
// create a copy of the database to modify.
copy := database.Copy()

// Update dynamic labels if the database has them.
labels, ok := s.dynamicLabels[copy.GetName()]
if ok && labels != nil {
copy.SetDynamicLabels(labels.Get())
}

// Add in the cloud labels if the db has them.
if s.cfg.CloudLabels != nil {
s.cfg.CloudLabels.Apply(copy)
}
return copy
}

// startHeartbeat starts the registration heartbeat to the auth server.
func (s *Server) startHeartbeat(ctx context.Context, database types.Database) error {
heartbeat, err := srv.NewHeartbeat(srv.HeartbeatConfig{
Expand Down Expand Up @@ -659,16 +695,8 @@ func (s *Server) getServerInfo(database types.Database) (types.Resource, error)
// Make sure to return a new object, because it gets cached by
// heartbeat and will always compare as equal otherwise.
s.mu.RLock()
copy := database.Copy()
copy := s.copyDatabaseWithUpdatedLabelsLocked(database)
s.mu.RUnlock()
// Update dynamic labels if the database has them.
labels := s.getDynamicLabels(copy.GetName())
if labels != nil {
copy.SetDynamicLabels(labels.Get())
}
if s.cfg.CloudLabels != nil {
s.cfg.CloudLabels.Apply(copy)
}
if s.cfg.CloudIAM != nil {
s.cfg.CloudIAM.UpdateIAMStatus(copy)
}
Expand Down Expand Up @@ -1079,17 +1107,9 @@ func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
s.log.Debugf("Client identity: %#v.", identity)

// Fetch the requested database server.
var database types.Database
registeredDatabases := s.getProxiedDatabases()
for _, db := range registeredDatabases {
if db.GetName() == identity.RouteToDatabase.ServiceName {
database = db
break
}
}
if database == nil {
return nil, trace.NotFound("%q not found among registered databases: %v",
identity.RouteToDatabase.ServiceName, registeredDatabases)
database, err := s.getProxiedDatabase(identity.RouteToDatabase.ServiceName)
if err != nil {
return nil, trace.Wrap(err)
}

autoCreate, databaseRoles, err := authContext.Checker.CheckDatabaseRoles(database)
Expand Down