Skip to content

Commit

Permalink
[v13] Allow configuring number of parallel execution workers (#29061)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinas committed Jul 14, 2023
1 parent 8e1526b commit c08e4dc
Show file tree
Hide file tree
Showing 8 changed files with 1,110 additions and 1,037 deletions.
4 changes: 4 additions & 0 deletions api/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ const (
// BreakerRatioMinExecutions is the minimum number of requests before the ratio tripper
// will consider examining the request pass rate
BreakerRatioMinExecutions = 10

// AssistCommandExecutionWorkers is the number of workers that will
// execute arbitrary remote commands on servers in parallel
AssistCommandExecutionWorkers = 30
)

var (
Expand Down
4 changes: 4 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,10 @@ message ClusterNetworkingConfigSpecV2 {
(gogoproto.jsontag) = "proxy_ping_interval,omitempty",
(gogoproto.casttype) = "Duration"
];

// AssistCommandExecutionWorkers determines the number of workers that will
// execute arbitrary Assist commands on servers in parallel
int32 AssistCommandExecutionWorkers = 11 [(gogoproto.jsontag) = "assist_command_execution_workers,omitempty"];
}

// TunnelStrategyV1 defines possible tunnel strategy types.
Expand Down
22 changes: 22 additions & 0 deletions api/types/networking.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ type ClusterNetworkingConfig interface {

// SetProxyPingInterval sets the proxy ping interval.
SetProxyPingInterval(time.Duration)

// GetAssistCommandExecutionWorkers gets the number of parallel command execution workers for Assist
GetAssistCommandExecutionWorkers() int32

// SetAssistCommandExecutionWorkers sets the number of parallel command execution workers for Assist
SetAssistCommandExecutionWorkers(n int32)
}

// NewClusterNetworkingConfigFromConfigFile is a convenience method to create
Expand Down Expand Up @@ -356,6 +362,12 @@ func (c *ClusterNetworkingConfigV2) CheckAndSetDefaults() error {
return trace.Wrap(err)
}

if c.Spec.AssistCommandExecutionWorkers < 0 {
return trace.BadParameter("command_execution_workers must be non-negative")
} else if c.Spec.AssistCommandExecutionWorkers == 0 {
c.Spec.AssistCommandExecutionWorkers = defaults.AssistCommandExecutionWorkers
}

return nil
}

Expand All @@ -369,6 +381,16 @@ func (c *ClusterNetworkingConfigV2) SetProxyPingInterval(interval time.Duration)
c.Spec.ProxyPingInterval = Duration(interval)
}

// GetAssistCommandExecutionWorkers gets the number of parallel command execution workers for Assist
func (c *ClusterNetworkingConfigV2) GetAssistCommandExecutionWorkers() int32 {
return c.Spec.AssistCommandExecutionWorkers
}

// SetAssistCommandExecutionWorkers sets the number of parallel command execution workers for Assist
func (c *ClusterNetworkingConfigV2) SetAssistCommandExecutionWorkers(n int32) {
c.Spec.AssistCommandExecutionWorkers = n
}

// MarshalYAML defines how a proxy listener mode should be marshaled to a string
func (p ProxyListenerMode) MarshalYAML() (interface{}, error) {
return strings.ToLower(p.String()), nil
Expand Down
2,025 changes: 1,029 additions & 996 deletions api/types/types.pb.go

Large diffs are not rendered by default.

32 changes: 20 additions & 12 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,11 @@ func applyAuthConfig(fc *FileConfig, cfg *servicecfg.Config) error {
key, err := os.ReadFile(keyPath)
if err != nil {
return trace.Errorf("failed to read OpenAI API key file: %w", err)
} else {
cfg.Auth.AssistAPIKey = strings.TrimSpace(string(key))
}
cfg.Auth.AssistAPIKey = strings.TrimSpace(string(key))

if fc.Auth.Assist.CommandExecutionWorkers < 0 {
return trace.BadParameter("command_execution_workers must not be negative")
}
}

Expand All @@ -741,17 +744,22 @@ func applyAuthConfig(fc *FileConfig, cfg *servicecfg.Config) error {
// Only override networking configuration if some of its fields are
// specified in file configuration.
if fc.Auth.hasCustomNetworkingConfig() {
var assistCommandExecutionWorkers int32
if fc.Auth.Assist != nil {
assistCommandExecutionWorkers = fc.Auth.Assist.CommandExecutionWorkers
}
cfg.Auth.NetworkingConfig, err = types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{
ClientIdleTimeout: fc.Auth.ClientIdleTimeout,
ClientIdleTimeoutMessage: fc.Auth.ClientIdleTimeoutMessage,
WebIdleTimeout: fc.Auth.WebIdleTimeout,
KeepAliveInterval: fc.Auth.KeepAliveInterval,
KeepAliveCountMax: fc.Auth.KeepAliveCountMax,
SessionControlTimeout: fc.Auth.SessionControlTimeout,
ProxyListenerMode: fc.Auth.ProxyListenerMode,
RoutingStrategy: fc.Auth.RoutingStrategy,
TunnelStrategy: fc.Auth.TunnelStrategy,
ProxyPingInterval: fc.Auth.ProxyPingInterval,
ClientIdleTimeout: fc.Auth.ClientIdleTimeout,
ClientIdleTimeoutMessage: fc.Auth.ClientIdleTimeoutMessage,
WebIdleTimeout: fc.Auth.WebIdleTimeout,
KeepAliveInterval: fc.Auth.KeepAliveInterval,
KeepAliveCountMax: fc.Auth.KeepAliveCountMax,
SessionControlTimeout: fc.Auth.SessionControlTimeout,
ProxyListenerMode: fc.Auth.ProxyListenerMode,
RoutingStrategy: fc.Auth.RoutingStrategy,
TunnelStrategy: fc.Auth.TunnelStrategy,
ProxyPingInterval: fc.Auth.ProxyPingInterval,
AssistCommandExecutionWorkers: assistCommandExecutionWorkers,
})
if err != nil {
return trace.Wrap(err)
Expand Down
22 changes: 18 additions & 4 deletions lib/config/fileconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ type Auth struct {
HostedPlugins HostedPlugins `yaml:"hosted_plugins,omitempty"`

// Assist is a set of options related to the Teleport Assist feature.
Assist *AssistOptions `yaml:"assist,omitempty"`
Assist *AuthAssistOptions `yaml:"assist,omitempty"`
}

// PluginService represents the configuration for the plugin service.
Expand Down Expand Up @@ -1044,7 +1044,8 @@ func (a *Auth) hasCustomNetworkingConfig() bool {
a.ProxyListenerMode != empty.ProxyListenerMode ||
a.RoutingStrategy != empty.RoutingStrategy ||
a.TunnelStrategy != empty.TunnelStrategy ||
a.ProxyPingInterval != empty.ProxyPingInterval
a.ProxyPingInterval != empty.ProxyPingInterval ||
(a.Assist != nil && a.Assist.CommandExecutionWorkers != 0)
}

// hasCustomSessionRecording returns true if any of the session recording
Expand Down Expand Up @@ -1379,12 +1380,25 @@ func (dt *DeviceTrust) Parse() (*types.DeviceTrust, error) {
}, nil
}

// AssistOptions is a set of options related to the Teleport Assist feature.
// AssistOptions is a set of options common to both Auth and Proxy related to the Teleport Assist feature.
type AssistOptions struct {
// OpenAI is a set of options related to the OpenAI assist backend.
OpenAI *OpenAIOptions `yaml:"openai,omitempty"`
}

// ProxyAssistOptions is a set of proxy service options related to the Assist feature
type ProxyAssistOptions struct {
AssistOptions `yaml:",inline"`
}

// AuthAssistOptions is a set of auth service options related to the Assist feature
type AuthAssistOptions struct {
AssistOptions `yaml:",inline"`
// CommandExecutionWorkers determines the number of workers that will
// execute arbitrary remote commands on servers (e.g. through Assist) in parallel
CommandExecutionWorkers int32 `yaml:"command_execution_workers,omitempty"`
}

// OpenAIOptions stores options related to the OpenAI assist backend.
type OpenAIOptions struct {
// APITokenPath is the path to a file with OpenAI API key.
Expand Down Expand Up @@ -2120,7 +2134,7 @@ type Proxy struct {
UI *UIConfig `yaml:"ui,omitempty"`

// Assist is a set of options related to the Teleport Assist feature.
Assist *AssistOptions `yaml:"assist,omitempty"`
Assist *ProxyAssistOptions `yaml:"assist,omitempty"`

// TrustXForwardedFor enables the service to take client source IPs from
// the "X-Forwarded-For" headers for web APIs received from layer 7 load
Expand Down
35 changes: 11 additions & 24 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/sirupsen/logrus"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -284,7 +285,7 @@ func (h *Handler) executeCommand(
return trace.Wrap(err)
}

runCommands(hosts, runCmd, h.log)
runCommands(hosts, runCmd, int(netConfig.GetAssistCommandExecutionWorkers()), h.log)

// Optionally, try to compute the command summary.
if output, valid := buffer.Export(); valid {
Expand Down Expand Up @@ -395,35 +396,21 @@ func outputByName(hosts []hostInfo, output map[string][]byte) map[string][]byte
}

// runCommands runs the given command on the given hosts.
func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, log logrus.FieldLogger) {
// Create a synchronization channel to limit the number of concurrent commands.
// The maximum number of concurrent commands is 30 - it is arbitrary.
syncChan := make(chan struct{}, 30)
// WaiteGroup to wait for all commands to finish.
wg := sync.WaitGroup{}
func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, numParallel int, log logrus.FieldLogger) {
var group errgroup.Group
group.SetLimit(numParallel)

for _, host := range hosts {
host := host
wg.Add(1)

go func() {
defer wg.Done()

// Limit the number of concurrent commands.
syncChan <- struct{}{}
defer func() {
// Release the command slot.
<-syncChan
}()

if err := runCmd(&host); err != nil {
log.WithError(err).Warnf("Failed to start session: %v", host.hostName)
}
}()
group.Go(func() error {
return trace.Wrap(runCmd(&host), "failed to start session on %v", host.hostName)
})
}

// Wait for all commands to finish.
wg.Wait()
if err := group.Wait(); err != nil {
log.WithError(err).Debug("Assist command execution failed")
}
}

// getMFACacheFn returns a function that caches the result of the given
Expand Down
3 changes: 2 additions & 1 deletion lib/web/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ func waitForCommandOutput(stream io.Reader, substr string) error {
// The commands should run in parallel, but we don't have a deterministic way to
// test that (sleep with checking the execution time in not deterministic).
func Test_runCommands(t *testing.T) {
const numWorkers = 30
counter := atomic.Int32{}

runCmd := func(host *hostInfo) error {
Expand All @@ -339,7 +340,7 @@ func Test_runCommands(t *testing.T) {
logger := logrus.New()
logger.Out = io.Discard

runCommands(hosts, runCmd, logger)
runCommands(hosts, runCmd, numWorkers, logger)

require.Equal(t, int32(100), counter.Load())
}
Expand Down

0 comments on commit c08e4dc

Please sign in to comment.