Skip to content

Commit

Permalink
Adding abiliy to add multiple AI bots. (#175)
Browse files Browse the repository at this point in the history
* Adding abiliy to add multiple AI bots.

* Fix tests

* More lint

* More lint

* update chevron placement

* Remember last selected bot

* Add some system console validation.

* UX fixes

* Replies count

* More UX feedback

* Simplify
  • Loading branch information
crspeller committed May 21, 2024
1 parent ffeefd8 commit bc89035
Show file tree
Hide file tree
Showing 42 changed files with 1,884 additions and 528 deletions.
2 changes: 2 additions & 0 deletions e2e/helpers/mmcontainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ export default class MattermostContainer {
"MM_FILESETTINGS_MAXFILESIZE": "256000000",
"MM_LOGSETTINGS_CONSOLELEVEL": "DEBUG",
"MM_LOGSETTINGS_FILELEVEL": "DEBUG",
"MM_SERVICESETTINGS_ENABLEDEVELOPER": "true",
"MM_SERVICESETTINGS_ENABLETESTING": "true",
};
this.email = defaultEmail;
this.username = defaultUsername;
Expand Down
9 changes: 7 additions & 2 deletions e2e/helpers/openai-mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ data: {"id":"chatcmpl-8t1WLFfcSfmK0sfBcFbj8VEhOqNYd","object":"chat.completion.c
data: [DONE]
`

export const responseTestText = "Hello! How can I assist you today?"

export const responseTest2 = `
data: {"id":"chatcmpl-8t1WLFfcSfmK0sfBcFbj8VEhOqNYd","object":"chat.completion.chunk","created":1708124577,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
Expand All @@ -48,6 +50,8 @@ data: {"id":"chatcmpl-8t1WLFfcSfmK0sfBcFbj8VEhOqNYd","object":"chat.completion.c
data: [DONE]
`

export const responseTest2Text = "Hello! This is a second message."


export class OpenAIMockContainer {
container: StartedTestContainer;
Expand Down Expand Up @@ -79,11 +83,12 @@ export class OpenAIMockContainer {
})
}

addCompletionMock = async (response: string) => {
addCompletionMock = async (response: string, botPrefix?: string) => {
const prefix = botPrefix ? ("/"+botPrefix) : ""
return this.addMock({
request: {
method: "POST",
path: "/chat/completions",
path: prefix + "/chat/completions",
},
context: {
times: 1,
Expand Down
29 changes: 22 additions & 7 deletions e2e/helpers/plugincontainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,30 @@ const RunContainer = async (): Promise<MattermostContainer> => {
"disableFunctionCalls": false,
"enableLLMTrace": true,
"enableUserRestrictions": false,
"llmBackend": "Mock",
"services": [
"defaultBotName": "mock",
"bots": [
{
"id": "y6fcxh0xc",
"name": "Mock",
"apiKey": "mock",
"serviceName": "openaicompatible",
"url": "http://openai:8080",
}
"name": "mock",
"displayName": "Mock Bot",
"customInstructions": "",
"service": {
"type": "openaicompatible",
"apiKey": "mock",
"apiURL": "http://openai:8080",
},
},
{
"id": "oawiejfoj",
"name": "second",
"displayName": "Second Bot",
"customInstructions": "",
"service": {
"type": "openaicompatible",
"apiKey": "ohno",
"apiURL": "http://openai:8080/second",
},
},
],
}
}
Expand Down
26 changes: 23 additions & 3 deletions e2e/tests/basic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import RunContainer from 'helpers/plugincontainer';
import MattermostContainer from 'helpers/mmcontainer';
import {login} from 'helpers/mm';
import {openRHS} from 'helpers/ai-plugin';
import { OpenAIMockContainer, RunOpenAIMocks, responseTest, responseTest2 } from 'helpers/openai-mock';
import { OpenAIMockContainer, RunOpenAIMocks, responseTest, responseTest2, responseTest2Text, responseTestText } from 'helpers/openai-mock';

let mattermost: MattermostContainer;
let openAIMock: OpenAIMockContainer;
Expand Down Expand Up @@ -67,10 +67,30 @@ test ('regenerate button', async ({ page }) => {
await page.getByTestId('reply_textbox').click();
await page.getByTestId('reply_textbox').fill('Hello!');
await page.getByTestId('reply_textbox').press('Enter');
await expect(page.getByText("Hello! How can I assist you today?")).toBeVisible();
await expect(page.getByText(responseTestText)).toBeVisible();

await openAIMock.addCompletionMock(responseTest2);

await page.getByRole('button', { name: 'Regenerate' }).click();
await expect(page.getByText("Hello! This is a second message.")).toBeVisible();
await expect(page.getByText(responseTest2Text)).toBeVisible();
})

test ('switching bots', async ({ page }) => {
const url = mattermost.url()
await login(page, url, "regularuser", "regularuser");;
await openRHS(page);
await openAIMock.addCompletionMock(responseTest, "second");

// Switch to second bot
await page.getByTestId('menuButtonMock Bot').click();
await page.getByRole('button', { name: 'Second Bot' }).click();

await page.getByTestId('reply_textbox').click();
await page.getByTestId('reply_textbox').fill('Hello!');
await page.getByTestId('reply_textbox').press('Enter');

// Second bot responds
await expect(page.getByRole('button', { name: 'second', exact: true })).toBeVisible();
// With correct message
await expect(page.getByText(responseTestText)).toBeVisible();
})
21 changes: 19 additions & 2 deletions server/ai/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,30 @@ package ai

type ServiceConfig struct {
Name string `json:"name"`
ServiceName string `json:"serviceName"`
Type string `json:"type"`
APIKey string `json:"apiKey"`
OrgID string `json:"orgId"`
DefaultModel string `json:"defaultModel"`
URL string `json:"url"`
APIURL string `json:"apiURL"`
Username string `json:"username"`
Password string `json:"password"`
TokenLimit int `json:"tokenLimit"`
StreamingTimeoutSeconds int `json:"streamingTimeoutSeconds"`
}

type BotConfig struct {
ID string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
CustomInstructions string `json:"customInstructions"`
Service ServiceConfig `json:"service"`
}

func (c *BotConfig) IsValid() bool {
isInvalid := c.Name == "" ||
c.DisplayName == "" ||
c.Service.Type == "" ||
(c.Service.Type == "openaicompatable" && c.Service.APIURL == "") ||
(c.Service.Type != "asksage" && c.Service.APIKey == "")
return !isInvalid
}
19 changes: 10 additions & 9 deletions server/ai/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ type Post struct {
}

type ConversationContext struct {
BotID string
Time string
ServerName string
CompanyName string
RequestingUser *model.User
Channel *model.Channel
Team *model.Team
Post *model.Post
PromptParameters map[string]string
BotID string
Time string
ServerName string
CompanyName string
RequestingUser *model.User
Channel *model.Channel
Team *model.Team
Post *model.Post
PromptParameters map[string]string
CustomInstructions string
}

func NewConversationContext(botID string, requestingUser *model.User, channel *model.Channel, post *model.Post) ConversationContext {
Expand Down
2 changes: 1 addition & 1 deletion server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ var ErrStreamingTimeout = errors.New("timeout streaming")

func NewCompatible(llmService ai.ServiceConfig) *OpenAI {
apiKey := llmService.APIKey
endpointURL := strings.TrimSuffix(llmService.URL, "/")
endpointURL := strings.TrimSuffix(llmService.APIURL, "/")
defaultModel := llmService.DefaultModel
config := openaiClient.DefaultConfig(apiKey)
config.BaseURL = endpointURL
Expand Down
4 changes: 3 additions & 1 deletion server/ai/prompts/standard_personality.tmpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
You are a helpful assistant called "AI Copilot" that responds on a Mattermost chat server called {{.ServerName}} owned by {{.CompanyName}}.

Current time and date in the user's location is {{.Time}}

{{if .CustomInstructions}}
{{.CustomInstructions}}
{{end}}
The following is the personal information of the user. This information is given with every request to you. You can use this information to taylor the request to the specific user however most of the time it will not be relavent. Only acknowledge the information when the request is directly related to the information provided. Never repeat it as written.
The user making the request username is '{{.RequestingUser.Username}}'.
{{if .RequestingUser.FirstName}}
Expand Down
89 changes: 77 additions & 12 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"github.com/gin-gonic/gin"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/pluginapi"
)

const (
ContextPostKey = "post"
ContextChannelKey = "channel"
ContextBotKey = "bot"
)

func (p *Plugin) ServeHTTP(c *plugin.Context, w http.ResponseWriter, r *http.Request) {
Expand All @@ -22,8 +24,12 @@ func (p *Plugin) ServeHTTP(c *plugin.Context, w http.ResponseWriter, r *http.Req
router.Use(p.MattermostAuthorizationRequired)

router.GET("/ai_threads", p.handleGetAIThreads)
router.GET("/ai_bots", p.handleGetAIBots)

postRouter := router.Group("/post/:postid")
botRequriedRouter := router.Group("")
botRequriedRouter.Use(p.aiBotRequired)

postRouter := botRequriedRouter.Group("/post/:postid")
postRouter.Use(p.postAuthorizationRequired)
postRouter.POST("/react", p.handleReact)
postRouter.POST("/summarize", p.handleSummarize)
Expand All @@ -32,7 +38,7 @@ func (p *Plugin) ServeHTTP(c *plugin.Context, w http.ResponseWriter, r *http.Req
postRouter.POST("/stop", p.handleStop)
postRouter.POST("/regenerate", p.handleRegenerate)

channelRouter := router.Group("/channel/:channelid")
channelRouter := botRequriedRouter.Group("/channel/:channelid")
channelRouter.Use(p.channelAuthorizationRequired)
channelRouter.POST("/since", p.handleSince)

Expand All @@ -42,6 +48,16 @@ func (p *Plugin) ServeHTTP(c *plugin.Context, w http.ResponseWriter, r *http.Req
router.ServeHTTP(w, r)
}

func (p *Plugin) aiBotRequired(c *gin.Context) {
botUsername := c.DefaultQuery("botUsername", p.getConfiguration().DefaultBotName)
bot := p.GetBotByUsername(botUsername)
if bot == nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get bot: %s", botUsername))
return
}
c.Set(ContextBotKey, bot)
}

func (p *Plugin) ginlogger(c *gin.Context) {
c.Next()

Expand All @@ -61,23 +77,72 @@ func (p *Plugin) MattermostAuthorizationRequired(c *gin.Context) {
func (p *Plugin) handleGetAIThreads(c *gin.Context) {
userID := c.GetHeader("Mattermost-User-Id")

botDMChannel, err := p.pluginAPI.Channel.GetDirect(userID, p.botid)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("unable to get DM with AI bot: %w", err))
return
p.botsLock.RLock()
defer p.botsLock.RUnlock()
dmChannelIDs := []string{}
for _, bot := range p.bots {
botDMChannel, err := p.pluginAPI.Channel.GetDirect(userID, bot.mmBot.UserId)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("unable to get DM with AI bot: %w", err))
return
}

// Extra permissions checks are not totally nessiary since a user should always have permission to read their own DMs
if !p.pluginAPI.User.HasPermissionToChannel(userID, botDMChannel.Id, model.PermissionReadChannel) {
c.AbortWithError(http.StatusForbidden, errors.New("user doesn't have permission to read channel"))
return
}

dmChannelIDs = append(dmChannelIDs, botDMChannel.Id)
}

// Extra permissions checks are not totally nessiary since a user should always have permission to read their own DMs
if !p.pluginAPI.User.HasPermissionToChannel(userID, botDMChannel.Id, model.PermissionReadChannel) {
c.AbortWithError(http.StatusForbidden, errors.New("user doesn't have permission to read channel"))
threads, err := p.getAIThreads(dmChannelIDs)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get posts for bot DM: %w", err))
return
}

posts, err := p.getAIThreads(botDMChannel.Id)
c.JSON(http.StatusOK, threads)
}

type AIBotInfo struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
Username string `json:"username"`
LastIconUpdate int64 `json:"lastIconUpdate"`
DMChannelID string `json:"dmChannelID"`
}

func (p *Plugin) handleGetAIBots(c *gin.Context) {
userID := c.GetHeader("Mattermost-User-Id")

ownedBots, err := p.pluginAPI.Bot.List(0, 1000, pluginapi.BotOwner("mattermost-ai"))
if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get posts for bot DM: %w", err))
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to get bots: %w", err))
return
}

c.JSON(http.StatusOK, posts)
// Get the info from all the bots.
// Put the default bot first.
bots := make([]AIBotInfo, len(ownedBots))
defaultBotName := p.getConfiguration().DefaultBotName
for i, bot := range ownedBots {
direct, err := p.pluginAPI.Channel.GetDirect(userID, bot.UserId)
if err != nil {
p.API.LogError("unable to get direct channel for bot", "error", err)
continue
}
bots[i] = AIBotInfo{
ID: bot.UserId,
DisplayName: bot.DisplayName,
Username: bot.Username,
LastIconUpdate: bot.LastIconUpdate,
DMChannelID: direct.Id,
}
if bot.Username == defaultBotName {
bots[0], bots[i] = bots[i], bots[0]
}
}

c.JSON(http.StatusOK, bots)
}
7 changes: 4 additions & 3 deletions server/api_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func (p *Plugin) channelAuthorizationRequired(c *gin.Context) {
func (p *Plugin) handleSince(c *gin.Context) {
userID := c.GetHeader("Mattermost-User-Id")
channel := c.MustGet(ContextChannelKey).(*model.Channel)
bot := c.MustGet(ContextBotKey).(*Bot)

if !p.licenseChecker.IsBasicsLicensed() {
c.AbortWithError(http.StatusForbidden, enterprise.ErrNotLicensed)
Expand Down Expand Up @@ -82,7 +83,7 @@ func (p *Plugin) handleSince(c *gin.Context) {

formattedThread := formatThread(threadData)

context := p.MakeConversationContext(user, channel, nil)
context := p.MakeConversationContext(bot, user, channel, nil)
context.PromptParameters = map[string]string{
"Posts": formattedThread,
}
Expand All @@ -108,15 +109,15 @@ func (p *Plugin) handleSince(c *gin.Context) {
return
}

resultStream, err := p.getLLM().ChatCompletion(prompt)
resultStream, err := p.getLLM(bot.cfg.Service).ChatCompletion(prompt)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

post := &model.Post{}
post.AddProp(NoRegen, "true")
if err := p.streamResultToNewDM(resultStream, user.Id, post); err != nil {
if err := p.streamResultToNewDM(bot.mmBot.UserId, resultStream, user.Id, post); err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
Expand Down
Loading

0 comments on commit bc89035

Please sign in to comment.