Skip to content

Commit

Permalink
[Assist] Prevent creating messages without conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakule committed May 24, 2023
1 parent 3e63df4 commit 33fb7bb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
9 changes: 9 additions & 0 deletions lib/services/local/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ func (s *AssistService) CreateAssistantMessage(ctx context.Context, req *assist.
return trace.BadParameter("missing conversation ID")
}

// Check if the conversation exists.
conversationKey := backend.Key(assistantConversationPrefix, req.Username, req.ConversationId)
if _, err := s.Get(ctx, conversationKey); err != nil {
if trace.IsNotFound(err) {
return trace.NotFound("conversation %q not found", req.ConversationId)
}
return trace.Wrap(err)
}

msg := req.GetMessage()
value, err := json.Marshal(msg)
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions lib/services/local/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -140,4 +141,18 @@ func TestAssistantCRUD(t *testing.T) {
require.Equal(t, conversationID, conversations.Conversations[0].Id)
require.Equal(t, conversationResp.Id, conversations.Conversations[1].Id)
})

t.Run("refuse to add messages if conversion does not exist", func(t *testing.T) {
msg := &assist.CreateAssistantMessageRequest{
Username: username,
ConversationId: uuid.New().String(),
Message: &assist.AssistantMessage{
CreatedTime: timestamppb.New(time.Now()),
Payload: "foo",
Type: "USER_MSG",
},
}
err := identity.CreateAssistantMessage(ctx, msg)
require.Error(t, err)
})
}
30 changes: 26 additions & 4 deletions lib/web/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package web

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
Expand All @@ -26,7 +27,6 @@ import (
"net/url"
"testing"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -168,7 +168,13 @@ func Test_runAssistant(t *testing.T) {
tc.setup(t, s)
}

ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
ctx := context.Background()
authPack := s.authPack(t, "foo")
// Create the conversation
conversationID := s.makeAssistConversation(t, ctx, authPack)

// Make WS client and start the conversation
ws, err := s.makeAssistant(t, authPack, conversationID)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ws.Close()) })

Expand All @@ -186,18 +192,34 @@ func Test_runAssistant(t *testing.T) {
tc.act(t, ws)
})
}
}

// makeAssistConversation creates a new assist conversation and returns its ID
func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string {
clt := authPack.clt

resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil)
require.NoError(t, err)

convResp := struct {
ConversationID string `json:"id"`
}{}
err = json.Unmarshal(resp.Bytes(), &convResp)
require.NoError(t, err)

return convResp.ConversationID
}

func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) {
// makeAssistant creates a new assistant websocket connection.
func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID string) (*websocket.Conn, error) {
u := url.URL{
Host: s.url().Host,
Scheme: client.WSS,
Path: fmt.Sprintf("/v1/webapi/sites/%s/assistant", currentSiteShortcut),
}

q := u.Query()
q.Set("conversation_id", uuid.New().String())
q.Set("conversation_id", conversationID)
q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token)
u.RawQuery = q.Encode()

Expand Down

0 comments on commit 33fb7bb

Please sign in to comment.