From 0d58c99c8b33d3f4af0b9647694cb06604963788 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Wed, 7 Jun 2023 18:07:15 -0400 Subject: [PATCH 1/4] Return errors over Assist WS (#27174) * Return errors over Assist WS * Add test * Address code review comments Add check for the close message in assist WS * Use UTC for getting time --- lib/ai/chat_test.go | 6 ++- lib/web/assistant.go | 27 +++++++++++- lib/web/assistant_test.go | 92 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 119 insertions(+), 6 deletions(-) diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go index 19aacd5be2368..bb9aaec20dd4f 100644 --- a/lib/ai/chat_test.go +++ b/lib/ai/chat_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tiktoken-go/tokenizer/codec" ) @@ -107,11 +108,12 @@ func TestChat_Complete(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") + // Use assert as require doesn't work when called from a goroutine + assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") dataBytes := responses[0] _, err := w.Write(dataBytes) - require.NoError(t, err, "Write error") + assert.NoError(t, err, "Write error") responses = responses[1:] })) diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 4f9edfe298aab..d7f41d9ca53b2 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -312,7 +312,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, sctx *SessionContext, site reversetunnel.RemoteSite, -) error { +) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") if conversationID == "" { @@ -371,7 +371,30 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, h.log.WithError(err).Error("Error setting websocket readline") return nil } - defer ws.Close() + defer func() { + closureReason := websocket.CloseNormalClosure + closureMsg := "" + if err != nil { + _ = ws.WriteJSON(&assistantMessage{ + Type: assist.MessageKindError, + Payload: err.Error(), + CreatedTime: h.clock.Now().UTC().Format(time.RFC3339), + }) + // Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 + closureReason = websocket.CloseInternalServerErr + closureMsg = err.Error() + } + // Send the close message to the client and close the connection + if err := ws.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(closureReason, closureMsg), + time.Now().Add(time.Second), + ); err != nil { + h.log.Warnf("Failed to write close message: %v", err) + } + if err := ws.Close(); err != nil { + h.log.Warnf("Failed to close websocket: %v", err) + } + }() // Update the read deadline upon receiving a pong message. ws.SetPongHandler(func(_ string) error { diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 87969a39f8621..292517c7234e1 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -31,6 +31,7 @@ import ( "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/time/rate" @@ -149,11 +150,12 @@ func Test_runAssistant(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") + // Use assert as require doesn't work when called from a goroutine + assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") dataBytes := responses[0] _, err := w.Write(dataBytes) - require.NoError(t, err, "Write error") + assert.NoError(t, err, "Write error") responses = responses[1:] })) @@ -194,6 +196,92 @@ func Test_runAssistant(t *testing.T) { } } +// Test_runAssistError tests that the assistant returns an error message +// when the OpenAI API returns an error. +func Test_runAssistError(t *testing.T) { + t.Parallel() + + readHelloMsg := func(ws *websocket.Conn) { + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect "hello" message + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + require.Contains(t, msg.Payload, "Hey, I'm Teleport") + } + + readErrorMsg := func(ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect OpenAI error message + require.Equal(t, assist.MessageKindError, msg.Type) + require.Contains(t, msg.Payload, "You are sending requests too quickly") + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Simulate rate limit error + w.WriteHeader(429) + + errMsg := openai.ErrorResponse{ + Error: &openai.APIError{ + Code: "rate_limit_reached", + Message: "You are sending requests too quickly.", + Param: nil, + Type: "rate_limit_reached", + HTTPStatusCode: 429, + }, + } + + dataBytes, err := json.Marshal(errMsg) + // Use assert as require doesn't work when called from a goroutine + assert.NoError(t, err, "Marshal error") + + _, err = w.Write(dataBytes) + assert.NoError(t, err, "Write error") + })) + t.Cleanup(server.Close) + + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) + + ctx := context.Background() + authPack := s.authPack(t, "foo") + // Create the conversation + conversationID := s.makeAssistConversation(t, ctx, authPack) + + // Make WS client and start the conversation + ws, err := s.makeAssistant(t, authPack, conversationID) + require.NoError(t, err) + t.Cleanup(func() { + // Close should yield an error as the server closes the connection + require.Error(t, ws.Close()) + }) + + // verify responses + readHelloMsg(ws) + readErrorMsg(ws) + + // Check for close message + _, _, err = ws.ReadMessage() + closeErr, ok := err.(*websocket.CloseError) + require.True(t, ok, "Expected close error") + require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure") +} + // makeAssistConversation creates a new assist conversation and returns its ID func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { clt := authPack.clt From 09cf91a4c5b4f1a649308d3ca9cd86d7b7391ddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Justinas=20Stankevi=C4=8Dius?= Date: Tue, 13 Jun 2023 23:14:32 +0300 Subject: [PATCH 2/4] Gracefully handle errors in Assist frontend (#27669) --- lib/web/assistant.go | 3 +- .../src/Assist/Conversation/Message.tsx | 1 + .../src/Assist/context/AssistContext.tsx | 38 ++++++++++++++++++- .../teleport/src/Assist/context/state.ts | 5 ++- .../teleport/src/Assist/context/utils.ts | 1 + web/packages/teleport/src/Assist/types.ts | 8 ++++ 6 files changed, 52 insertions(+), 4 deletions(-) diff --git a/lib/web/assistant.go b/lib/web/assistant.go index d7f41d9ca53b2..7e76a81bfd521 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -375,9 +375,10 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, closureReason := websocket.CloseNormalClosure closureMsg := "" if err != nil { + h.log.WithError(err).Error("Error in the Assistant loop") _ = ws.WriteJSON(&assistantMessage{ Type: assist.MessageKindError, - Payload: err.Error(), + Payload: "An error has occurred. Please try again later.", CreatedTime: h.clock.Now().UTC().Format(time.RFC3339), }) // Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 diff --git a/web/packages/teleport/src/Assist/Conversation/Message.tsx b/web/packages/teleport/src/Assist/Conversation/Message.tsx index 31f60bef5714e..4cb5f2fad5e18 100644 --- a/web/packages/teleport/src/Assist/Conversation/Message.tsx +++ b/web/packages/teleport/src/Assist/Conversation/Message.tsx @@ -107,6 +107,7 @@ function createComponentForEntry( switch (entry.type) { case ServerMessageType.Assist: case ServerMessageType.User: + case ServerMessageType.Error: return ; case ServerMessageType.Command: diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index b69c702c758c7..65719bef7e85e 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -105,7 +105,7 @@ export function AssistContextProvider(props: PropsWithChildren) { }); } - function setupWebSocket(conversationId: string) { + function setupWebSocket(conversationId: string, initialMessage?: string) { activeWebSocket.current = new WebSocket( cfg.getAssistConversationWebSocketUrl( getHostName(), @@ -123,6 +123,16 @@ export function AssistContextProvider(props: PropsWithChildren) { TEN_MINUTES * 0.8 ); + activeWebSocket.current.onopen = () => { + if (initialMessage) { + activeWebSocket.current.send(initialMessage); + } + }; + + activeWebSocket.current.onclose = () => { + dispatch({ type: AssistStateActionType.SetStreaming, streaming: false }); + }; + activeWebSocket.current.onmessage = async event => { const data = JSON.parse(event.data) as ServerMessage; @@ -178,6 +188,21 @@ export function AssistContextProvider(props: PropsWithChildren) { break; } + + case ServerMessageType.Error: + dispatch({ + type: AssistStateActionType.AddMessage, + messageType: ServerMessageType.Error, + message: data.payload, + conversationId, + }); + + dispatch({ + type: AssistStateActionType.SetStreaming, + streaming: false, + }); + + break; } }; } @@ -273,7 +298,16 @@ export function AssistContextProvider(props: PropsWithChildren) { dispatch({ type: AssistStateActionType.SetStreaming, streaming: true }); - activeWebSocket.current.send(JSON.stringify({ payload: message })); + const data = JSON.stringify({ payload: message }); + + if ( + !activeWebSocket.current || + activeWebSocket.current.readyState === WebSocket.CLOSED + ) { + setupWebSocket(state.conversations.selectedId, data); + } else { + activeWebSocket.current.send(data); + } dispatch({ type: AssistStateActionType.AddMessage, diff --git a/web/packages/teleport/src/Assist/context/state.ts b/web/packages/teleport/src/Assist/context/state.ts index c72222db77878..41dc5312ff367 100644 --- a/web/packages/teleport/src/Assist/context/state.ts +++ b/web/packages/teleport/src/Assist/context/state.ts @@ -90,7 +90,10 @@ export interface SetConversationMessagesAction { export interface AddMessageAction { type: AssistStateActionType.AddMessage; - messageType: ServerMessageType.User | ServerMessageType.Assist; + messageType: + | ServerMessageType.User + | ServerMessageType.Assist + | ServerMessageType.Error; message: string; conversationId: string; } diff --git a/web/packages/teleport/src/Assist/context/utils.ts b/web/packages/teleport/src/Assist/context/utils.ts index 08f41ebdd4fd4..e7473f7003ea7 100644 --- a/web/packages/teleport/src/Assist/context/utils.ts +++ b/web/packages/teleport/src/Assist/context/utils.ts @@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) { case ServerMessageType.Command: case ServerMessageType.CommandResult: case ServerMessageType.CommandResultStream: + case ServerMessageType.Error: return Author.Teleport; } } diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts index d9b2579c116b4..ba73068e0c80e 100644 --- a/web/packages/teleport/src/Assist/types.ts +++ b/web/packages/teleport/src/Assist/types.ts @@ -18,6 +18,7 @@ import { EventType } from 'teleport/lib/term/enums'; export enum ServerMessageType { Assist = 'CHAT_MESSAGE_ASSISTANT', User = 'CHAT_MESSAGE_USER', + Error = 'CHAT_MESSAGE_ERROR', Command = 'COMMAND', CommandResult = 'COMMAND_RESULT', CommandResultStream = 'COMMAND_RESULT_STREAM', @@ -85,6 +86,12 @@ export interface ResolvedUserServerMessage { created: Date; } +export interface ResolvedErrorServerMessage { + type: ServerMessageType.Error; + message: string; + created: Date; +} + export interface ResolvedCommandResultStreamServerMessage { type: ServerMessageType.CommandResultStream; id: number; @@ -99,6 +106,7 @@ export type ResolvedServerMessage = | ResolvedCommandServerMessage | ResolvedAssistServerMessage | ResolvedUserServerMessage + | ResolvedErrorServerMessage | ResolvedCommandResultServerMessage | ResolvedAssistThoughtServerMessage | ResolvedCommandResultStreamServerMessage; From e8b83736016ecfa89283ee3b57b33f376c85e5a5 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Fri, 16 Jun 2023 16:26:40 +0300 Subject: [PATCH 3/4] Bump e to include teleport.e#1634 --- e | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e b/e index 181e48f84123c..f0d46edbf462b 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit 181e48f84123c43f92f31edc5d53fce9c0f8a12b +Subproject commit f0d46edbf462be9ea2f113f5792f9abb97a206ef From 6c4518c41c347fc19446f0de5b945dbe27f96a91 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Fri, 16 Jun 2023 16:48:56 +0300 Subject: [PATCH 4/4] Fix error assertions in the test This test is removed altogether in #27075, but that has not been backported yet. --- lib/web/assistant_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 292517c7234e1..add4852eb95aa 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -225,9 +225,9 @@ func Test_runAssistError(t *testing.T) { err = json.Unmarshal(payload, &msg) require.NoError(t, err) - // Expect OpenAI error message + // Expect a generic error message require.Equal(t, assist.MessageKindError, msg.Type) - require.Contains(t, msg.Payload, "You are sending requests too quickly") + require.Contains(t, msg.Payload, "An error has occurred.") } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -267,8 +267,7 @@ func Test_runAssistError(t *testing.T) { ws, err := s.makeAssistant(t, authPack, conversationID) require.NoError(t, err) t.Cleanup(func() { - // Close should yield an error as the server closes the connection - require.Error(t, ws.Close()) + ws.Close() }) // verify responses