Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Assist] Remove the empty assist message #28125

Merged
merged 3 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 40 additions & 23 deletions lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1"
"github.com/gravitational/teleport/lib/ai"
"github.com/gravitational/teleport/lib/ai/model"
"github.com/gravitational/teleport/lib/auth"
)

// MessageType is a type of the Assist message.
Expand All @@ -59,6 +58,18 @@ const (
MessageKindError MessageType = "CHAT_MESSAGE_ERROR"
)

type PluginGetter interface {
PluginsClient() pluginsv1.PluginServiceClient
}

type AssistantService interface {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing docs on exported types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assist.AssistantService also stutters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the name, but this is not the same thing. This package is called assist, but the AssistantService doesn't refer to this package, but the Assist service that lives in the backend. I renamed it to MessageService as all receivers are related to messages.

// GetAssistantMessages returns all messages with given conversation ID.
GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error)

// CreateAssistantMessage adds the message to the backend.
CreateAssistantMessage(ctx context.Context, msg *assist.CreateAssistantMessageRequest) error
}

// Assist is the Teleport Assist client.
type Assist struct {
client *ai.Client
Expand All @@ -67,7 +78,7 @@ type Assist struct {
}

// NewAssist creates a new Assist client.
func NewAssist(ctx context.Context, proxyClient auth.ClientI,
func NewAssist(ctx context.Context, proxyClient PluginGetter,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're here, assist.NewAssist() stutters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Touch a line and the owner of it you become" 😆 Changed

proxySettings any, openaiCfg *openai.ClientConfig) (*Assist, error) {

client, err := getAssistantClient(ctx, proxyClient, proxySettings, openaiCfg)
Expand All @@ -85,24 +96,24 @@ func NewAssist(ctx context.Context, proxyClient auth.ClientI,
type Chat struct {
assist *Assist
chat *ai.Chat
// authClient is the auth server client.
authClient auth.ClientI
// assistService is the auth server client.
assistService AssistantService
// ConversationID is the ID of the conversation.
ConversationID string
// Username is the username of the user who started the chat.
Username string
}

// NewChat creates a new Assist chat.
func (a *Assist) NewChat(ctx context.Context, authClient auth.ClientI,
func (a *Assist) NewChat(ctx context.Context, assistService AssistantService,
conversationID string, username string,
) (*Chat, error) {
aichat := a.client.NewChat(username)

chat := &Chat{
assist: a,
chat: aichat,
authClient: authClient,
assistService: assistService,
ConversationID: conversationID,
Username: username,
}
Expand All @@ -122,7 +133,7 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e
// loadMessages loads the messages from the database.
func (c *Chat) loadMessages(ctx context.Context) error {
// existing conversation, retrieve old messages
messages, err := c.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
messages, err := c.assistService.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
ConversationId: c.ConversationID,
Username: c.Username,
})
Expand All @@ -148,7 +159,7 @@ func (c *Chat) IsNewConversation() bool {

// getAssistantClient returns the OpenAI client created base on Teleport Plugin information
// or the static token configured in YAML.
func getAssistantClient(ctx context.Context, proxyClient auth.ClientI,
func getAssistantClient(ctx context.Context, proxyClient PluginGetter,
proxySettings any, openaiCfg *openai.ClientConfig,
) (*ai.Client, error) {
apiKey, err := getOpenAITokenFromDefaultPlugin(ctx, proxyClient)
Expand Down Expand Up @@ -181,9 +192,11 @@ func getAssistantClient(ctx context.Context, proxyClient auth.ClientI,
return ai.NewClient(apiKey), nil
}

// onMessageFunc is a function that is called when a message is received.
type onMessageFunc func(kind MessageType, payload []byte, createdTime time.Time) error

// ProcessComplete processes the completion request and returns the number of tokens used.
func (c *Chat) ProcessComplete(ctx context.Context,
onMessage func(kind MessageType, payload []byte, createdTime time.Time) error, userInput string,
func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string,
) (*model.TokensUsed, error) {
var tokensUsed *model.TokensUsed

Expand All @@ -195,16 +208,20 @@ func (c *Chat) ProcessComplete(ctx context.Context,

// write the user message to persistent storage and the chat structure
c.chat.Insert(openai.ChatMessageRoleUser, userInput)
if err := c.authClient.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
Message: &assist.AssistantMessage{
Type: string(MessageKindUserMessage),
Payload: userInput, // TODO(jakule): Sanitize the payload
CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()),
},
ConversationId: c.ConversationID,
Username: c.Username,
}); err != nil {
return nil, trace.Wrap(err)

if userInput != "" {
// Do not write empty messages to the database.
jakule marked this conversation as resolved.
Show resolved Hide resolved
if err := c.assistService.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{
Message: &assist.AssistantMessage{
Type: string(MessageKindUserMessage),
Payload: userInput, // TODO(jakule): Sanitize the payload
CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()),
},
ConversationId: c.ConversationID,
Username: c.Username,
}); err != nil {
return nil, trace.Wrap(err)
}
}

switch message := message.(type) {
Expand All @@ -223,7 +240,7 @@ func (c *Chat) ProcessComplete(ctx context.Context,
},
}

if err := c.authClient.CreateAssistantMessage(ctx, protoMsg); err != nil {
if err := c.assistService.CreateAssistantMessage(ctx, protoMsg); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -253,7 +270,7 @@ func (c *Chat) ProcessComplete(ctx context.Context,
},
}

if err := c.authClient.CreateAssistantMessage(ctx, msg); err != nil {
if err := c.assistService.CreateAssistantMessage(ctx, msg); err != nil {
return nil, trace.Wrap(err)
}

Expand All @@ -267,7 +284,7 @@ func (c *Chat) ProcessComplete(ctx context.Context,
return tokensUsed, nil
}

func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient auth.ClientI) (string, error) {
func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient PluginGetter) (string, error) {
// Try retrieving credentials from the plugin resource first
openaiPlugin, err := proxyClient.PluginsClient().GetPlugin(ctx, &pluginsv1.GetPluginRequest{
Name: "openai-default",
Expand Down
147 changes: 147 additions & 0 deletions lib/assist/assist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package assist

import (
"context"
"errors"
"net/http/httptest"
"testing"
"time"

"github.com/jonboulle/clockwork"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1"
"github.com/gravitational/teleport/api/types"
aitest "github.com/gravitational/teleport/lib/ai/testutils"
"github.com/gravitational/teleport/lib/auth"
)

func TestChatComplete(t *testing.T) {
t.Parallel()

// Given an OpenAI server that returns a response for a chat completion request.
responses := []string{
generateCommandResponse(),
}

server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL + "/v1"

// And a chat client.
ctx := context.Background()
client, err := NewAssist(ctx, &mockPluginGetter{}, &apiKeyMock{}, &cfg)
require.NoError(t, err)

// And a test auth server.
authSrv, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
Dir: t.TempDir(),
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

// And created conversation.
const testUser = "bob"
conversationResp, err := authSrv.AuthServer.CreateAssistantConversation(ctx, &assist.CreateAssistantConversationRequest{
Username: testUser,
CreatedTime: timestamppb.Now(),
})
require.NoError(t, err)

// When a chat is created.
chat, err := client.NewChat(ctx, authSrv.AuthServer, conversationResp.Id, testUser)
require.NoError(t, err)

t.Run("new conversation is new", func(t *testing.T) {
// Then the chat is new.
require.True(t, chat.IsNewConversation())
})

t.Run("new conversation is not complete", func(t *testing.T) {
// The first message is the welcome message.
_, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error {
require.Equal(t, MessageKindAssistantMessage, kind)
require.Contains(t, string(payload), "Hey, I'm Teleport")
return nil
}, "")
require.NoError(t, err)
})

t.Run("new conversation is not complete", func(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicate sub test names might get a bit confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

// The second message is the command response.
_, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error {
require.Equal(t, MessageKindCommand, kind)
require.Equal(t, string(payload), `{"command":"df -h","nodes":["localhost"]}`)
return nil
}, "Show free disk space on localhost")
require.NoError(t, err)
})

t.Run("check what messages are stored in the backend", func(t *testing.T) {
// backend should have 3 messages: welcome message, user message, command response.
messages, err := authSrv.AuthServer.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
Username: testUser,
ConversationId: conversationResp.Id,
})
require.NoError(t, err)
require.Len(t, messages.Messages, 3)

require.Equal(t, string(MessageKindAssistantMessage), messages.Messages[0].Type)
require.Equal(t, string(MessageKindUserMessage), messages.Messages[1].Type)
require.Equal(t, string(MessageKindCommand), messages.Messages[2].Type)
})
}

type apiKeyMock struct{}

// GetOpenAIAPIKey returns a mock API key.
func (m *apiKeyMock) GetOpenAIAPIKey() string {
return "123"
}

type mockPluginGetter struct{}

func (m *mockPluginGetter) PluginsClient() pluginsv1.PluginServiceClient {
return &mockPluginServiceClient{}
}

type mockPluginServiceClient struct {
pluginsv1.PluginServiceClient
}

// GetPlugin always returns an error, so the assist fallbacks to the default config.
func (m *mockPluginServiceClient) GetPlugin(_ context.Context, _ *pluginsv1.GetPluginRequest, _ ...grpc.CallOption) (*types.PluginV1, error) {
return nil, errors.New("not implemented")
}

// generateCommandResponse generates a response for the command "df -h" on the node "localhost"
func generateCommandResponse() string {
return "```" + `json
{
"action": "Command Execution",
"action_input": "{\"command\":\"df -h\",\"nodes\":[\"localhost\"],\"labels\":[]}"
}
` + "```"
}