Skip to content

Commit

Permalink
Share a single client as much as possible in tbot
Browse files Browse the repository at this point in the history
  • Loading branch information
strideynet committed Mar 4, 2024
1 parent f713fa4 commit 6621b16
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 74 deletions.
58 changes: 34 additions & 24 deletions lib/tbot/service_bot_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,24 @@ type identityService struct {
resolver reversetunnelclient.Resolver

mu sync.Mutex
_ident *identity.Identity
client auth.ClientI
facade *identity.Facade
}

func (s *identityService) String() string {
return "identity"
}

func (s *identityService) setIdent(i *identity.Identity) {
func (s *identityService) GetIdentity() *identity.Identity {
s.mu.Lock()
defer s.mu.Unlock()
s._ident = i
return s.facade.Get()
}

func (s *identityService) ident() *identity.Identity {
func (s *identityService) GetClient() auth.ClientI {
s.mu.Lock()
defer s.mu.Unlock()
return s._ident
return s.client
}

func (s *identityService) String() string {
return "identity"
}

func hasTokenChanged(configTokenBytes, identityBytes []byte) bool {
Expand Down Expand Up @@ -130,7 +131,7 @@ func (s *identityService) loadIdentityFromStore(ctx context.Context, store bot.D
return loadedIdent, nil
}

// Initialize attempts to loaad an existing identity from the bot's storage.
// Initialize attempts to load an existing identity from the bot's storage.
// If an identity is found, it is checked against the configured onboarding
// settings. It is then renewed and persisted.
//
Expand Down Expand Up @@ -160,7 +161,8 @@ func (s *identityService) Initialize(ctx context.Context) error {
if err := checkIdentity(s.log, loadedIdent); err != nil {
return trace.Wrap(err)
}
authClient, err := clientForIdentity(ctx, s.log, s.cfg, loadedIdent, s.resolver)
facade := identity.NewFacade(s.cfg.FIPS, s.cfg.Insecure, loadedIdent)
authClient, err := clientForFacade(ctx, s.log, s.cfg, facade, s.resolver)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -190,22 +192,27 @@ func (s *identityService) Initialize(ctx context.Context) error {
return trace.Wrap(err)
}

testClient, err := clientForIdentity(ctx, s.log, s.cfg, newIdentity, s.resolver)
// Create the facaded client we can share with other components of tbot.
facade := identity.NewFacade(s.cfg.FIPS, s.cfg.Insecure, newIdentity)
c, err := clientForFacade(ctx, s.log, s.cfg, facade, s.resolver)
if err != nil {
return trace.Wrap(err)
}
defer testClient.Close()
s.mu.Lock()
s.client = c
s.facade = facade
s.mu.Unlock()

s.setIdent(newIdentity)
s.log.Info("Identity initialized successfully")
return nil
}

// Attempt a request to make sure our client works so we can exit early if
// we are in a bad state.
if _, err := testClient.Ping(ctx); err != nil {
return trace.Wrap(err, "unable to communicate with auth server")
func (s *identityService) Close() error {
c := s.GetClient()
if c == nil {
return nil
}
s.log.Info("Identity initialized successfully.")

return nil
return trace.Wrap(c.Close())
}

func (s *identityService) Run(ctx context.Context) error {
Expand Down Expand Up @@ -281,7 +288,7 @@ func (s *identityService) renew(
ctx, span := tracer.Start(ctx, "identityService/renew")
defer span.End()

currentIdentity := s.ident()
currentIdentity := s.facade.Get()
// Make sure we can still write to the bot's destination.
if err := identity.VerifyWrite(ctx, botDestination); err != nil {
return trace.Wrap(err, "Cannot write to destination %s, aborting.", botDestination)
Expand All @@ -292,7 +299,10 @@ func (s *identityService) renew(
if s.cfg.Onboarding.RenewableJoinMethod() {
// When using a renewable join method, we use GenerateUserCerts to
// request a new certificate using our current identity.
authClient, err := clientForIdentity(ctx, s.log, s.cfg, currentIdentity, s.resolver)
// We explicitly create a new client here to ensure that the latest
// identity is being used!
facade := identity.NewFacade(s.cfg.FIPS, s.cfg.Insecure, currentIdentity)
authClient, err := clientForFacade(ctx, s.log, s.cfg, facade, s.resolver)
if err != nil {
return trace.Wrap(err, "creating auth client")
}
Expand All @@ -313,7 +323,7 @@ func (s *identityService) renew(
}

s.log.WithField("identity", describeTLSIdentity(s.log, newIdentity)).Info("Fetched new bot identity.")
s.setIdent(newIdentity)
s.facade.Set(newIdentity)

if err := identity.SaveIdentity(ctx, newIdentity, botDestination, identity.BotKinds()...); err != nil {
return trace.Wrap(err, "saving new identity")
Expand Down
18 changes: 5 additions & 13 deletions lib/tbot/service_ca_rotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ import (

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/auth"
)

// debouncer accepts a duration, and a function. When `attempt` is called on
Expand Down Expand Up @@ -130,9 +129,8 @@ const caRotationRetryBackoff = time.Second * 2
type caRotationService struct {
log logrus.FieldLogger
reloadBroadcaster *channelBroadcaster
cfg *config.BotConfig
botIdentitySrc botIdentitySrc
resolver reversetunnelclient.Resolver
botClient auth.ClientI
getBotIdentity getBotIdentityFn
}

func (s *caRotationService) String() string {
Expand Down Expand Up @@ -189,15 +187,9 @@ func (s *caRotationService) Run(ctx context.Context) error {
func (s *caRotationService) watchCARotations(ctx context.Context, queueReload func()) error {
s.log.Debugf("Attempting to establish watch for CA events")

ident := s.botIdentitySrc.BotIdentity()
client, err := clientForIdentity(ctx, s.log, s.cfg, ident, s.resolver)
if err != nil {
return trace.Wrap(err, "creating client for ca watcher")
}
defer client.Close()

ident := s.getBotIdentity()
clusterName := ident.ClusterName
watcher, err := client.NewWatcher(ctx, types.Watch{
watcher, err := s.botClient.NewWatcher(ctx, types.Watch{
Kinds: []types.WatchKind{{
Kind: types.KindCertAuthority,
Filter: types.CertAuthorityFilter{
Expand Down
10 changes: 6 additions & 4 deletions lib/tbot/service_ca_rotation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ func TestBot_Run_CARotation(t *testing.T) {
// Allow time for bot to start running and watching for CA rotations
// TODO: We should modify the bot to emit events that may be useful...
time.Sleep(10 * time.Second)
facade := b.botIdentitySvc.facade

// fetch initial host cert
require.Len(t, b.BotIdentity().TLSCACertsBytes, 2)
Expand All @@ -331,24 +332,25 @@ func TestBot_Run_CARotation(t *testing.T) {
// TODO: These sleeps allow the client time to rotate. They could be
// replaced if tbot emitted a CA rotation/renewal event.
time.Sleep(time.Second * 30)
_, err = clientForIdentity(ctx, log, botConfig, b.BotIdentity(), resolver)

_, err = clientForFacade(ctx, log, botConfig, facade, resolver)
require.NoError(t, err)

rotate(ctx, t, log, teleportProcess(), types.RotationPhaseUpdateClients)
time.Sleep(time.Second * 30)
// Ensure both sets of CA certificates are now available locally
require.Len(t, b.BotIdentity().TLSCACertsBytes, 3)
_, err = clientForIdentity(ctx, log, botConfig, b.BotIdentity(), resolver)
_, err = clientForFacade(ctx, log, botConfig, facade, resolver)
require.NoError(t, err)

rotate(ctx, t, log, teleportProcess(), types.RotationPhaseUpdateServers)
time.Sleep(time.Second * 30)
_, err = clientForIdentity(ctx, log, botConfig, b.BotIdentity(), resolver)
_, err = clientForFacade(ctx, log, botConfig, facade, resolver)
require.NoError(t, err)

rotate(ctx, t, log, teleportProcess(), types.RotationStateStandby)
time.Sleep(time.Second * 30)
_, err = clientForIdentity(ctx, log, botConfig, b.BotIdentity(), resolver)
_, err = clientForFacade(ctx, log, botConfig, facade, resolver)
require.NoError(t, err)

require.Len(t, b.BotIdentity().TLSCACertsBytes, 2)
Expand Down
20 changes: 8 additions & 12 deletions lib/tbot/service_outputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ const renewalRetryLimit = 5
type outputsService struct {
log logrus.FieldLogger
reloadBroadcaster *channelBroadcaster
botIdentitySrc botIdentitySrc
botClient auth.ClientI
getBotIdentity getBotIdentityFn
cfg *config.BotConfig
resolver reversetunnelclient.Resolver
}
Expand All @@ -85,25 +86,19 @@ func (s *outputsService) renewOutputs(
ctx, span := tracer.Start(ctx, "outputsService/renewOutputs")
defer span.End()

botIdentity := s.botIdentitySrc.BotIdentity()
client, err := clientForIdentity(ctx, s.log, s.cfg, botIdentity, s.resolver)
if err != nil {
return trace.Wrap(err)
}
defer client.Close()

// create a cache shared across outputs so they don't hammer the auth
// server with similar requests
drc := &outputRenewalCache{
client: client,
client: s.botClient,
cfg: s.cfg,
}

// Determine the default role list based on the bot role. The role's
// name should match the certificate's Key ID (user and role names
// should all match bot-$name)
botIdentity := s.getBotIdentity()
botResourceName := botIdentity.X509Cert.Subject.CommonName
defaultRoles, err := fetchDefaultRoles(ctx, client, botResourceName)
defaultRoles, err := fetchDefaultRoles(ctx, s.botClient, botResourceName)
if err != nil {
s.log.WithError(err).Warnf("Unable to determine default roles, no roles will be requested if unspecified")
defaultRoles = []string{}
Expand Down Expand Up @@ -132,7 +127,7 @@ func (s *outputsService) renewOutputs(
}

impersonatedIdentity, impersonatedClient, err := s.generateImpersonatedIdentity(
ctx, client, botIdentity, output, defaultRoles,
ctx, s.botClient, botIdentity, output, defaultRoles,
)
if err != nil {
return trace.Wrap(err, "generating impersonated certs for output: %s", output)
Expand Down Expand Up @@ -551,7 +546,8 @@ func (s *outputsService) generateImpersonatedIdentity(

// create a client that uses the impersonated identity, so that when we
// fetch information, we can ensure access rights are enforced.
impersonatedClient, err = clientForIdentity(ctx, s.log, s.cfg, impersonatedIdentity, s.resolver)
facade := identity.NewFacade(s.cfg.FIPS, s.cfg.Insecure, impersonatedIdentity)
impersonatedClient, err = clientForFacade(ctx, s.log, s.cfg, facade, s.resolver)
if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down
37 changes: 16 additions & 21 deletions lib/tbot/tbot.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,12 @@ func (b *Bot) markStarted() error {
return nil
}

type botIdentitySrc interface {
BotIdentity() *identity.Identity
}
type getBotIdentityFn func() *identity.Identity

// BotIdentity returns the bot's own identity. This will return nil if the bot
// has not been started.
func (b *Bot) BotIdentity() *identity.Identity {
return b.botIdentitySvc.ident()
return b.botIdentitySvc.GetIdentity()
}

func (b *Bot) Run(ctx context.Context) error {
Expand Down Expand Up @@ -163,6 +161,11 @@ func (b *Bot) Run(ctx context.Context) error {
if err := b.botIdentitySvc.Initialize(ctx); err != nil {
return trace.Wrap(err)
}
defer func() {
if err := b.botIdentitySvc.Close(); err != nil {
b.log.WithError(err).Error("Failed to close bot identity service")
}
}()
services = append(services, b.botIdentitySvc)

// Setup all other services
Expand All @@ -176,7 +179,8 @@ func (b *Bot) Run(ctx context.Context) error {
})
}
services = append(services, &outputsService{
botIdentitySrc: b,
getBotIdentity: b.botIdentitySvc.GetIdentity,
botClient: b.botIdentitySvc.GetClient(),
cfg: b.cfg,
resolver: resolver,
log: b.log.WithField(
Expand All @@ -185,9 +189,8 @@ func (b *Bot) Run(ctx context.Context) error {
reloadBroadcaster: reloadBroadcaster,
})
services = append(services, &caRotationService{
botIdentitySrc: b,
cfg: b.cfg,
resolver: resolver,
getBotIdentity: b.botIdentitySvc.GetIdentity,
botClient: b.botIdentitySvc.GetClient(),
log: b.log.WithField(
trace.Component, teleport.Component(componentTBot, "ca-rotation"),
),
Expand Down Expand Up @@ -348,28 +351,20 @@ func checkIdentity(log logrus.FieldLogger, ident *identity.Identity) error {
return nil
}

// clientForIdentity creates a new auth client from the given
// identity. Note that depending on the connection address given, this may
// clientForFacade creates a new auth client from the given
// facade. Note that depending on the connection address given, this may
// attempt to connect via the proxy and therefore requires both SSH and TLS
// credentials.
func clientForIdentity(
func clientForFacade(
ctx context.Context,
log logrus.FieldLogger,
cfg *config.BotConfig,
id *identity.Identity,
facade *identity.Facade,
resolver reversetunnelclient.Resolver,
) (auth.ClientI, error) {
ctx, span := tracer.Start(ctx, "clientForIdentity")
ctx, span := tracer.Start(ctx, "clientForFacade")
defer span.End()

if id.SSHCert == nil || id.X509Cert == nil {
return nil, trace.BadParameter("auth client requires a fully formed identity")
}

// TODO(noah): Eventually we'll want to reuse this facade across the bot
// rather than recreating it. Right now the blocker to that is handling the
// generation field on the certificate.
facade := identity.NewFacade(cfg.FIPS, cfg.Insecure, id)
tlsConfig, err := facade.TLSConfig()
if err != nil {
return nil, trace.Wrap(err)
Expand Down

0 comments on commit 6621b16

Please sign in to comment.