-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Assist] Remove the empty assist message #28125
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,6 @@ import ( | |
pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" | ||
"github.com/gravitational/teleport/lib/ai" | ||
"github.com/gravitational/teleport/lib/ai/model" | ||
"github.com/gravitational/teleport/lib/auth" | ||
) | ||
|
||
// MessageType is a type of the Assist message. | ||
|
@@ -59,6 +58,18 @@ const ( | |
MessageKindError MessageType = "CHAT_MESSAGE_ERROR" | ||
) | ||
|
||
type PluginGetter interface { | ||
PluginsClient() pluginsv1.PluginServiceClient | ||
} | ||
|
||
type AssistantService interface { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the name, but this is not the same thing. This package is called |
||
// GetAssistantMessages returns all messages with given conversation ID. | ||
GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error) | ||
|
||
// CreateAssistantMessage adds the message to the backend. | ||
CreateAssistantMessage(ctx context.Context, msg *assist.CreateAssistantMessageRequest) error | ||
} | ||
|
||
// Assist is the Teleport Assist client. | ||
type Assist struct { | ||
client *ai.Client | ||
|
@@ -67,7 +78,7 @@ type Assist struct { | |
} | ||
|
||
// NewAssist creates a new Assist client. | ||
func NewAssist(ctx context.Context, proxyClient auth.ClientI, | ||
func NewAssist(ctx context.Context, proxyClient PluginGetter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While you're here, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Touch a line and the owner of it you become" 😆 Changed |
||
proxySettings any, openaiCfg *openai.ClientConfig) (*Assist, error) { | ||
|
||
client, err := getAssistantClient(ctx, proxyClient, proxySettings, openaiCfg) | ||
|
@@ -85,24 +96,24 @@ func NewAssist(ctx context.Context, proxyClient auth.ClientI, | |
type Chat struct { | ||
assist *Assist | ||
chat *ai.Chat | ||
// authClient is the auth server client. | ||
authClient auth.ClientI | ||
// assistService is the auth server client. | ||
assistService AssistantService | ||
// ConversationID is the ID of the conversation. | ||
ConversationID string | ||
// Username is the username of the user who started the chat. | ||
Username string | ||
} | ||
|
||
// NewChat creates a new Assist chat. | ||
func (a *Assist) NewChat(ctx context.Context, authClient auth.ClientI, | ||
func (a *Assist) NewChat(ctx context.Context, assistService AssistantService, | ||
conversationID string, username string, | ||
) (*Chat, error) { | ||
aichat := a.client.NewChat(username) | ||
|
||
chat := &Chat{ | ||
assist: a, | ||
chat: aichat, | ||
authClient: authClient, | ||
assistService: assistService, | ||
ConversationID: conversationID, | ||
Username: username, | ||
} | ||
|
@@ -122,7 +133,7 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e | |
// loadMessages loads the messages from the database. | ||
func (c *Chat) loadMessages(ctx context.Context) error { | ||
// existing conversation, retrieve old messages | ||
messages, err := c.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ | ||
messages, err := c.assistService.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ | ||
ConversationId: c.ConversationID, | ||
Username: c.Username, | ||
}) | ||
|
@@ -148,7 +159,7 @@ func (c *Chat) IsNewConversation() bool { | |
|
||
// getAssistantClient returns the OpenAI client created base on Teleport Plugin information | ||
// or the static token configured in YAML. | ||
func getAssistantClient(ctx context.Context, proxyClient auth.ClientI, | ||
func getAssistantClient(ctx context.Context, proxyClient PluginGetter, | ||
proxySettings any, openaiCfg *openai.ClientConfig, | ||
) (*ai.Client, error) { | ||
apiKey, err := getOpenAITokenFromDefaultPlugin(ctx, proxyClient) | ||
|
@@ -181,9 +192,11 @@ func getAssistantClient(ctx context.Context, proxyClient auth.ClientI, | |
return ai.NewClient(apiKey), nil | ||
} | ||
|
||
// onMessageFunc is a function that is called when a message is received. | ||
type onMessageFunc func(kind MessageType, payload []byte, createdTime time.Time) error | ||
|
||
// ProcessComplete processes the completion request and returns the number of tokens used. | ||
func (c *Chat) ProcessComplete(ctx context.Context, | ||
onMessage func(kind MessageType, payload []byte, createdTime time.Time) error, userInput string, | ||
func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string, | ||
) (*model.TokensUsed, error) { | ||
var tokensUsed *model.TokensUsed | ||
|
||
|
@@ -195,16 +208,20 @@ func (c *Chat) ProcessComplete(ctx context.Context, | |
|
||
// write the user message to persistent storage and the chat structure | ||
c.chat.Insert(openai.ChatMessageRoleUser, userInput) | ||
if err := c.authClient.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{ | ||
Message: &assist.AssistantMessage{ | ||
Type: string(MessageKindUserMessage), | ||
Payload: userInput, // TODO(jakule): Sanitize the payload | ||
CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()), | ||
}, | ||
ConversationId: c.ConversationID, | ||
Username: c.Username, | ||
}); err != nil { | ||
return nil, trace.Wrap(err) | ||
|
||
if userInput != "" { | ||
// Do not write empty messages to the database. | ||
jakule marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if err := c.assistService.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{ | ||
Message: &assist.AssistantMessage{ | ||
Type: string(MessageKindUserMessage), | ||
Payload: userInput, // TODO(jakule): Sanitize the payload | ||
CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()), | ||
}, | ||
ConversationId: c.ConversationID, | ||
Username: c.Username, | ||
}); err != nil { | ||
return nil, trace.Wrap(err) | ||
} | ||
} | ||
|
||
switch message := message.(type) { | ||
|
@@ -223,7 +240,7 @@ func (c *Chat) ProcessComplete(ctx context.Context, | |
}, | ||
} | ||
|
||
if err := c.authClient.CreateAssistantMessage(ctx, protoMsg); err != nil { | ||
if err := c.assistService.CreateAssistantMessage(ctx, protoMsg); err != nil { | ||
return nil, trace.Wrap(err) | ||
} | ||
|
||
|
@@ -253,7 +270,7 @@ func (c *Chat) ProcessComplete(ctx context.Context, | |
}, | ||
} | ||
|
||
if err := c.authClient.CreateAssistantMessage(ctx, msg); err != nil { | ||
if err := c.assistService.CreateAssistantMessage(ctx, msg); err != nil { | ||
return nil, trace.Wrap(err) | ||
} | ||
|
||
|
@@ -267,7 +284,7 @@ func (c *Chat) ProcessComplete(ctx context.Context, | |
return tokensUsed, nil | ||
} | ||
|
||
func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient auth.ClientI) (string, error) { | ||
func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient PluginGetter) (string, error) { | ||
// Try retrieving credentials from the plugin resource first | ||
openaiPlugin, err := proxyClient.PluginsClient().GetPlugin(ctx, &pluginsv1.GetPluginRequest{ | ||
Name: "openai-default", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
/* | ||
* Copyright 2023 Gravitational, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package assist | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/jonboulle/clockwork" | ||
"github.com/sashabaranov/go-openai" | ||
"github.com/stretchr/testify/require" | ||
"google.golang.org/grpc" | ||
"google.golang.org/protobuf/types/known/timestamppb" | ||
|
||
"github.com/gravitational/teleport/api/gen/proto/go/assist/v1" | ||
pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" | ||
"github.com/gravitational/teleport/api/types" | ||
aitest "github.com/gravitational/teleport/lib/ai/testutils" | ||
"github.com/gravitational/teleport/lib/auth" | ||
) | ||
|
||
func TestChatComplete(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Given an OpenAI server that returns a response for a chat completion request. | ||
responses := []string{ | ||
generateCommandResponse(), | ||
} | ||
|
||
server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) | ||
t.Cleanup(server.Close) | ||
|
||
cfg := openai.DefaultConfig("secret-test-token") | ||
cfg.BaseURL = server.URL + "/v1" | ||
|
||
// And a chat client. | ||
ctx := context.Background() | ||
client, err := NewAssist(ctx, &mockPluginGetter{}, &apiKeyMock{}, &cfg) | ||
require.NoError(t, err) | ||
|
||
// And a test auth server. | ||
authSrv, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ | ||
Dir: t.TempDir(), | ||
Clock: clockwork.NewFakeClock(), | ||
}) | ||
require.NoError(t, err) | ||
|
||
// And created conversation. | ||
const testUser = "bob" | ||
conversationResp, err := authSrv.AuthServer.CreateAssistantConversation(ctx, &assist.CreateAssistantConversationRequest{ | ||
Username: testUser, | ||
CreatedTime: timestamppb.Now(), | ||
}) | ||
require.NoError(t, err) | ||
|
||
// When a chat is created. | ||
chat, err := client.NewChat(ctx, authSrv.AuthServer, conversationResp.Id, testUser) | ||
require.NoError(t, err) | ||
|
||
t.Run("new conversation is new", func(t *testing.T) { | ||
// Then the chat is new. | ||
require.True(t, chat.IsNewConversation()) | ||
}) | ||
|
||
t.Run("new conversation is not complete", func(t *testing.T) { | ||
// The first message is the welcome message. | ||
_, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error { | ||
require.Equal(t, MessageKindAssistantMessage, kind) | ||
require.Contains(t, string(payload), "Hey, I'm Teleport") | ||
return nil | ||
}, "") | ||
require.NoError(t, err) | ||
}) | ||
|
||
t.Run("new conversation is not complete", func(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The duplicate sub test names might get a bit confusing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed |
||
// The second message is the command response. | ||
_, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error { | ||
require.Equal(t, MessageKindCommand, kind) | ||
require.Equal(t, string(payload), `{"command":"df -h","nodes":["localhost"]}`) | ||
return nil | ||
}, "Show free disk space on localhost") | ||
require.NoError(t, err) | ||
}) | ||
|
||
t.Run("check what messages are stored in the backend", func(t *testing.T) { | ||
// backend should have 3 messages: welcome message, user message, command response. | ||
messages, err := authSrv.AuthServer.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ | ||
Username: testUser, | ||
ConversationId: conversationResp.Id, | ||
}) | ||
require.NoError(t, err) | ||
require.Len(t, messages.Messages, 3) | ||
|
||
require.Equal(t, string(MessageKindAssistantMessage), messages.Messages[0].Type) | ||
require.Equal(t, string(MessageKindUserMessage), messages.Messages[1].Type) | ||
require.Equal(t, string(MessageKindCommand), messages.Messages[2].Type) | ||
}) | ||
} | ||
|
||
type apiKeyMock struct{} | ||
|
||
// GetOpenAIAPIKey returns a mock API key. | ||
func (m *apiKeyMock) GetOpenAIAPIKey() string { | ||
return "123" | ||
} | ||
|
||
type mockPluginGetter struct{} | ||
|
||
func (m *mockPluginGetter) PluginsClient() pluginsv1.PluginServiceClient { | ||
return &mockPluginServiceClient{} | ||
} | ||
|
||
type mockPluginServiceClient struct { | ||
pluginsv1.PluginServiceClient | ||
} | ||
|
||
// GetPlugin always returns an error, so the assist fallbacks to the default config. | ||
func (m *mockPluginServiceClient) GetPlugin(_ context.Context, _ *pluginsv1.GetPluginRequest, _ ...grpc.CallOption) (*types.PluginV1, error) { | ||
return nil, errors.New("not implemented") | ||
} | ||
|
||
// generateCommandResponse generates a response for the command "df -h" on the node "localhost" | ||
func generateCommandResponse() string { | ||
return "```" + `json | ||
{ | ||
"action": "Command Execution", | ||
"action_input": "{\"command\":\"df -h\",\"nodes\":[\"localhost\"],\"labels\":[]}" | ||
} | ||
` + "```" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: missing docs on exported types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added