Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Assist] Add in SSH context Assist endpoints #30319

Merged
merged 5 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ func (chat *Chat) Complete(ctx context.Context, userInput string, progressUpdate
}, model.NewTokenCount(), nil
}

return chat.Reply(ctx, userInput, progressUpdates)
}

// Reply replies to the user input with a message from the assistant based on the current context.
func (chat *Chat) Reply(ctx context.Context, userInput string, progressUpdates func(*model.AgentAction)) (any, *model.TokenCount, error) {
userMessage := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: userInput,
Expand Down
31 changes: 29 additions & 2 deletions lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,34 @@ func NewClientFromConfig(config openai.ClientConfig) *Client {
// toolsConfig contains all required clients and configuration for agent tools
// to interact with Teleport.
func (client *Client) NewChat(username string, toolsConfig model.ToolsConfig) (*Chat, error) {
agent, err := model.NewAgent(username, toolsConfig)
tools := []model.Tool{
model.NewExecutionTool(),
}
if !toolsConfig.DisableEmbeddingsTool {
tools = append(tools, model.NewRetrievalTool(toolsConfig.EmbeddingsClient, toolsConfig.NodeClient,
toolsConfig.AccessChecker, username))
}
agent, err := model.NewAgent(tools...)
if err != nil {
return nil, trace.Wrap(err)
}
return &Chat{
client: client,
messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: model.PromptCharacter(username),
},
},
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: agent,
}, nil
}

func (client *Client) NewCommand(username string) (*Chat, error) {
agent, err := model.NewAgent(model.NewGenerateTool())
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -121,7 +148,7 @@ func (client *Client) CommandSummary(ctx context.Context, messages []openai.Chat
return completion, tc, trace.Wrap(err)
}

// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero shot classifier.
// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero-shot classifier.
func (client *Client) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) {
resp, err := client.svc.CreateChatCompletion(
ctx,
Expand Down
60 changes: 45 additions & 15 deletions lib/ai/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,34 @@ const (
finalResponseHeader = "<FINAL RESPONSE>"
)

// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature.
func NewAgent(username string, config ToolsConfig) (*Agent, error) {
err := config.CheckAndSetDefaults()
if err != nil {
return nil, trace.Wrap(err)
}
// NewExecutionTool creates a new execution tool. The execution tool is responsible for executing commands.
func NewExecutionTool() Tool {
return &commandExecutionTool{}
}

tools := []Tool{&commandExecutionTool{}}
// NewGenerateTool creates a new generation tool. The generation tool is responsible for generating Bash commands.
func NewGenerateTool() Tool {
return &commandGenerationTool{}
}

if !config.DisableEmbeddingsTool {
tools = append(tools,
&embeddingRetrievalTool{
assistClient: config.EmbeddingsClient,
currentUser: username,
nodeClient: config.NodeClient,
userAccessChecker: config.AccessChecker,
})
// NewRetrievalTool creates a new retrieval tool. The retrieval tool is responsible for retrieving embeddings.
func NewRetrievalTool(assistClient assist.AssistEmbeddingServiceClient,
nodeClient NodeGetter,
userAccessChecker services.AccessChecker,
currentUser string,
) Tool {
return &embeddingRetrievalTool{
assistClient: assistClient,
currentUser: currentUser,
nodeClient: nodeClient,
userAccessChecker: userAccessChecker,
}
}

// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature.
func NewAgent(tools ...Tool) (*Agent, error) {
if len(tools) == 0 {
return nil, trace.BadParameter("at least one tool is required")
}

return &Agent{
Expand Down Expand Up @@ -264,6 +275,25 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
return stepOutput{finish: &agentFinish{output: completion}}, nil
}

if tool, ok := tool.(*commandGenerationTool); ok {
input, err := tool.parseInput(action.Input)
if err != nil {
action := &AgentAction{
Action: actionException,
Input: observationPrefix + "Invalid or incomplete response",
Log: thoughtPrefix + err.Error(),
}

return stepOutput{action: action, observation: action.Input}, nil
}
completion := &GeneratedCommand{
Command: input.Command,
}

log.Tracef("agent decided on command generation, let's translate to an agentFinish")
return stepOutput{finish: &agentFinish{output: completion}}, nil
}

runOut, err := tool.Run(ctx, action.Input)
if err != nil {
return stepOutput{}, trace.Wrap(err)
Expand Down
80 changes: 80 additions & 0 deletions lib/ai/model/generationtool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package model

import (
"context"
"fmt"

"github.com/gravitational/trace"
)

type commandGenerationTool struct{}

type commandGenerationToolInput struct {
// Command is a unix command to execute.
Command string `json:"command"`
}

func (c *commandGenerationTool) Name() string {
return "Command Generation"
}

func (c *commandGenerationTool) Description() string {
// acknowledgement field is used to convince the LLM to return the JSON.
// Base on my testing LLM ignores the JSON when the schema has only one field.
// Adding additional "pseudo-fields" to the schema makes the LLM return the JSON.
return fmt.Sprintf(`Generate a Bash command.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we specify something other than bash here? Maybe we could feed the prompt the OS (if we have that available during an SSH session?)

If not, something more generic would avoid the assumption the user's shell is always bash

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is something that we could explore "when needed". WebUI doesn't know what OS the client is using. I'm also not sure how good OpenAI is at generating commands for different terminals. I think that the bottom line, 98% of the syntax is the same between bash, sh, zsh, fish or whatever you're using. The differences are mainly when you write a script which we don't do.

The input must be a JSON object with the following schema:
%vjson
{
"command": string, \\ The generated command
"acknowledgement": boolean \\ Set to true to ackowledge that you understand the formatting
}
Comment on lines +44 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too familiar with the structure of Assist but could this generate markdown instead? We could stream it to the UI quickly if that was the case, with the suggested command going in between three backticks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenAI sends the markdown back most of the time.

Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.`, toolResponse)

%v
`, "```", "```")
}

func (c *commandGenerationTool) Run(_ context.Context, _ string) (string, error) {
// This is stubbed because commandGenerationTool is handled specially.
// This is because execution of this tool breaks the loop and returns a command suggestion to the user.
// It is still handled as a tool because testing has shown that the LLM behaves better when it is treated as a tool.
//
// In addition, treating it as a Tool interface item simplifies the display and prompt assembly logic significantly.
return "", trace.NotImplemented("not implemented")
}

// parseInput is called in a special case if the planned tool is commandExecutionTool.
// This is because commandExecutionTool is handled differently from most other tools and forcibly terminates the thought loop.
func (*commandGenerationTool) parseInput(input string) (*commandGenerationToolInput, error) {
output, err := parseJSONFromModel[commandGenerationToolInput](input)
if err != nil {
return nil, err
}

if output.Command == "" {
return nil, &invalidOutputError{
coarse: "command generation: missing command",
detail: "command must be non-empty",
}
}

// Ignore the acknowledgement field.
// We do not care about the value. Having the command it enough.

return &output, nil
}
5 changes: 5 additions & 0 deletions lib/ai/model/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ type CompletionCommand struct {
Nodes []string `json:"nodes,omitempty"`
Labels []Label `json:"labels,omitempty"`
}

// GeneratedCommand represents a Bash command generated by LLM.
type GeneratedCommand struct {
Command string `json:"command"`
}
3 changes: 2 additions & 1 deletion lib/ai/model/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ func ConversationCommandResult(result map[string][]byte) string {
message.WriteString(string(output))
message.WriteString("\n")
}
message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.")
message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary. " +
"For error messages suggest a solution if possible. The solution can contain a Linux command or a description.")
return message.String()
}

Expand Down
84 changes: 83 additions & 1 deletion lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,31 @@ func (a *Assist) NewChat(ctx context.Context, assistService MessageService,
return chat, nil
}

// LightweightChat is a Teleport Assist chat that doesn't store the history
// of the conversation.
type LightweightChat struct {
assist *Assist
chat *ai.Chat
}

// NewLightweightChat creates a new Assist chat what doesn't store the history
// of the conversation.
func (a *Assist) NewLightweightChat(username string) (*LightweightChat, error) {
aichat, err := a.client.NewCommand(username) // TODO(jakule): fix this after all in-flight PRs are merged
if err != nil {
return nil, trace.Wrap(err)
}

return &LightweightChat{
assist: a,
chat: aichat,
}, nil
}

func (a *Assist) NewSSHCommand(username string) (*ai.Chat, error) {
return a.client.NewCommand(username)
}

// GenerateSummary generates a summary for the given message.
func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, error) {
return a.client.Summary(ctx, message)
Expand Down Expand Up @@ -179,7 +204,7 @@ func (c *Chat) reloadMessages(ctx context.Context) error {
}

// ClassifyMessage takes a user message, a list of categories, and uses the AI
// mode as a zero shot classifier. It returns an error if the classification
// mode as a zero-shot classifier. It returns an error if the classification
// result is not a valid class.
func (a *Assist) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) {
category, err := a.client.ClassifyMessage(ctx, message, classes)
Expand Down Expand Up @@ -406,6 +431,63 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use
return tokenCount, nil
}

// ProcessComplete processes a user message and returns the assistant's response.
func (c *LightweightChat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string,
) (*model.TokenCount, error) {
progressUpdates := func(update *model.AgentAction) {
payload, err := json.Marshal(update)
if err != nil {
log.WithError(err).Debugf("Failed to marshal progress update: %v", update)
return
}

if err := onMessage(MessageKindProgressUpdate, payload, c.assist.clock.Now().UTC()); err != nil {
log.WithError(err).Debugf("Failed to send progress update: %v", update)
return
}
}

message, tokenCount, err := c.chat.Reply(ctx, userInput, progressUpdates)
if err != nil {
return nil, trace.Wrap(err)
}

c.chat.Insert(openai.ChatMessageRoleUser, userInput)

switch message := message.(type) {
case *model.Message:
c.chat.Insert(openai.ChatMessageRoleAssistant, message.Content)
if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), c.assist.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *model.GeneratedCommand:
c.chat.Insert(openai.ChatMessageRoleAssistant, message.Command)
if err := onMessage(MessageKindCommand, []byte(message.Command), c.assist.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *model.StreamingMessage:
if err := func() error {
var text strings.Builder
defer onMessage(MessageKindAssistantPartialFinalize, nil, c.assist.clock.Now().UTC())
for part := range message.Parts {
text.WriteString(part)

if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), c.assist.clock.Now().UTC()); err != nil {
return trace.Wrap(err)
}
}
c.chat.Insert(openai.ChatMessageRoleAssistant, text.String())
return nil
}(); err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.Errorf("Unexpected message type: %T", message)
}

return tokenCount, nil
}

func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient PluginGetter) (string, error) {
// Try retrieving credentials from the plugin resource first
openaiPlugin, err := proxyClient.PluginsClient().GetPlugin(ctx, &pluginsv1.GetPluginRequest{
Expand Down