Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve HA behavior of database agents in leaf clusters #10641

Merged
merged 3 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/srv/db"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/mysql"
Expand Down Expand Up @@ -387,6 +388,121 @@ func TestDatabaseAccessPostgresSeparateListener(t *testing.T) {
require.NoError(t, err)
}

func init() {
// Override database agents shuffle behavior to ensure they're always
// tried in the same order during tests. Used for HA tests.
db.SetShuffleFunc(db.ShuffleSort)
}

// TestDatabaseAccessHARootCluster verifies that proxy falls back to a healthy
// database agent when multiple agents are serving the same database and one
// of them is down in a root cluster.
func TestDatabaseAccessHARootCluster(t *testing.T) {
pack := setupDatabaseTest(t)

// Insert a database server entry not backed by an actual running agent
// to simulate a scenario when an agent is down but the resource hasn't
// expired from the backend yet.
dbServer, err := types.NewDatabaseServerV3(types.Metadata{
Name: pack.root.postgresService.Name,
}, types.DatabaseServerSpecV3{
Protocol: defaults.ProtocolPostgres,
URI: pack.root.postgresAddr,
// To make sure unhealthy server is always picked in tests first, make
// sure its host ID always compares as "smaller" as the tests sort
// agents.
HostID: "0000",
Hostname: "test",
})
require.NoError(t, err)

_, err = pack.root.cluster.Process.GetAuthServer().UpsertDatabaseServer(
context.Background(), dbServer)
require.NoError(t, err)

// Connect to the database service in root cluster.
client, err := postgres.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Address: net.JoinHostPort(Loopback, pack.root.cluster.GetPortWeb()),
Cluster: pack.root.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: pack.root.postgresService.Name,
Protocol: pack.root.postgresService.Protocol,
Username: "postgres",
Database: "test",
},
})
require.NoError(t, err)

// Execute a query.
result, err := client.Exec(context.Background(), "select 1").ReadAll()
require.NoError(t, err)
require.Equal(t, []*pgconn.Result{postgres.TestQueryResponse}, result)
require.Equal(t, uint32(1), pack.root.postgres.QueryCount())
require.Equal(t, uint32(0), pack.leaf.postgres.QueryCount())

// Disconnect.
err = client.Close(context.Background())
require.NoError(t, err)
}

// TestDatabaseAccessHALeafCluster verifies that proxy falls back to a healthy
// database agent when multiple agents are serving the same database and one
// of them is down in a leaf cluster.
func TestDatabaseAccessHALeafCluster(t *testing.T) {
pack := setupDatabaseTest(t)
pack.waitForLeaf(t)

// Insert a database server entry not backed by an actual running agent
// to simulate a scenario when an agent is down but the resource hasn't
// expired from the backend yet.
dbServer, err := types.NewDatabaseServerV3(types.Metadata{
Name: pack.leaf.postgresService.Name,
}, types.DatabaseServerSpecV3{
Protocol: defaults.ProtocolPostgres,
URI: pack.leaf.postgresAddr,
// To make sure unhealthy server is always picked in tests first, make
// sure its host ID always compares as "smaller" as the tests sort
// agents.
HostID: "0000",
Hostname: "test",
})
require.NoError(t, err)

_, err = pack.leaf.cluster.Process.GetAuthServer().UpsertDatabaseServer(
context.Background(), dbServer)
require.NoError(t, err)

// Connect to the database service in leaf cluster via root cluster.
client, err := postgres.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Address: net.JoinHostPort(Loopback, pack.root.cluster.GetPortWeb()), // Connecting via root cluster.
Cluster: pack.leaf.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: pack.leaf.postgresService.Name,
Protocol: pack.leaf.postgresService.Protocol,
Username: "postgres",
Database: "test",
},
})
require.NoError(t, err)

// Execute a query.
result, err := client.Exec(context.Background(), "select 1").ReadAll()
require.NoError(t, err)
require.Equal(t, []*pgconn.Result{postgres.TestQueryResponse}, result)
require.Equal(t, uint32(1), pack.leaf.postgres.QueryCount())
require.Equal(t, uint32(0), pack.root.postgres.QueryCount())

// Disconnect.
err = client.Close(context.Background())
require.NoError(t, err)
}

// TestDatabaseAccessMongoSeparateListener tests mongo proxy listener running on separate port.
func TestDatabaseAccessMongoSeparateListener(t *testing.T) {
pack := setupDatabaseTest(t,
Expand Down
15 changes: 15 additions & 0 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,18 @@ type Server interface {
// Wait waits for server to close all outstanding operations
Wait()
}

const (
// NoApplicationTunnel is the error message returned when application
// reverse tunnel cannot be found.
//
// It usually happens when an app agent has shut down (or crashed) but
// hasn't expired from the backend yet.
NoApplicationTunnel = "could not find reverse tunnel, check that Application Service agent proxying this application is up and running"
// NoDatabaseTunnel is the error message returned when database reverse
// tunnel cannot be found.
//
// It usually happens when a database agent has shut down (or crashed) but
// hasn't expired from the backend yet.
NoDatabaseTunnel = "could not find reverse tunnel, check that Database Service agent proxying this database is up and running"
)
10 changes: 7 additions & 3 deletions lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,13 @@ func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bo
return nil, false, trace.Wrap(err)
}

// Connections to applications should never occur over a direct dial, return right away.
if r.ConnType == types.AppTunnel {
return nil, false, trace.ConnectionProblem(err, "failed to connect to application")
// Connections to applications and databases should never occur over
// a direct dial, return right away.
switch r.ConnType {
case types.AppTunnel:
return nil, false, trace.ConnectionProblem(err, NoApplicationTunnel)
case types.DatabaseTunnel:
return nil, false, trace.ConnectionProblem(err, NoDatabaseTunnel)
Comment on lines +378 to +384
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we somehow unify this switch so we will have same common logic for remoteSite and localSite https://github.com/gravitational/teleport/blob/roman/leafdb/lib/reversetunnel/localsite.go#L327

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but to be honest it doesn't feel like moving it out is really worth it (and would probably complicate implementation a bit too).

}

errTun := err
Expand Down
12 changes: 6 additions & 6 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"fmt"
"net"
"os"
"sort"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -1518,6 +1517,12 @@ func (c *testContext) Close() error {
return trace.NewAggregate(errors...)
}

func init() {
// Override database agents shuffle behavior to ensure they're always
// tried in the same order during tests. Used for HA tests.
SetShuffleFunc(ShuffleSort)
}

func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDatabaseOption) *testContext {
testCtx := &testContext{
clusterName: "root.example.com",
Expand Down Expand Up @@ -1630,11 +1635,6 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
Emitter: testCtx.emitter,
Clock: testCtx.clock,
ServerID: "proxy-server",
Shuffle: func(servers []types.DatabaseServer) []types.DatabaseServer {
// To ensure predictability in tests, sort servers instead of shuffling.
sort.Sort(types.DatabaseServers(servers))
return servers
},
LockWatcher: proxyLockWatcher,
})
require.NoError(t, err)
Expand Down
70 changes: 55 additions & 15 deletions lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import (
"io"
"math/rand"
"net"
"sort"
"strings"
"sync"
"time"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -85,12 +88,51 @@ type ProxyServerConfig struct {
Clock clockwork.Clock
// ServerID is the ID of the audit log server.
ServerID string
// Shuffle allows to override shuffle logic in tests.
Shuffle func([]types.DatabaseServer) []types.DatabaseServer
// LockWatcher is a lock watcher.
LockWatcher *services.LockWatcher
}

// ShuffleFunc defines a function that shuffles a list of database servers.
type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer

// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers.
// Used to provide load balancing behavior when proxying to multiple agents.
func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer {
rand.New(rand.NewSource(time.Now().UnixNano())).Shuffle(
len(servers), func(i, j int) {
servers[i], servers[j] = servers[j], servers[i]
})
return servers
}

// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID.
// Used to provide predictable behavior in tests.
func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer {
sort.Sort(types.DatabaseServers(servers))
return servers
}

var (
// mu protects the shuffleFunc global access.
mu sync.RWMutex
// shuffleFunc provides shuffle behavior for multiple database agents.
shuffleFunc ShuffleFunc = ShuffleRandom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not big fun of global variables but I assume that this logic is hard to test without this. Though wonder if the db servers order can be enforced in integration tests by settings specific time point in clock used by proxy db service:

clock := NewFakeClockAt(time.Date(2021, time.April, 4, 0, 0, 0, 0, time.UTC))
pack := setupDatabaseTest(t, withClock(clock))

https://github.com/gravitational/teleport/pull/10641/files#diff-3ed9752f98ccf36539f32e345ea5cb26a13d71552d6b9a9be5117308d26b00cbL119

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about the same actually but didn't know how to ensure the stable shuffle order even with the frozen time. I think in this case shuffle will always produce the same order but there's no way of telling which order. So while I'm not a big fan of this either, this felt like the most reliable way to ensure guaranteed order.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we could always shuffle and in tests add an extra pass to sort them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean adding some sort of "test mode", I was trying to avoid that. I generally try not to make "if isTestMode()" switches in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to write NVM, but you were faster. Initially I thought about passing extra sort function in test, but that introduces the same issue that we have already.
I agree on the "test mode" too.

)

// SetShuffleFunc sets the shuffle behavior when proxying to multiple agents.
func SetShuffleFunc(fn ShuffleFunc) {
mu.Lock()
defer mu.Unlock()
shuffleFunc = fn
}

// getShuffleFunc returns the configured function used to shuffle agents.
func getShuffleFunc() ShuffleFunc {
mu.RLock()
defer mu.RUnlock()
return shuffleFunc
}

// CheckAndSetDefaults validates the config and sets default values.
func (c *ProxyServerConfig) CheckAndSetDefaults() error {
if c.AccessPoint == nil {
Expand All @@ -114,15 +156,6 @@ func (c *ProxyServerConfig) CheckAndSetDefaults() error {
if c.ServerID == "" {
return trace.BadParameter("missing ServerID")
}
if c.Shuffle == nil {
c.Shuffle = func(servers []types.DatabaseServer) []types.DatabaseServer {
rand.New(rand.NewSource(c.Clock.Now().UnixNano())).Shuffle(
len(servers), func(i, j int) {
servers[i], servers[j] = servers[j], servers[i]
})
return servers
}
}
if c.LockWatcher == nil {
return trace.BadParameter("missing LockWatcher")
}
Expand Down Expand Up @@ -351,7 +384,7 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext
// There may be multiple database servers proxying the same database. If
// we get a connection problem error trying to dial one of them, likely
// the database server is down so try the next one.
for _, server := range s.cfg.Shuffle(proxyCtx.Servers) {
for _, server := range getShuffleFunc()(proxyCtx.Servers) {
s.log.Debugf("Dialing to %v.", server)
tlsConfig, err := s.getConfigForServer(ctx, proxyCtx.Identity, server)
if err != nil {
Expand All @@ -364,9 +397,9 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext
ConnType: types.DatabaseTunnel,
})
if err != nil {
// Connection problem indicates reverse tunnel to this server is down.
if trace.IsConnectionProblem(err) {
s.log.WithError(err).Warnf("Failed to dial %v.", server)
// If an agent is down, we'll retry on the next one (if available).
if isReverseTunnelDownError(err) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check for error type anyway? Shouldn't we just try to connect to all servers in case of any error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about it originally and decided to only retry reverse tunnel errors specifically to be on the safe side and not retry errors that should not be retried (like, target database connection errors, rbac, etc.). We do it same way for kube and apps.

s.log.WithError(err).Warnf("Failed to dial database %v.", server)
continue
}
return nil, trace.Wrap(err)
Expand All @@ -380,6 +413,13 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext
return nil, trace.BadParameter("failed to connect to any of the database servers")
}

// isReverseTunnelDownError returns true if the provided error indicates that
// the reverse tunnel connection is down e.g. because the agent is down.
func isReverseTunnelDownError(err error) bool {
return trace.IsConnectionProblem(err) ||
strings.Contains(err.Error(), reversetunnel.NoDatabaseTunnel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding missing switch r.ConnType -> case types.DatabaseTunnel case and and returning trace.ConnectionProblem(err, the check based on the NoDatabaseTunnel error message text seems be to be not necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's correct - previously, we would also return trace.ConnectionProblem there (if you look a few lines below, after directDial) and it wasn't detected in the db proxy. To confirm I also tried removing the error message matching and the integration test I wrote failed with the error this PR fixes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Any reason why not to create an var ErrNoDatabaseTunnel and just use here err == ErrNoDatabaseTunnel instead of comparing strings? For me sounds cleaner and safer as current implementation panics when err == nil - not very likely to happen but still.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I looked at it, we just get a generic trace error from reversetunnel/transport in this case so I'm not sure introducing another error type is going to help. Otherwise, the check for trace.ConnectionProblem which we already had here would have worked too.

}

// Proxy starts proxying all traffic received from database client between
// this proxy and Teleport database service over reverse tunnel.
//
Expand Down