diff --git a/e b/e
index 181e48f84123c..f0d46edbf462b 160000
--- a/e
+++ b/e
@@ -1 +1 @@
-Subproject commit 181e48f84123c43f92f31edc5d53fce9c0f8a12b
+Subproject commit f0d46edbf462be9ea2f113f5792f9abb97a206ef
diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go
index 19aacd5be2368..bb9aaec20dd4f 100644
--- a/lib/ai/chat_test.go
+++ b/lib/ai/chat_test.go
@@ -23,6 +23,7 @@ import (
"testing"
"github.com/sashabaranov/go-openai"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tiktoken-go/tokenizer/codec"
)
@@ -107,11 +108,12 @@ func TestChat_Complete(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
- require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
+ // Use assert as require doesn't work when called from a goroutine
+ assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]
_, err := w.Write(dataBytes)
- require.NoError(t, err, "Write error")
+ assert.NoError(t, err, "Write error")
responses = responses[1:]
}))
diff --git a/lib/web/assistant.go b/lib/web/assistant.go
index 4f9edfe298aab..7e76a81bfd521 100644
--- a/lib/web/assistant.go
+++ b/lib/web/assistant.go
@@ -312,7 +312,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error {
// runAssistant upgrades the HTTP connection to a websocket and starts a chat loop.
func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
sctx *SessionContext, site reversetunnel.RemoteSite,
-) error {
+) (err error) {
q := r.URL.Query()
conversationID := q.Get("conversation_id")
if conversationID == "" {
@@ -371,7 +371,31 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
h.log.WithError(err).Error("Error setting websocket readline")
return nil
}
- defer ws.Close()
+ defer func() {
+ closureReason := websocket.CloseNormalClosure
+ closureMsg := ""
+ if err != nil {
+ h.log.WithError(err).Error("Error in the Assistant loop")
+ _ = ws.WriteJSON(&assistantMessage{
+ Type: assist.MessageKindError,
+ Payload: "An error has occurred. Please try again later.",
+ CreatedTime: h.clock.Now().UTC().Format(time.RFC3339),
+ })
+ // Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
+ closureReason = websocket.CloseInternalServerErr
+ closureMsg = err.Error()
+ }
+ // Send the close message to the client and close the connection
+ if err := ws.WriteControl(websocket.CloseMessage,
+ websocket.FormatCloseMessage(closureReason, closureMsg),
+ time.Now().Add(time.Second),
+ ); err != nil {
+ h.log.Warnf("Failed to write close message: %v", err)
+ }
+ if err := ws.Close(); err != nil {
+ h.log.Warnf("Failed to close websocket: %v", err)
+ }
+ }()
// Update the read deadline upon receiving a pong message.
ws.SetPongHandler(func(_ string) error {
diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go
index 87969a39f8621..add4852eb95aa 100644
--- a/lib/web/assistant_test.go
+++ b/lib/web/assistant_test.go
@@ -31,6 +31,7 @@ import (
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
@@ -149,11 +150,12 @@ func Test_runAssistant(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
- require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
+ // Use assert as require doesn't work when called from a goroutine
+ assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]
_, err := w.Write(dataBytes)
- require.NoError(t, err, "Write error")
+ assert.NoError(t, err, "Write error")
responses = responses[1:]
}))
@@ -194,6 +196,91 @@ func Test_runAssistant(t *testing.T) {
}
}
+// Test_runAssistError tests that the assistant returns an error message
+// when the OpenAI API returns an error.
+func Test_runAssistError(t *testing.T) {
+ t.Parallel()
+
+ readHelloMsg := func(ws *websocket.Conn) {
+ _, payload, err := ws.ReadMessage()
+ require.NoError(t, err)
+
+ var msg assistantMessage
+ err = json.Unmarshal(payload, &msg)
+ require.NoError(t, err)
+
+ // Expect "hello" message
+ require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
+ require.Contains(t, msg.Payload, "Hey, I'm Teleport")
+ }
+
+ readErrorMsg := func(ws *websocket.Conn) {
+ err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
+ require.NoError(t, err)
+
+ _, payload, err := ws.ReadMessage()
+ require.NoError(t, err)
+
+ var msg assistantMessage
+ err = json.Unmarshal(payload, &msg)
+ require.NoError(t, err)
+
+ // Expect a generic error message
+ require.Equal(t, assist.MessageKindError, msg.Type)
+ require.Contains(t, msg.Payload, "An error has occurred.")
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ // Simulate rate limit error
+ w.WriteHeader(429)
+
+ errMsg := openai.ErrorResponse{
+ Error: &openai.APIError{
+ Code: "rate_limit_reached",
+ Message: "You are sending requests too quickly.",
+ Param: nil,
+ Type: "rate_limit_reached",
+ HTTPStatusCode: 429,
+ },
+ }
+
+ dataBytes, err := json.Marshal(errMsg)
+ // Use assert as require doesn't work when called from a goroutine
+ assert.NoError(t, err, "Marshal error")
+
+ _, err = w.Write(dataBytes)
+ assert.NoError(t, err, "Write error")
+ }))
+ t.Cleanup(server.Close)
+
+ openaiCfg := openai.DefaultConfig("test-token")
+ openaiCfg.BaseURL = server.URL
+ s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
+
+ ctx := context.Background()
+ authPack := s.authPack(t, "foo")
+ // Create the conversation
+ conversationID := s.makeAssistConversation(t, ctx, authPack)
+
+ // Make WS client and start the conversation
+ ws, err := s.makeAssistant(t, authPack, conversationID)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ ws.Close()
+ })
+
+ // verify responses
+ readHelloMsg(ws)
+ readErrorMsg(ws)
+
+ // Check for close message
+ _, _, err = ws.ReadMessage()
+ closeErr, ok := err.(*websocket.CloseError)
+ require.True(t, ok, "Expected close error")
+ require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure")
+}
+
// makeAssistConversation creates a new assist conversation and returns its ID
func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string {
clt := authPack.clt
diff --git a/web/packages/teleport/src/Assist/Conversation/Message.tsx b/web/packages/teleport/src/Assist/Conversation/Message.tsx
index 31f60bef5714e..4cb5f2fad5e18 100644
--- a/web/packages/teleport/src/Assist/Conversation/Message.tsx
+++ b/web/packages/teleport/src/Assist/Conversation/Message.tsx
@@ -107,6 +107,7 @@ function createComponentForEntry(
switch (entry.type) {
case ServerMessageType.Assist:
case ServerMessageType.User:
+ case ServerMessageType.Error:
return ;
case ServerMessageType.Command:
diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx
index b69c702c758c7..65719bef7e85e 100644
--- a/web/packages/teleport/src/Assist/context/AssistContext.tsx
+++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx
@@ -105,7 +105,7 @@ export function AssistContextProvider(props: PropsWithChildren) {
});
}
- function setupWebSocket(conversationId: string) {
+ function setupWebSocket(conversationId: string, initialMessage?: string) {
activeWebSocket.current = new WebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
@@ -123,6 +123,16 @@ export function AssistContextProvider(props: PropsWithChildren) {
TEN_MINUTES * 0.8
);
+ activeWebSocket.current.onopen = () => {
+ if (initialMessage) {
+ activeWebSocket.current.send(initialMessage);
+ }
+ };
+
+ activeWebSocket.current.onclose = () => {
+ dispatch({ type: AssistStateActionType.SetStreaming, streaming: false });
+ };
+
activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;
@@ -178,6 +188,21 @@ export function AssistContextProvider(props: PropsWithChildren) {
break;
}
+
+ case ServerMessageType.Error:
+ dispatch({
+ type: AssistStateActionType.AddMessage,
+ messageType: ServerMessageType.Error,
+ message: data.payload,
+ conversationId,
+ });
+
+ dispatch({
+ type: AssistStateActionType.SetStreaming,
+ streaming: false,
+ });
+
+ break;
}
};
}
@@ -273,7 +298,16 @@ export function AssistContextProvider(props: PropsWithChildren) {
dispatch({ type: AssistStateActionType.SetStreaming, streaming: true });
- activeWebSocket.current.send(JSON.stringify({ payload: message }));
+ const data = JSON.stringify({ payload: message });
+
+ if (
+ !activeWebSocket.current ||
+ activeWebSocket.current.readyState === WebSocket.CLOSED
+ ) {
+ setupWebSocket(state.conversations.selectedId, data);
+ } else {
+ activeWebSocket.current.send(data);
+ }
dispatch({
type: AssistStateActionType.AddMessage,
diff --git a/web/packages/teleport/src/Assist/context/state.ts b/web/packages/teleport/src/Assist/context/state.ts
index c72222db77878..41dc5312ff367 100644
--- a/web/packages/teleport/src/Assist/context/state.ts
+++ b/web/packages/teleport/src/Assist/context/state.ts
@@ -90,7 +90,10 @@ export interface SetConversationMessagesAction {
export interface AddMessageAction {
type: AssistStateActionType.AddMessage;
- messageType: ServerMessageType.User | ServerMessageType.Assist;
+ messageType:
+ | ServerMessageType.User
+ | ServerMessageType.Assist
+ | ServerMessageType.Error;
message: string;
conversationId: string;
}
diff --git a/web/packages/teleport/src/Assist/context/utils.ts b/web/packages/teleport/src/Assist/context/utils.ts
index 08f41ebdd4fd4..e7473f7003ea7 100644
--- a/web/packages/teleport/src/Assist/context/utils.ts
+++ b/web/packages/teleport/src/Assist/context/utils.ts
@@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) {
case ServerMessageType.Command:
case ServerMessageType.CommandResult:
case ServerMessageType.CommandResultStream:
+ case ServerMessageType.Error:
return Author.Teleport;
}
}
diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts
index d9b2579c116b4..ba73068e0c80e 100644
--- a/web/packages/teleport/src/Assist/types.ts
+++ b/web/packages/teleport/src/Assist/types.ts
@@ -18,6 +18,7 @@ import { EventType } from 'teleport/lib/term/enums';
export enum ServerMessageType {
Assist = 'CHAT_MESSAGE_ASSISTANT',
User = 'CHAT_MESSAGE_USER',
+ Error = 'CHAT_MESSAGE_ERROR',
Command = 'COMMAND',
CommandResult = 'COMMAND_RESULT',
CommandResultStream = 'COMMAND_RESULT_STREAM',
@@ -85,6 +86,12 @@ export interface ResolvedUserServerMessage {
created: Date;
}
+export interface ResolvedErrorServerMessage {
+ type: ServerMessageType.Error;
+ message: string;
+ created: Date;
+}
+
export interface ResolvedCommandResultStreamServerMessage {
type: ServerMessageType.CommandResultStream;
id: number;
@@ -99,6 +106,7 @@ export type ResolvedServerMessage =
| ResolvedCommandServerMessage
| ResolvedAssistServerMessage
| ResolvedUserServerMessage
+ | ResolvedErrorServerMessage
| ResolvedCommandResultServerMessage
| ResolvedAssistThoughtServerMessage
| ResolvedCommandResultStreamServerMessage;