Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Allow OpenAI to execute arbitrary external actions + no-code implementation with N8N #182

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ func toolsToOpenAITools(tools []ai.Tool) []openaiClient.Tool {
Anonymous: true,
ExpandedStruct: true,
}

for _, tool := range tools {
schema := schemaMaker.Reflect(tool.Schema)
result = append(result, openaiClient.Tool{
Expand Down Expand Up @@ -186,6 +185,21 @@ func createFunctionArrgmentResolver(jsonArgs string) ai.ToolArgumentGetter {
}
}

func (s *OpenAI) handleStreamFunctionCall(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation, name, arguments string) (openaiClient.ChatCompletionRequest, error) {
fmt.Println("TOOL SELECTED", name, arguments)
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context)
if err != nil {
fmt.Println("Error resolving function: ", err)
}
request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{
Role: openaiClient.ChatMessageRoleFunction,
Name: name,
Content: toolResult,
})

return request, nil
}

type ToolBufferElement struct {
id strings.Builder
name strings.Builder
Expand Down
12 changes: 9 additions & 3 deletions server/ai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ import (
"errors"
)

type BuiltInToolsFunc func(isDM bool) []Tool

type Prompts struct {
templates *template.Template
templates *template.Template
getBuiltInTools BuiltInToolsFunc
getThirdPartyTools BuiltInToolsFunc
}

const PromptExtension = "tmpl"
Expand All @@ -37,14 +41,16 @@ const (
PromptFindOpenQuestionsSince = "find_open_questions_since"
)

func NewPrompts(input fs.FS) (*Prompts, error) {
func NewPrompts(input fs.FS, getBuiltInTools, getThirdPartyTools BuiltInToolsFunc) (*Prompts, error) {
templates, err := template.ParseFS(input, "ai/prompts/*")
if err != nil {
return nil, fmt.Errorf("unable to parse prompt templates: %w", err)
}

return &Prompts{
templates: templates,
templates: templates,
getBuiltInTools: getBuiltInTools,
getThirdPartyTools: getThirdPartyTools,
}, nil
}

Expand Down
17 changes: 11 additions & 6 deletions server/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
)

type Tool struct {
Name string
Description string
Schema any
Resolver func(context ConversationContext, argsGetter ToolArgumentGetter) (string, error)
Name string
Description string
Schema any
IsRawMessage bool
HTTPMethod string
Resolver func(name string, context ConversationContext, argsGetter ToolArgumentGetter) (string, error)
}

type ToolArgumentGetter func(args any) error
Expand Down Expand Up @@ -52,8 +54,11 @@ func (s *ToolStore) ResolveTool(name string, argsGetter ToolArgumentGetter, cont
s.TraceUnknown(name, argsGetter)
return "", errors.New("unknown tool " + name)
}
results, err := tool.Resolver(context, argsGetter)
s.TraceResolved(name, argsGetter, results)
if tool.Resolver == nil {
return "", errors.New("Tool resolver IS NIL")
}
results, err := tool.Resolver(name, context, argsGetter)
s.TraceResolved(name, argsGetter, results)
return results, err
}

Expand Down
8 changes: 4 additions & 4 deletions server/built_in_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type LookupMattermostUserArgs struct {
Username string `jsonschema_description:"The username of the user to lookup without a leading '@'. Example: 'firstname.lastname'"`
}

func (p *Plugin) toolResolveLookupMattermostUser(context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
func (p *Plugin) toolResolveLookupMattermostUser(_ string, context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
var args LookupMattermostUserArgs
err := argsGetter(&args)
if err != nil {
Expand Down Expand Up @@ -85,7 +85,7 @@ type GetChannelPosts struct {
NumberPosts int `jsonschema_description:"The number of most recent posts to get. Example: '30'"`
}

func (p *Plugin) toolResolveGetChannelPosts(context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
func (p *Plugin) toolResolveGetChannelPosts(_ string, context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
var args GetChannelPosts
err := argsGetter(&args)
if err != nil {
Expand Down Expand Up @@ -143,7 +143,7 @@ func formatGithubIssue(issue *github.Issue) string {

var validGithubRepoName = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)

func (p *Plugin) toolGetGithubIssue(context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
func (p *Plugin) toolGetGithubIssue(_ string, context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
var args GetGithubIssueArgs
err := argsGetter(&args)
if err != nil {
Expand Down Expand Up @@ -359,7 +359,7 @@ func (p *Plugin) getPublicJiraIssues(instanceURL string, issueKeys []string) ([]
return &issue, nil
}*/

func (p *Plugin) toolGetJiraIssue(context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
func (p *Plugin) toolGetJiraIssue(_ string, context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
var args GetJiraIssueArgs
err := argsGetter(&args)
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions server/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"reflect"

"github.com/mattermost/mattermost-plugin-ai/server/ai"
"github.com/mattermost/mattermost-plugin-ai/server/tools"
)

type Config struct {
Expand All @@ -14,10 +15,11 @@ type Config struct {
TranscriptGenerator string `json:"transcriptBackend"`
EnableLLMTrace bool `json:"enableLLMTrace"`

EnableUseRestrictions bool `json:"enableUserRestrictions"`
AllowPrivateChannels bool `json:"allowPrivateChannels"`
AllowedTeamIDs string `json:"allowedTeamIDs"`
OnlyUsersOnTeam string `json:"onlyUsersOnTeam"`
EnableUseRestrictions bool `json:"enableUserRestrictions"`
AllowPrivateChannels bool `json:"allowPrivateChannels"`
AllowedTeamIDs string `json:"allowedTeamIDs"`
OnlyUsersOnTeam string `json:"onlyUsersOnTeam"`
ExternalTools []tools.ToolGetterConfig `json:"ExternalTools"`
}

// configuration captures the plugin's external configuration as exposed in the Mattermost server
Expand Down
52 changes: 52 additions & 0 deletions server/external_tools.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package main

import (
"fmt"

"github.com/mattermost/mattermost-plugin-ai/server/ai"
"github.com/mattermost/mattermost-plugin-ai/server/tools/n8n"
"github.com/mattermost/mattermost-plugin-ai/server/tools/superface"
"github.com/mattermost/mattermost-plugin-ai/server/tools/zapier"
)

func (p *Plugin) getThirdPartyTools(isDM bool) []ai.Tool {
thirdPartyTools := []ai.Tool{}

config := p.getConfiguration()

if len(config.ExternalTools) == 0 {
return thirdPartyTools
}

for _, tool := range config.ExternalTools {
switch tool.Provider {
case "superface":
getter := superface.New(tool.URL, tool.AuthToken)
tools, err := getter.ListTools("")
if err != nil {
// handle
fmt.Println(fmt.Errorf("error occurred fetching tools from superface: %w", err))
}
thirdPartyTools = append(thirdPartyTools, tools...)
case "zapier":
// Haven't actually gotten this one working yet
getter := zapier.New(tool.URL, tool.AuthToken)
tools, err := getter.ListTools("")
if err != nil {
// handle
fmt.Println(fmt.Errorf("error occurred fetching tools from zapier", err))

Check failure on line 37 in server/external_tools.go

View workflow job for this annotation

GitHub Actions / plugin-ci / test

fmt.Errorf call has arguments but no formatting directives

Check failure on line 37 in server/external_tools.go

View workflow job for this annotation

GitHub Actions / plugin-ci / test

fmt.Errorf call has arguments but no formatting directives
}
thirdPartyTools = append(thirdPartyTools, tools...)
case "n8n":
getter := n8n.New(tool.URL, tool.AuthToken)
tools, err := getter.ListTools("")
if err != nil {
fmt.Println(fmt.Errorf("error occurred fetching tools from n8n %w", err))
}
thirdPartyTools = append(thirdPartyTools, tools...)
}

}

return thirdPartyTools
}
4 changes: 2 additions & 2 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (p *Plugin) OnActivate() error {

p.metricsService = metrics.NewMetrics(metrics.InstanceInfo{
InstallationID: os.Getenv("MM_CLOUD_INSTALLATION_ID"),
PluginVersion: manifest.Version,
PluginVersion: "1.0.0",
})
p.metricsHandler = metrics.NewMetricsHandler(p.GetMetrics())

Expand All @@ -112,7 +112,7 @@ func (p *Plugin) OnActivate() error {
}

var err error
p.prompts, err = ai.NewPrompts(promptsFolder)
p.prompts, err = ai.NewPrompts(promptsFolder, p.getBuiltInTools, p.getThirdPartyTools)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion server/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func SetupTestEnvironment(t *testing.T) *TestEnvironment {
}

var promptErr error
p.prompts, promptErr = ai.NewPrompts(promptsFolder)
p.prompts, promptErr = ai.NewPrompts(promptsFolder, p.getBuiltInTools, p.getThirdPartyTools)
require.NoError(t, promptErr)

p.ffmpegPath = ""
Expand Down
14 changes: 14 additions & 0 deletions server/tools/interfaces.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package tools

import "github.com/mattermost/mattermost-plugin-ai/server/ai"

type ToolGetterConfig struct {
Provider string
URL string
AuthToken string
}

type ToolGetter interface {
ListTools(userID string) ([]ai.Tool, error)
Perform(userID string, functionName string, arguments any) (any, error)
}
135 changes: 135 additions & 0 deletions server/tools/n8n/n8n.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package n8n

import (
"bytes"
"encoding/json"
"fmt"
"net/http"

"github.com/mattermost/mattermost-plugin-ai/server/ai"
)

type N8N struct {
n8nURL string
authToken string
HTTPClient *http.Client
}

type PerformResponse struct {
Status string `json:"status"`
AssistantHint string `json:"assistant_hint"`
Result any `json:"result"`
CopilotSpecificData any `json:"copilot_specific_data"`
}

func New(url, authToken string) *N8N {
return &N8N{
n8nURL: url,
authToken: authToken,
HTTPClient: &http.Client{},
}
}

func (n *N8N) Resolver(httpVerb string, functionName string, context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
var args any
performResolverErr := argsGetter(&args)
if performResolverErr != nil {
return "", performResolverErr
}
toolPerformResult, performResolverErr := n.Perform(httpVerb, context.RequestingUser.Id, functionName, args)
if performResolverErr != nil {
return "", performResolverErr
}
return toolPerformResult.ToString()
}

func (n *N8N) ListTools(userID string) ([]ai.Tool, error) {
result := N8NListResponse{}

resp, err := n.do(http.MethodGet, "/api/v1/workflows?active=true", userID, nil)
if err != nil {
return nil, err
}

defer resp.Body.Close()

err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
return nil, err
}

if len(result.Tools) == 0 {
return nil, nil
}

n8nTools := []ai.Tool{}
for _, tool := range result.Tools {
if !tool.Active {
continue
}
n8nTool := tool.ToMattermostAITool()

n8nTool.Resolver = func(functionName string, context ai.ConversationContext, argsGetter ai.ToolArgumentGetter) (string, error) {
return n.Resolver(n8nTool.HTTPMethod, functionName, context, argsGetter)
}
n8nTools = append(n8nTools, n8nTool)
}

return n8nTools, nil
}

func (n *N8N) Perform(httpVerb, userID, functionName string, arguments any) (*PerformResponse, error) {
var result PerformResponse

fmt.Println("PERFORMING N8N ACTION", functionName)
resp, err := n.do(httpVerb, fmt.Sprintf("/webhook/%s", functionName), userID, arguments)
if err != nil {
return nil, err
}

defer resp.Body.Close()

err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
return nil, err
}

fmt.Println("N8N RESULT OBJECT:", result)
result.CopilotSpecificData = nil

return &result, nil
}

func (s *N8N) do(method, path, userID string, body interface{}) (*http.Response, error) {
var req *http.Request
fullPath := s.n8nURL + path
if body != nil {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
fmt.Println("SENDING BODY:", string(jsonBody))
bodyBuffer := bytes.NewBuffer(jsonBody)

req, err = http.NewRequest(method, fullPath, bodyBuffer)
if err != nil {
return nil, err
}
} else {
var err error
req, err = http.NewRequest(method, fullPath, nil)
if err != nil {
return nil, err
}
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-N8N-API-KEY", s.authToken)

resp, err := s.HTTPClient.Do(req)
if err != nil {
return nil, err
}

return resp, nil
}
Loading
Loading