Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] Gracefully handle errors in Assist frontend (#27669) #27935

Merged
merged 4 commits into from Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion e
Submodule e updated from 181e48 to f0d46e
6 changes: 4 additions & 2 deletions lib/ai/chat_test.go
Expand Up @@ -23,6 +23,7 @@ import (
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tiktoken-go/tokenizer/codec"
)
Expand Down Expand Up @@ -107,11 +108,12 @@ func TestChat_Complete(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
// Use assert as require doesn't work when called from a goroutine
assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
assert.NoError(t, err, "Write error")

responses = responses[1:]
}))
Expand Down
28 changes: 26 additions & 2 deletions lib/web/assistant.go
Expand Up @@ -312,7 +312,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error {
// runAssistant upgrades the HTTP connection to a websocket and starts a chat loop.
func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
sctx *SessionContext, site reversetunnel.RemoteSite,
) error {
) (err error) {
q := r.URL.Query()
conversationID := q.Get("conversation_id")
if conversationID == "" {
Expand Down Expand Up @@ -371,7 +371,31 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
h.log.WithError(err).Error("Error setting websocket readline")
return nil
}
defer ws.Close()
defer func() {
closureReason := websocket.CloseNormalClosure
closureMsg := ""
if err != nil {
h.log.WithError(err).Error("Error in the Assistant loop")
_ = ws.WriteJSON(&assistantMessage{
Type: assist.MessageKindError,
Payload: "An error has occurred. Please try again later.",
CreatedTime: h.clock.Now().UTC().Format(time.RFC3339),
})
// Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
closureReason = websocket.CloseInternalServerErr
closureMsg = err.Error()
}
// Send the close message to the client and close the connection
if err := ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(closureReason, closureMsg),
time.Now().Add(time.Second),
); err != nil {
h.log.Warnf("Failed to write close message: %v", err)
}
if err := ws.Close(); err != nil {
h.log.Warnf("Failed to close websocket: %v", err)
}
}()

// Update the read deadline upon receiving a pong message.
ws.SetPongHandler(func(_ string) error {
Expand Down
91 changes: 89 additions & 2 deletions lib/web/assistant_test.go
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"

Expand Down Expand Up @@ -149,11 +150,12 @@ func Test_runAssistant(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
// Use assert as require doesn't work when called from a goroutine
assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
assert.NoError(t, err, "Write error")

responses = responses[1:]
}))
Expand Down Expand Up @@ -194,6 +196,91 @@ func Test_runAssistant(t *testing.T) {
}
}

// Test_runAssistError tests that the assistant returns an error message
// when the OpenAI API returns an error.
func Test_runAssistError(t *testing.T) {
t.Parallel()

readHelloMsg := func(ws *websocket.Conn) {
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

// Expect "hello" message
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
}

readErrorMsg := func(ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)

_, payload, err := ws.ReadMessage()
require.NoError(t, err)

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

// Expect a generic error message
require.Equal(t, assist.MessageKindError, msg.Type)
require.Contains(t, msg.Payload, "An error has occurred.")
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Simulate rate limit error
w.WriteHeader(429)

errMsg := openai.ErrorResponse{
Error: &openai.APIError{
Code: "rate_limit_reached",
Message: "You are sending requests too quickly.",
Param: nil,
Type: "rate_limit_reached",
HTTPStatusCode: 429,
},
}

dataBytes, err := json.Marshal(errMsg)
// Use assert as require doesn't work when called from a goroutine
assert.NoError(t, err, "Marshal error")

_, err = w.Write(dataBytes)
assert.NoError(t, err, "Write error")
}))
t.Cleanup(server.Close)

openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})

ctx := context.Background()
authPack := s.authPack(t, "foo")
// Create the conversation
conversationID := s.makeAssistConversation(t, ctx, authPack)

// Make WS client and start the conversation
ws, err := s.makeAssistant(t, authPack, conversationID)
require.NoError(t, err)
t.Cleanup(func() {
ws.Close()
})

// verify responses
readHelloMsg(ws)
readErrorMsg(ws)

// Check for close message
_, _, err = ws.ReadMessage()
closeErr, ok := err.(*websocket.CloseError)
require.True(t, ok, "Expected close error")
require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure")
}

// makeAssistConversation creates a new assist conversation and returns its ID
func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string {
clt := authPack.clt
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/Assist/Conversation/Message.tsx
Expand Up @@ -107,6 +107,7 @@ function createComponentForEntry(
switch (entry.type) {
case ServerMessageType.Assist:
case ServerMessageType.User:
case ServerMessageType.Error:
return <MessageEntry content={entry.message} />;

case ServerMessageType.Command:
Expand Down
38 changes: 36 additions & 2 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Expand Up @@ -105,7 +105,7 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
});
}

function setupWebSocket(conversationId: string) {
function setupWebSocket(conversationId: string, initialMessage?: string) {
activeWebSocket.current = new WebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
Expand All @@ -123,6 +123,16 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
TEN_MINUTES * 0.8
);

activeWebSocket.current.onopen = () => {
if (initialMessage) {
activeWebSocket.current.send(initialMessage);
}
};

activeWebSocket.current.onclose = () => {
dispatch({ type: AssistStateActionType.SetStreaming, streaming: false });
};

activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;

Expand Down Expand Up @@ -178,6 +188,21 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

break;
}

case ServerMessageType.Error:
dispatch({
type: AssistStateActionType.AddMessage,
messageType: ServerMessageType.Error,
message: data.payload,
conversationId,
});

dispatch({
type: AssistStateActionType.SetStreaming,
streaming: false,
});

break;
}
};
}
Expand Down Expand Up @@ -273,7 +298,16 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

dispatch({ type: AssistStateActionType.SetStreaming, streaming: true });

activeWebSocket.current.send(JSON.stringify({ payload: message }));
const data = JSON.stringify({ payload: message });

if (
!activeWebSocket.current ||
activeWebSocket.current.readyState === WebSocket.CLOSED
) {
setupWebSocket(state.conversations.selectedId, data);
} else {
activeWebSocket.current.send(data);
}

dispatch({
type: AssistStateActionType.AddMessage,
Expand Down
5 changes: 4 additions & 1 deletion web/packages/teleport/src/Assist/context/state.ts
Expand Up @@ -90,7 +90,10 @@ export interface SetConversationMessagesAction {

export interface AddMessageAction {
type: AssistStateActionType.AddMessage;
messageType: ServerMessageType.User | ServerMessageType.Assist;
messageType:
| ServerMessageType.User
| ServerMessageType.Assist
| ServerMessageType.Error;
message: string;
conversationId: string;
}
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/Assist/context/utils.ts
Expand Up @@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) {
case ServerMessageType.Command:
case ServerMessageType.CommandResult:
case ServerMessageType.CommandResultStream:
case ServerMessageType.Error:
return Author.Teleport;
}
}
Expand Down
8 changes: 8 additions & 0 deletions web/packages/teleport/src/Assist/types.ts
Expand Up @@ -18,6 +18,7 @@ import { EventType } from 'teleport/lib/term/enums';
export enum ServerMessageType {
Assist = 'CHAT_MESSAGE_ASSISTANT',
User = 'CHAT_MESSAGE_USER',
Error = 'CHAT_MESSAGE_ERROR',
Command = 'COMMAND',
CommandResult = 'COMMAND_RESULT',
CommandResultStream = 'COMMAND_RESULT_STREAM',
Expand Down Expand Up @@ -85,6 +86,12 @@ export interface ResolvedUserServerMessage {
created: Date;
}

export interface ResolvedErrorServerMessage {
type: ServerMessageType.Error;
message: string;
created: Date;
}

export interface ResolvedCommandResultStreamServerMessage {
type: ServerMessageType.CommandResultStream;
id: number;
Expand All @@ -99,6 +106,7 @@ export type ResolvedServerMessage =
| ResolvedCommandServerMessage
| ResolvedAssistServerMessage
| ResolvedUserServerMessage
| ResolvedErrorServerMessage
| ResolvedCommandResultServerMessage
| ResolvedAssistThoughtServerMessage
| ResolvedCommandResultStreamServerMessage;
Expand Down