Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions lib/tbot/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package tbot

import (
"context"
"sync"

apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
trustv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/trust/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"google.golang.org/grpc"
)

type Client interface {
Close() error
GenerateHostCert(ctx context.Context, in *trustv1.GenerateHostCertRequest, opts ...grpc.CallOption) (*trustv1.GenerateHostCertResponse, error)
GenerateUserCerts(ctx context.Context, req proto.UserCertsRequest) (*proto.Certs, error)
GetAuthPreference(context.Context) (types.AuthPreference, error)
GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error)
GetCertAuthority(ctx context.Context, id types.CertAuthID, includeSigningKeys bool) (types.CertAuthority, error)
GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error)
GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error)
GetRemoteClusters(ctx context.Context) ([]types.RemoteCluster, error)
GetResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error)
GetRole(ctx context.Context, name string) (types.Role, error)
GetTrustCertAuthority(ctx context.Context, in *trustv1.GetCertAuthorityRequest, opts ...grpc.CallOption) (*types.CertAuthorityV2, error)
IssueWorkloadIdentities(ctx context.Context, in *workloadidentityv1pb.IssueWorkloadIdentitiesRequest, opts ...grpc.CallOption) (*workloadidentityv1pb.IssueWorkloadIdentitiesResponse, error)
IssueWorkloadIdentity(ctx context.Context, in *workloadidentityv1pb.IssueWorkloadIdentityRequest, opts ...grpc.CallOption) (*workloadidentityv1pb.IssueWorkloadIdentityResponse, error)
ListSPIFFEFederations(ctx context.Context, in *machineidv1pb.ListSPIFFEFederationsRequest, opts ...grpc.CallOption) (*machineidv1pb.ListSPIFFEFederationsResponse, error)
ListUnifiedResources(ctx context.Context, req *proto.ListUnifiedResourcesRequest) (*proto.ListUnifiedResourcesResponse, error)
NewWatcher(ctx context.Context, watch types.Watch) (types.Watcher, error)
Ping(ctx context.Context) (proto.PingResponse, error)
ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error)
SignJWTSVIDs(ctx context.Context, in *machineidv1pb.SignJWTSVIDsRequest, opts ...grpc.CallOption) (*machineidv1pb.SignJWTSVIDsResponse, error)
SignX509SVIDs(ctx context.Context, in *machineidv1pb.SignX509SVIDsRequest, opts ...grpc.CallOption) (*machineidv1pb.SignX509SVIDsResponse, error)
StreamSignedCRL(ctx context.Context, in *workloadidentityv1pb.StreamSignedCRLRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[workloadidentityv1pb.StreamSignedCRLResponse], error)
SubmitHeartbeat(ctx context.Context, in *machineidv1pb.SubmitHeartbeatRequest, opts ...grpc.CallOption) (*machineidv1pb.SubmitHeartbeatResponse, error)
}

type fallableClient struct {
mu sync.Mutex
client *apiclient.Client
err error
}

var _ Client = (*fallableClient)(nil)

func (f *fallableClient) setClient(client *apiclient.Client) {
f.mu.Lock()
defer f.mu.Unlock()

f.client = client
f.err = nil
}

func (f *fallableClient) getClient() (*apiclient.Client, error) {
f.mu.Lock()
defer f.mu.Unlock()

return f.client, f.err
}

func (c *fallableClient) Close() error {
if client, _ := c.getClient(); client != nil {
return client.Close()
}
return nil
}

func (c *fallableClient) GenerateHostCert(ctx context.Context, in *trustv1.GenerateHostCertRequest, opts ...grpc.CallOption) (*trustv1.GenerateHostCertResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.TrustClient().GenerateHostCert(ctx, in, opts...)
}

func (c *fallableClient) GenerateUserCerts(ctx context.Context, in proto.UserCertsRequest) (*proto.Certs, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GenerateUserCerts(ctx, in)
}

func (c *fallableClient) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetAuthPreference(ctx)
}

func (c *fallableClient) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetCertAuthorities(ctx, caType, loadKeys)
}

func (c *fallableClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, includeSigningKeys bool) (types.CertAuthority, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetCertAuthority(ctx, id, includeSigningKeys)
}

func (c *fallableClient) GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetClusterCACert(ctx)
}

func (c *fallableClient) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetClusterNetworkingConfig(ctx)
}

func (c *fallableClient) GetRemoteClusters(ctx context.Context) ([]types.RemoteCluster, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetRemoteClusters(ctx)
}

func (c *fallableClient) GetResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetResources(ctx, req)
}

func (c *fallableClient) GetRole(ctx context.Context, name string) (types.Role, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.GetRole(ctx, name)
}

func (c *fallableClient) GetTrustCertAuthority(ctx context.Context, in *trustv1.GetCertAuthorityRequest, opts ...grpc.CallOption) (*types.CertAuthorityV2, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.TrustClient().GetCertAuthority(ctx, in, opts...)
}

func (c *fallableClient) IssueWorkloadIdentities(ctx context.Context, in *workloadidentityv1pb.IssueWorkloadIdentitiesRequest, opts ...grpc.CallOption) (*workloadidentityv1pb.IssueWorkloadIdentitiesResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.WorkloadIdentityIssuanceClient().IssueWorkloadIdentities(ctx, in, opts...)
}

func (c *fallableClient) IssueWorkloadIdentity(ctx context.Context, in *workloadidentityv1pb.IssueWorkloadIdentityRequest, opts ...grpc.CallOption) (*workloadidentityv1pb.IssueWorkloadIdentityResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.WorkloadIdentityIssuanceClient().IssueWorkloadIdentity(ctx, in, opts...)
}

func (c *fallableClient) ListSPIFFEFederations(ctx context.Context, in *machineidv1pb.ListSPIFFEFederationsRequest, opts ...grpc.CallOption) (*machineidv1pb.ListSPIFFEFederationsResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.SPIFFEFederationServiceClient().ListSPIFFEFederations(ctx, in, opts...)
}

func (c *fallableClient) ListUnifiedResources(ctx context.Context, req *proto.ListUnifiedResourcesRequest) (*proto.ListUnifiedResourcesResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.ListUnifiedResources(ctx, req)
}

func (c *fallableClient) NewWatcher(ctx context.Context, watch types.Watch) (types.Watcher, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.NewWatcher(ctx, watch)
}

func (c *fallableClient) Ping(ctx context.Context) (proto.PingResponse, error) {
client, err := c.getClient()
if err != nil {
return proto.PingResponse{}, err
}
return client.Ping(ctx)
}

func (c *fallableClient) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.ResolveSSHTarget(ctx, req)
}

func (c *fallableClient) SignJWTSVIDs(ctx context.Context, in *machineidv1pb.SignJWTSVIDsRequest, opts ...grpc.CallOption) (*machineidv1pb.SignJWTSVIDsResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.WorkloadIdentityServiceClient().SignJWTSVIDs(ctx, in, opts...)
}

func (c *fallableClient) SignX509SVIDs(ctx context.Context, in *machineidv1pb.SignX509SVIDsRequest, opts ...grpc.CallOption) (*machineidv1pb.SignX509SVIDsResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.WorkloadIdentityServiceClient().SignX509SVIDs(ctx, in, opts...)
}

func (c *fallableClient) StreamSignedCRL(ctx context.Context, in *workloadidentityv1pb.StreamSignedCRLRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[workloadidentityv1pb.StreamSignedCRLResponse], error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.WorkloadIdentityRevocationServiceClient().StreamSignedCRL(ctx, in, opts...)
}

func (c *fallableClient) SubmitHeartbeat(ctx context.Context, in *machineidv1pb.SubmitHeartbeatRequest, opts ...grpc.CallOption) (*machineidv1pb.SubmitHeartbeatResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
return client.BotInstanceServiceClient().SubmitHeartbeat(ctx, in, opts...)
}
6 changes: 5 additions & 1 deletion lib/tbot/config/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ import (
// If there's no destination in the provided yaml node, then this will return
// nil, nil.
func extractOutputDestination(node *yaml.Node) (bot.Destination, error) {
return extractDestinationField(node, "destination")
}

func extractDestinationField(node *yaml.Node, name string) (bot.Destination, error) {
for i, subNode := range node.Content {
if subNode.Value == "destination" {
if subNode.Value == name {
// Next node will be the contents
dest, err := unmarshalDestination(node.Content[i+1])
if err != nil {
Expand Down
22 changes: 20 additions & 2 deletions lib/tbot/config/service_workload_identity_aws_ra.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type WorkloadIdentityAWSRAService struct {
// Destination is where the credentials should be written to.
Destination bot.Destination `yaml:"destination"`

Cache bot.Destination `yaml:"cache"`

// RoleARN is the ARN of the role to assume.
// Example: `arn:aws:iam::123456789012:role/example-role`
// Required.
Expand Down Expand Up @@ -96,16 +98,25 @@ type WorkloadIdentityAWSRAService struct {
EndpointOverride string `yaml:"-"`
}

// Init initializes the destination.
// Init initializes the output and cache destinations.
func (o *WorkloadIdentityAWSRAService) Init(ctx context.Context) error {
return trace.Wrap(o.Destination.Init(ctx, []string{}))
return trace.NewAggregate(
o.Destination.Init(ctx, []string{}),
o.Cache.Init(ctx, []string{}),
)
}

// CheckAndSetDefaults checks the WorkloadIdentityAWSRAService values and sets any defaults.
func (o *WorkloadIdentityAWSRAService) CheckAndSetDefaults() error {
if err := validateOutputDestination(o.Destination); err != nil {
return trace.Wrap(err)
}
if o.Cache == nil {
o.Cache = &DestinationMemory{}
}
if err := o.Cache.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err, "validating cache destination")
}
if err := o.Selector.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err, "validating selector")
}
Expand Down Expand Up @@ -149,6 +160,8 @@ func (o *WorkloadIdentityAWSRAService) CheckAndSetDefaults() error {
return nil
}

// TODO: this seems to be wrong - check it out.
//
// Describe returns the file descriptions for the WorkloadIdentityJWTService.
func (o *WorkloadIdentityAWSRAService) Describe() []FileDescription {
fds := []FileDescription{
Expand All @@ -175,12 +188,17 @@ func (o *WorkloadIdentityAWSRAService) UnmarshalYAML(node *yaml.Node) error {
if err != nil {
return trace.Wrap(err)
}
cache, err := extractDestinationField(node, "cache")
if err != nil {
return trace.Wrap(err)
}
// Alias type to remove UnmarshalYAML to avoid recursion
type raw WorkloadIdentityAWSRAService
if err := node.Decode((*raw)(o)); err != nil {
return trace.Wrap(err)
}
o.Destination = dest
o.Cache = cache
return nil
}

Expand Down
5 changes: 2 additions & 3 deletions lib/tbot/database_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/authclient"
libdefaults "github.com/gravitational/teleport/lib/defaults"
)

func getDatabase(ctx context.Context, clt *authclient.Client, name string) (types.Database, error) {
func getDatabase(ctx context.Context, clt Client, name string) (types.Database, error) {
ctx, span := tracer.Start(ctx, "getDatabase")
defer span.End()

Expand All @@ -59,7 +58,7 @@ func getDatabase(ctx context.Context, clt *authclient.Client, name string) (type
func getRouteToDatabase(
ctx context.Context,
log *slog.Logger,
client *authclient.Client,
client Client,
service string,
username string,
database string,
Expand Down
2 changes: 1 addition & 1 deletion lib/tbot/output_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ type identityConfigurator = func(req *proto.UserCertsRequest)
// certs.
func generateIdentity(
ctx context.Context,
client *authclient.Client,
client Client,
currentIdentity *identity.Identity,
roles []string,
ttl time.Duration,
Expand Down
9 changes: 4 additions & 5 deletions lib/tbot/service_application_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/identity"
Expand All @@ -39,7 +38,7 @@ import (
// ApplicationOutputService generates the artifacts necessary to connect to a
// HTTP or TCP application using Teleport.
type ApplicationOutputService struct {
botAuthClient *authclient.Client
botAuthClient Client
botCfg *config.BotConfig
cfg *config.ApplicationOutput
getBotIdentity getBotIdentityFn
Expand Down Expand Up @@ -115,7 +114,7 @@ func (s *ApplicationOutputService) generate(ctx context.Context) error {
// create a client that uses the impersonated identity, so that when we
// fetch information, we can ensure access rights are enforced.
facade := identity.NewFacade(s.botCfg.FIPS, s.botCfg.Insecure, id)
impersonatedClient, err := clientForFacade(ctx, s.log, s.botCfg, facade, s.resolver)
impersonatedClient, err := temporaryClient(ctx, s.log, s.botCfg, facade, s.resolver)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -210,7 +209,7 @@ func (s *ApplicationOutputService) render(
func getRouteToApp(
ctx context.Context,
botIdentity *identity.Identity,
client *authclient.Client,
client Client,
appName string,
) (proto.RouteToApp, types.Application, error) {
ctx, span := tracer.Start(ctx, "getRouteToApp")
Expand All @@ -233,7 +232,7 @@ func getRouteToApp(
return routeToApp, app, nil
}

func getApp(ctx context.Context, clt *authclient.Client, appName string) (types.Application, error) {
func getApp(ctx context.Context, clt Client, appName string) (types.Application, error) {
ctx, span := tracer.Start(ctx, "getApp")
defer span.End()

Expand Down
Loading
Loading