Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel execution to assist. #26563

Merged
merged 5 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
150 changes: 89 additions & 61 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,80 +188,108 @@ func (h *Handler) executeCommand(
h.log.Debugf("Found %d hosts to run Assist command %q on.", len(hosts), req.Command)

mfaCacheFn := getMFACacheFn()
interactiveCommand := strings.Split(req.Command, " ")

for _, host := range hosts {
err := func() error {
sessionData, err := h.generateCommandSession(&host, req.Login, clusterName, sessionCtx.cfg.User)
if err != nil {
h.log.WithError(err).Debug("Unable to generate new ssh session.")
return trace.Wrap(err)
}
runCmd := func(host *hostInfo) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this into a seperate function at all? executeCommand is already pretty unwieldy to look at.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried, but this function uses 11 variables from the outer scope. I've tried to move them to a new struct, but IMO, at that point, it's not worth it. If you have a strong opinion about that I can do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 fine then

sessionData, err := h.generateCommandSession(host, req.Login, clusterName, sessionCtx.cfg.User)
if err != nil {
h.log.WithError(err).Debug("Unable to generate new ssh session.")
return trace.Wrap(err)
}

h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.",
host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID())

commandHandlerConfig := CommandHandlerConfig{
SessionCtx: sessionCtx,
AuthProvider: clt,
SessionData: sessionData,
KeepAliveInterval: netConfig.GetKeepAliveInterval(),
ProxyHostPort: h.ProxyHostPort(),
InteractiveCommand: strings.Split(req.Command, " "),
Router: h.cfg.Router,
TracerProvider: h.cfg.TracerProvider,
LocalAuthProvider: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
}
h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.",
host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID())

commandHandlerConfig := CommandHandlerConfig{
SessionCtx: sessionCtx,
AuthProvider: clt,
SessionData: sessionData,
KeepAliveInterval: keepAliveInterval,
ProxyHostPort: h.ProxyHostPort(),
InteractiveCommand: interactiveCommand,
Router: h.cfg.Router,
TracerProvider: h.cfg.TracerProvider,
LocalAuthProvider: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
}

handler, err := newCommandHandler(ctx, commandHandlerConfig)
if err != nil {
h.log.WithError(err).Error("Unable to create terminal.")
return trace.Wrap(err)
}
handler.ws = &noopCloserWS{ws}

h.userConns.Add(1)
defer h.userConns.Add(-1)

h.log.Infof("Executing command: %#v.", req)
httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r)

msgPayload, err := json.Marshal(struct {
NodeID string `json:"node_id"`
ExecutionID string `json:"execution_id"`
SessionID string `json:"session_id"`
}{
NodeID: host.id,
ExecutionID: req.ExecutionID,
SessionID: string(sessionData.ID),
})

if err != nil {
return trace.Wrap(err)
}
handler, err := newCommandHandler(ctx, commandHandlerConfig)
if err != nil {
h.log.WithError(err).Error("Unable to create terminal.")
return trace.Wrap(err)
}
handler.ws = &noopCloserWS{ws}

err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
ConversationId: req.ConversationID,
Username: identity.TeleportUser,
Message: &assist.AssistantMessage{
Type: string(assistlib.MessageKindCommandResult),
CreatedTime: timestamppb.New(time.Now().UTC()),
Payload: string(msgPayload),
},
})
h.userConns.Add(1)
defer h.userConns.Add(-1)

return trace.Wrap(err)
}()
h.log.Infof("Executing command: %#v.", req)
httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r)

msgPayload, err := json.Marshal(struct {
NodeID string `json:"node_id"`
ExecutionID string `json:"execution_id"`
SessionID string `json:"session_id"`
}{
NodeID: host.id,
ExecutionID: req.ExecutionID,
SessionID: string(sessionData.ID),
})

if err != nil {
h.log.WithError(err).Warnf("Failed to start session: %v", host.hostName)
continue
return trace.Wrap(err)
}

err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
ConversationId: req.ConversationID,
Username: identity.TeleportUser,
Message: &assist.AssistantMessage{
Type: string(assistlib.MessageKindCommandResult),
CreatedTime: timestamppb.New(time.Now().UTC()),
Payload: string(msgPayload),
},
})

return trace.Wrap(err)
}

runCommands(hosts, runCmd, h.log)

return nil, nil
}

// runCommands runs the given command on the given hosts.
func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, log logrus.FieldLogger) {
// Create a synchronization channel to limit the number of concurrent commands.
// The maximum number of concurrent commands is 30 - it is arbitrary.
syncChan := make(chan struct{}, 30)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can/should we put this in a configuration resource somewhere? I can imagine someone wanting to play with/adjust this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also thinking about it, but our "configuration management" looks a bit messy now.
@justinas Any concerns with adding this value to teleport.yaml to the assist section?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No concerns, we'll likely leave it as default on Cloud.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xacrimon I created a new issue to address this comment, as I don't want to introduce more changes after people have already reviewed it. https://github.com/gravitational/teleport.e/issues/1516

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splendid

Copy link
Contributor

@xacrimon xacrimon May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this use sync/semaphore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also thinking about it, but Go has divided opinions about semaphores in general. I remember a proposal to remove semaphores in Go 2, which had many examples of why you should not use than. Now even stdlib recommends "kind of" against them https://pkg.go.dev/sync#Cond

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's uh, certainly a way to design a language. anyhow, in that case, fine by me.

// WaiteGroup to wait for all commands to finish.
wg := sync.WaitGroup{}

for _, host := range hosts {
host := host
wg.Add(1)

go func() {
defer wg.Done()

// Limit the number of concurrent commands.
syncChan <- struct{}{}
defer func() {
// Release the command slot.
<-syncChan
}()

if err := runCmd(&host); err != nil {
log.WithError(err).Warnf("Failed to start session: %v", host.hostName)
}
}()
}

// Wait for all commands to finish.
wg.Wait()
}

// getMFACacheFn returns a function that caches the result of the given
// get function. The cache is protected by a mutex, so it is safe to call
// the returned function from multiple goroutines.
Expand Down
28 changes: 28 additions & 0 deletions lib/web/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net/http"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -33,6 +34,7 @@ import (
"github.com/gorilla/websocket"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/client"
Expand Down Expand Up @@ -154,3 +156,29 @@ func waitForCommandOutput(stream io.Reader, substr string) error {
}
}
}

// Test_runCommands tests that runCommands runs the given command on all hosts.
// The commands should run in parallel, but we don't have a deterministic way to
// test that (sleep with checking the execution time in not deterministic).
func Test_runCommands(t *testing.T) {
counter := atomic.Int32{}

runCmd := func(host *hostInfo) error {
counter.Add(1)
return nil
}

hosts := make([]hostInfo, 0, 100)
for i := 0; i < 100; i++ {
hosts = append(hosts, hostInfo{
hostName: fmt.Sprintf("localhost%d", i),
})
}

logger := logrus.New()
logger.Out = io.Discard

runCommands(hosts, runCmd, logger)

require.Equal(t, int32(100), counter.Load())
}