Skip to content

Commit

Permalink
[v13] Gracefully handle errors in Assist frontend (#27669) (#27935)
Browse files Browse the repository at this point in the history
* Return errors over Assist WS (#27174)

* Return errors over Assist WS

* Add test

* Address code review comments
Add check for the close message in assist WS

* Use UTC for getting time

* Gracefully handle errors in Assist frontend (#27669)

* Bump e to include teleport.e#1634

* Fix error assertions in the test

This test is removed altogether in #27075,
but that has not been backported yet.

---------

Co-authored-by: Jakub Nyckowski <jakub.nyckowski@goteleport.com>
  • Loading branch information
justinas and jakule committed Jun 19, 2023
1 parent 69e07e6 commit 7bd0519
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 10 deletions.
2 changes: 1 addition & 1 deletion e
Submodule e updated from 181e48 to f0d46e
6 changes: 4 additions & 2 deletions lib/ai/chat_test.go
Expand Up @@ -23,6 +23,7 @@ import (
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tiktoken-go/tokenizer/codec"
)
Expand Down Expand Up @@ -107,11 +108,12 @@ func TestChat_Complete(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
// Use assert as require doesn't work when called from a goroutine
assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
assert.NoError(t, err, "Write error")

responses = responses[1:]
}))
Expand Down
28 changes: 26 additions & 2 deletions lib/web/assistant.go
Expand Up @@ -312,7 +312,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error {
// runAssistant upgrades the HTTP connection to a websocket and starts a chat loop.
func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
sctx *SessionContext, site reversetunnel.RemoteSite,
) error {
) (err error) {
q := r.URL.Query()
conversationID := q.Get("conversation_id")
if conversationID == "" {
Expand Down Expand Up @@ -371,7 +371,31 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
h.log.WithError(err).Error("Error setting websocket readline")
return nil
}
defer ws.Close()
defer func() {
closureReason := websocket.CloseNormalClosure
closureMsg := ""
if err != nil {
h.log.WithError(err).Error("Error in the Assistant loop")
_ = ws.WriteJSON(&assistantMessage{
Type: assist.MessageKindError,
Payload: "An error has occurred. Please try again later.",
CreatedTime: h.clock.Now().UTC().Format(time.RFC3339),
})
// Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
closureReason = websocket.CloseInternalServerErr
closureMsg = err.Error()
}
// Send the close message to the client and close the connection
if err := ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(closureReason, closureMsg),
time.Now().Add(time.Second),
); err != nil {
h.log.Warnf("Failed to write close message: %v", err)
}
if err := ws.Close(); err != nil {
h.log.Warnf("Failed to close websocket: %v", err)
}
}()

// Update the read deadline upon receiving a pong message.
ws.SetPongHandler(func(_ string) error {
Expand Down
91 changes: 89 additions & 2 deletions lib/web/assistant_test.go
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"

Expand Down Expand Up @@ -149,11 +150,12 @@ func Test_runAssistant(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
// Use assert as require doesn't work when called from a goroutine
assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
assert.NoError(t, err, "Write error")

responses = responses[1:]
}))
Expand Down Expand Up @@ -194,6 +196,91 @@ func Test_runAssistant(t *testing.T) {
}
}

// Test_runAssistError tests that the assistant returns an error message
// when the OpenAI API returns an error.
func Test_runAssistError(t *testing.T) {
t.Parallel()

readHelloMsg := func(ws *websocket.Conn) {
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

// Expect "hello" message
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
}

readErrorMsg := func(ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)

_, payload, err := ws.ReadMessage()
require.NoError(t, err)

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

// Expect a generic error message
require.Equal(t, assist.MessageKindError, msg.Type)
require.Contains(t, msg.Payload, "An error has occurred.")
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Simulate rate limit error
w.WriteHeader(429)

errMsg := openai.ErrorResponse{
Error: &openai.APIError{
Code: "rate_limit_reached",
Message: "You are sending requests too quickly.",
Param: nil,
Type: "rate_limit_reached",
HTTPStatusCode: 429,
},
}

dataBytes, err := json.Marshal(errMsg)
// Use assert as require doesn't work when called from a goroutine
assert.NoError(t, err, "Marshal error")

_, err = w.Write(dataBytes)
assert.NoError(t, err, "Write error")
}))
t.Cleanup(server.Close)

openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})

ctx := context.Background()
authPack := s.authPack(t, "foo")
// Create the conversation
conversationID := s.makeAssistConversation(t, ctx, authPack)

// Make WS client and start the conversation
ws, err := s.makeAssistant(t, authPack, conversationID)
require.NoError(t, err)
t.Cleanup(func() {
ws.Close()
})

// verify responses
readHelloMsg(ws)
readErrorMsg(ws)

// Check for close message
_, _, err = ws.ReadMessage()
closeErr, ok := err.(*websocket.CloseError)
require.True(t, ok, "Expected close error")
require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure")
}

// makeAssistConversation creates a new assist conversation and returns its ID
func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string {
clt := authPack.clt
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/Assist/Conversation/Message.tsx
Expand Up @@ -107,6 +107,7 @@ function createComponentForEntry(
switch (entry.type) {
case ServerMessageType.Assist:
case ServerMessageType.User:
case ServerMessageType.Error:
return <MessageEntry content={entry.message} />;

case ServerMessageType.Command:
Expand Down
38 changes: 36 additions & 2 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Expand Up @@ -105,7 +105,7 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
});
}

function setupWebSocket(conversationId: string) {
function setupWebSocket(conversationId: string, initialMessage?: string) {
activeWebSocket.current = new WebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
Expand All @@ -123,6 +123,16 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
TEN_MINUTES * 0.8
);

activeWebSocket.current.onopen = () => {
if (initialMessage) {
activeWebSocket.current.send(initialMessage);
}
};

activeWebSocket.current.onclose = () => {
dispatch({ type: AssistStateActionType.SetStreaming, streaming: false });
};

activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;

Expand Down Expand Up @@ -178,6 +188,21 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

break;
}

case ServerMessageType.Error:
dispatch({
type: AssistStateActionType.AddMessage,
messageType: ServerMessageType.Error,
message: data.payload,
conversationId,
});

dispatch({
type: AssistStateActionType.SetStreaming,
streaming: false,
});

break;
}
};
}
Expand Down Expand Up @@ -273,7 +298,16 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

dispatch({ type: AssistStateActionType.SetStreaming, streaming: true });

activeWebSocket.current.send(JSON.stringify({ payload: message }));
const data = JSON.stringify({ payload: message });

if (
!activeWebSocket.current ||
activeWebSocket.current.readyState === WebSocket.CLOSED
) {
setupWebSocket(state.conversations.selectedId, data);
} else {
activeWebSocket.current.send(data);
}

dispatch({
type: AssistStateActionType.AddMessage,
Expand Down
5 changes: 4 additions & 1 deletion web/packages/teleport/src/Assist/context/state.ts
Expand Up @@ -90,7 +90,10 @@ export interface SetConversationMessagesAction {

export interface AddMessageAction {
type: AssistStateActionType.AddMessage;
messageType: ServerMessageType.User | ServerMessageType.Assist;
messageType:
| ServerMessageType.User
| ServerMessageType.Assist
| ServerMessageType.Error;
message: string;
conversationId: string;
}
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/Assist/context/utils.ts
Expand Up @@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) {
case ServerMessageType.Command:
case ServerMessageType.CommandResult:
case ServerMessageType.CommandResultStream:
case ServerMessageType.Error:
return Author.Teleport;
}
}
Expand Down
8 changes: 8 additions & 0 deletions web/packages/teleport/src/Assist/types.ts
Expand Up @@ -18,6 +18,7 @@ import { EventType } from 'teleport/lib/term/enums';
export enum ServerMessageType {
Assist = 'CHAT_MESSAGE_ASSISTANT',
User = 'CHAT_MESSAGE_USER',
Error = 'CHAT_MESSAGE_ERROR',
Command = 'COMMAND',
CommandResult = 'COMMAND_RESULT',
CommandResultStream = 'COMMAND_RESULT_STREAM',
Expand Down Expand Up @@ -85,6 +86,12 @@ export interface ResolvedUserServerMessage {
created: Date;
}

export interface ResolvedErrorServerMessage {
type: ServerMessageType.Error;
message: string;
created: Date;
}

export interface ResolvedCommandResultStreamServerMessage {
type: ServerMessageType.CommandResultStream;
id: number;
Expand All @@ -99,6 +106,7 @@ export type ResolvedServerMessage =
| ResolvedCommandServerMessage
| ResolvedAssistServerMessage
| ResolvedUserServerMessage
| ResolvedErrorServerMessage
| ResolvedCommandResultServerMessage
| ResolvedAssistThoughtServerMessage
| ResolvedCommandResultStreamServerMessage;
Expand Down

0 comments on commit 7bd0519

Please sign in to comment.