Skip to content

Commit

Permalink
[Assist] Remove the empty assist message (#28125)
Browse files Browse the repository at this point in the history
* [Assist] Remove the empty assist message

Assist shows an empty message at the beginning of each conversation when reading it from DB. This PR fixes that behavior and adds a test to prevent this from happening in the future.

* Address code review comments

* Address code review comments
  • Loading branch information
jakule committed Jun 22, 2023
1 parent 9cec75f commit 3fa6463
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 26 deletions.
67 changes: 43 additions & 24 deletions lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1"
"github.com/gravitational/teleport/lib/ai"
"github.com/gravitational/teleport/lib/ai/model"
"github.com/gravitational/teleport/lib/auth"
)

// MessageType is a type of the Assist message.
Expand All @@ -59,15 +58,29 @@ const (
MessageKindError MessageType = "CHAT_MESSAGE_ERROR"
)

// PluginGetter is the minimal interface used by the chat to interact with the plugin service in the backend.
type PluginGetter interface {
PluginsClient() pluginsv1.PluginServiceClient
}

// MessageService is the minimal interface used by the chat to interact with the Assist message service in the backend.
type MessageService interface {
// GetAssistantMessages returns all messages with given conversation ID.
GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error)

// CreateAssistantMessage adds the message to the backend.
CreateAssistantMessage(ctx context.Context, msg *assist.CreateAssistantMessageRequest) error
}

// Assist is the Teleport Assist client.
type Assist struct {
client *ai.Client
// clock is a clock used to generate timestamps.
clock clockwork.Clock
}

// NewAssist creates a new Assist client.
func NewAssist(ctx context.Context, proxyClient auth.ClientI,
// NewClient creates a new Assist client.
func NewClient(ctx context.Context, proxyClient PluginGetter,
proxySettings any, openaiCfg *openai.ClientConfig) (*Assist, error) {

client, err := getAssistantClient(ctx, proxyClient, proxySettings, openaiCfg)
Expand All @@ -85,24 +98,24 @@ func NewAssist(ctx context.Context, proxyClient auth.ClientI,
type Chat struct {
assist *Assist
chat *ai.Chat
// authClient is the auth server client.
authClient auth.ClientI
// assistService is the auth server client.
assistService MessageService
// ConversationID is the ID of the conversation.
ConversationID string
// Username is the username of the user who started the chat.
Username string
}

// NewChat creates a new Assist chat.
func (a *Assist) NewChat(ctx context.Context, authClient auth.ClientI,
func (a *Assist) NewChat(ctx context.Context, assistService MessageService,
conversationID string, username string,
) (*Chat, error) {
aichat := a.client.NewChat(username)

chat := &Chat{
assist: a,
chat: aichat,
authClient: authClient,
assistService: assistService,
ConversationID: conversationID,
Username: username,
}
Expand All @@ -122,7 +135,7 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e
// loadMessages loads the messages from the database.
func (c *Chat) loadMessages(ctx context.Context) error {
// existing conversation, retrieve old messages
messages, err := c.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
messages, err := c.assistService.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
ConversationId: c.ConversationID,
Username: c.Username,
})
Expand All @@ -148,7 +161,7 @@ func (c *Chat) IsNewConversation() bool {

// getAssistantClient returns the OpenAI client created base on Teleport Plugin information
// or the static token configured in YAML.
func getAssistantClient(ctx context.Context, proxyClient auth.ClientI,
func getAssistantClient(ctx context.Context, proxyClient PluginGetter,
proxySettings any, openaiCfg *openai.ClientConfig,
) (*ai.Client, error) {
apiKey, err := getOpenAITokenFromDefaultPlugin(ctx, proxyClient)
Expand Down Expand Up @@ -181,9 +194,11 @@ func getAssistantClient(ctx context.Context, proxyClient auth.ClientI,
return ai.NewClient(apiKey), nil
}

// onMessageFunc is a function that is called when a message is received.
type onMessageFunc func(kind MessageType, payload []byte, createdTime time.Time) error

// ProcessComplete processes the completion request and returns the number of tokens used.
func (c *Chat) ProcessComplete(ctx context.Context,
onMessage func(kind MessageType, payload []byte, createdTime time.Time) error, userInput string,
func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string,
) (*model.TokensUsed, error) {
var tokensUsed *model.TokensUsed

Expand All @@ -195,16 +210,20 @@ func (c *Chat) ProcessComplete(ctx context.Context,

// write the user message to persistent storage and the chat structure
c.chat.Insert(openai.ChatMessageRoleUser, userInput)
if err := c.authClient.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
Message: &assist.AssistantMessage{
Type: string(MessageKindUserMessage),
Payload: userInput, // TODO(jakule): Sanitize the payload
CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()),
},
ConversationId: c.ConversationID,
Username: c.Username,
}); err != nil {
return nil, trace.Wrap(err)

// Do not write empty messages to the database.
if userInput != "" {
if err := c.assistService.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
Message: &assist.AssistantMessage{
Type: string(MessageKindUserMessage),
Payload: userInput, // TODO(jakule): Sanitize the payload
CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()),
},
ConversationId: c.ConversationID,
Username: c.Username,
}); err != nil {
return nil, trace.Wrap(err)
}
}

switch message := message.(type) {
Expand All @@ -223,7 +242,7 @@ func (c *Chat) ProcessComplete(ctx context.Context,
},
}

if err := c.authClient.CreateAssistantMessage(ctx, protoMsg); err != nil {
if err := c.assistService.CreateAssistantMessage(ctx, protoMsg); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -253,7 +272,7 @@ func (c *Chat) ProcessComplete(ctx context.Context,
},
}

if err := c.authClient.CreateAssistantMessage(ctx, msg); err != nil {
if err := c.assistService.CreateAssistantMessage(ctx, msg); err != nil {
return nil, trace.Wrap(err)
}

Expand All @@ -267,7 +286,7 @@ func (c *Chat) ProcessComplete(ctx context.Context,
return tokensUsed, nil
}

func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient auth.ClientI) (string, error) {
func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient PluginGetter) (string, error) {
// Try retrieving credentials from the plugin resource first
openaiPlugin, err := proxyClient.PluginsClient().GetPlugin(ctx, &pluginsv1.GetPluginRequest{
Name: "openai-default",
Expand Down
147 changes: 147 additions & 0 deletions lib/assist/assist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package assist

import (
"context"
"errors"
"net/http/httptest"
"testing"
"time"

"github.com/jonboulle/clockwork"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1"
"github.com/gravitational/teleport/api/types"
aitest "github.com/gravitational/teleport/lib/ai/testutils"
"github.com/gravitational/teleport/lib/auth"
)

func TestChatComplete(t *testing.T) {
t.Parallel()

// Given an OpenAI server that returns a response for a chat completion request.
responses := []string{
generateCommandResponse(),
}

server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL + "/v1"

// And a chat client.
ctx := context.Background()
client, err := NewClient(ctx, &mockPluginGetter{}, &apiKeyMock{}, &cfg)
require.NoError(t, err)

// And a test auth server.
authSrv, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
Dir: t.TempDir(),
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

// And created conversation.
const testUser = "bob"
conversationResp, err := authSrv.AuthServer.CreateAssistantConversation(ctx, &assist.CreateAssistantConversationRequest{
Username: testUser,
CreatedTime: timestamppb.Now(),
})
require.NoError(t, err)

// When a chat is created.
chat, err := client.NewChat(ctx, authSrv.AuthServer, conversationResp.Id, testUser)
require.NoError(t, err)

t.Run("new conversation is new", func(t *testing.T) {
// Then the chat is new.
require.True(t, chat.IsNewConversation())
})

t.Run("the first message is the hey message", func(t *testing.T) {
// The first message is the welcome message.
_, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error {
require.Equal(t, MessageKindAssistantMessage, kind)
require.Contains(t, string(payload), "Hey, I'm Teleport")
return nil
}, "")
require.NoError(t, err)
})

t.Run("command should be returned in the response", func(t *testing.T) {
// The second message is the command response.
_, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error {
require.Equal(t, MessageKindCommand, kind)
require.Equal(t, string(payload), `{"command":"df -h","nodes":["localhost"]}`)
return nil
}, "Show free disk space on localhost")
require.NoError(t, err)
})

t.Run("check what messages are stored in the backend", func(t *testing.T) {
// backend should have 3 messages: welcome message, user message, command response.
messages, err := authSrv.AuthServer.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
Username: testUser,
ConversationId: conversationResp.Id,
})
require.NoError(t, err)
require.Len(t, messages.Messages, 3)

require.Equal(t, string(MessageKindAssistantMessage), messages.Messages[0].Type)
require.Equal(t, string(MessageKindUserMessage), messages.Messages[1].Type)
require.Equal(t, string(MessageKindCommand), messages.Messages[2].Type)
})
}

type apiKeyMock struct{}

// GetOpenAIAPIKey returns a mock API key.
func (m *apiKeyMock) GetOpenAIAPIKey() string {
return "123"
}

type mockPluginGetter struct{}

func (m *mockPluginGetter) PluginsClient() pluginsv1.PluginServiceClient {
return &mockPluginServiceClient{}
}

type mockPluginServiceClient struct {
pluginsv1.PluginServiceClient
}

// GetPlugin always returns an error, so the assist fallbacks to the default config.
func (m *mockPluginServiceClient) GetPlugin(_ context.Context, _ *pluginsv1.GetPluginRequest, _ ...grpc.CallOption) (*types.PluginV1, error) {
return nil, errors.New("not implemented")
}

// generateCommandResponse generates a response for the command "df -h" on the node "localhost"
func generateCommandResponse() string {
return "```" + `json
{
"action": "Command Execution",
"action_input": "{\"command\":\"df -h\",\"nodes\":[\"localhost\"],\"labels\":[]}"
}
` + "```"
}
4 changes: 2 additions & 2 deletions lib/web/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request,
return nil, trace.Wrap(err)
}

client, err := assist.NewAssist(r.Context(), h.cfg.ProxyClient,
client, err := assist.NewClient(r.Context(), h.cfg.ProxyClient,
h.cfg.ProxySettings, h.cfg.OpenAIConfig)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -410,7 +410,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,

go startPingLoop(ctx, ws, keepAliveInterval, h.log, nil)

assistClient, err := assist.NewAssist(ctx, h.cfg.ProxyClient,
assistClient, err := assist.NewClient(ctx, h.cfg.ProxyClient,
h.cfg.ProxySettings, h.cfg.OpenAIConfig)
if err != nil {
return trace.Wrap(err)
Expand Down

0 comments on commit 3fa6463

Please sign in to comment.