Skip to content

Commit

Permalink
Add parallel execution to assist. (#26563)
Browse files Browse the repository at this point in the history
* Add parallel execution to assist.

* Extract execution logic to a new function.

* Add test

* Switch uber to std

* Address code review comments
  • Loading branch information
jakule committed Jun 2, 2023
1 parent f5fd1a0 commit 71b1f49
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 61 deletions.
150 changes: 89 additions & 61 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,80 +188,108 @@ func (h *Handler) executeCommand(
h.log.Debugf("Found %d hosts to run Assist command %q on.", len(hosts), req.Command)

mfaCacheFn := getMFACacheFn()
interactiveCommand := strings.Split(req.Command, " ")

for _, host := range hosts {
err := func() error {
sessionData, err := h.generateCommandSession(&host, req.Login, clusterName, sessionCtx.cfg.User)
if err != nil {
h.log.WithError(err).Debug("Unable to generate new ssh session.")
return trace.Wrap(err)
}
runCmd := func(host *hostInfo) error {
sessionData, err := h.generateCommandSession(host, req.Login, clusterName, sessionCtx.cfg.User)
if err != nil {
h.log.WithError(err).Debug("Unable to generate new ssh session.")
return trace.Wrap(err)
}

h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.",
host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID())

commandHandlerConfig := CommandHandlerConfig{
SessionCtx: sessionCtx,
AuthProvider: clt,
SessionData: sessionData,
KeepAliveInterval: netConfig.GetKeepAliveInterval(),
ProxyHostPort: h.ProxyHostPort(),
InteractiveCommand: strings.Split(req.Command, " "),
Router: h.cfg.Router,
TracerProvider: h.cfg.TracerProvider,
LocalAuthProvider: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
}
h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.",
host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID())

commandHandlerConfig := CommandHandlerConfig{
SessionCtx: sessionCtx,
AuthProvider: clt,
SessionData: sessionData,
KeepAliveInterval: keepAliveInterval,
ProxyHostPort: h.ProxyHostPort(),
InteractiveCommand: interactiveCommand,
Router: h.cfg.Router,
TracerProvider: h.cfg.TracerProvider,
LocalAuthProvider: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
}

handler, err := newCommandHandler(ctx, commandHandlerConfig)
if err != nil {
h.log.WithError(err).Error("Unable to create terminal.")
return trace.Wrap(err)
}
handler.ws = &noopCloserWS{ws}

h.userConns.Add(1)
defer h.userConns.Add(-1)

h.log.Infof("Executing command: %#v.", req)
httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r)

msgPayload, err := json.Marshal(struct {
NodeID string `json:"node_id"`
ExecutionID string `json:"execution_id"`
SessionID string `json:"session_id"`
}{
NodeID: host.id,
ExecutionID: req.ExecutionID,
SessionID: string(sessionData.ID),
})

if err != nil {
return trace.Wrap(err)
}
handler, err := newCommandHandler(ctx, commandHandlerConfig)
if err != nil {
h.log.WithError(err).Error("Unable to create terminal.")
return trace.Wrap(err)
}
handler.ws = &noopCloserWS{ws}

err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
ConversationId: req.ConversationID,
Username: identity.TeleportUser,
Message: &assist.AssistantMessage{
Type: string(assistlib.MessageKindCommandResult),
CreatedTime: timestamppb.New(time.Now().UTC()),
Payload: string(msgPayload),
},
})
h.userConns.Add(1)
defer h.userConns.Add(-1)

return trace.Wrap(err)
}()
h.log.Infof("Executing command: %#v.", req)
httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r)

msgPayload, err := json.Marshal(struct {
NodeID string `json:"node_id"`
ExecutionID string `json:"execution_id"`
SessionID string `json:"session_id"`
}{
NodeID: host.id,
ExecutionID: req.ExecutionID,
SessionID: string(sessionData.ID),
})

if err != nil {
h.log.WithError(err).Warnf("Failed to start session: %v", host.hostName)
continue
return trace.Wrap(err)
}

err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
ConversationId: req.ConversationID,
Username: identity.TeleportUser,
Message: &assist.AssistantMessage{
Type: string(assistlib.MessageKindCommandResult),
CreatedTime: timestamppb.New(time.Now().UTC()),
Payload: string(msgPayload),
},
})

return trace.Wrap(err)
}

runCommands(hosts, runCmd, h.log)

return nil, nil
}

// runCommands runs the given command on the given hosts.
func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, log logrus.FieldLogger) {
// Create a synchronization channel to limit the number of concurrent commands.
// The maximum number of concurrent commands is 30 - it is arbitrary.
syncChan := make(chan struct{}, 30)
// WaiteGroup to wait for all commands to finish.
wg := sync.WaitGroup{}

for _, host := range hosts {
host := host
wg.Add(1)

go func() {
defer wg.Done()

// Limit the number of concurrent commands.
syncChan <- struct{}{}
defer func() {
// Release the command slot.
<-syncChan
}()

if err := runCmd(&host); err != nil {
log.WithError(err).Warnf("Failed to start session: %v", host.hostName)
}
}()
}

// Wait for all commands to finish.
wg.Wait()
}

// getMFACacheFn returns a function that caches the result of the given
// get function. The cache is protected by a mutex, so it is safe to call
// the returned function from multiple goroutines.
Expand Down
28 changes: 28 additions & 0 deletions lib/web/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net/http"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -33,6 +34,7 @@ import (
"github.com/gorilla/websocket"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/client"
Expand Down Expand Up @@ -154,3 +156,29 @@ func waitForCommandOutput(stream io.Reader, substr string) error {
}
}
}

// Test_runCommands tests that runCommands runs the given command on all hosts.
// The commands should run in parallel, but we don't have a deterministic way to
// test that (sleep with checking the execution time in not deterministic).
func Test_runCommands(t *testing.T) {
counter := atomic.Int32{}

runCmd := func(host *hostInfo) error {
counter.Add(1)
return nil
}

hosts := make([]hostInfo, 0, 100)
for i := 0; i < 100; i++ {
hosts = append(hosts, hostInfo{
hostName: fmt.Sprintf("localhost%d", i),
})
}

logger := logrus.New()
logger.Out = io.Discard

runCommands(hosts, runCmd, logger)

require.Equal(t, int32(100), counter.Load())
}

0 comments on commit 71b1f49

Please sign in to comment.