diff --git a/go/ai/prompt.go b/go/ai/prompt.go index c381443d9b..c76b014216 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -182,10 +182,26 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod }, } - actionOpts.Messages, err = renderMessages(ctx, tempOpts, actionOpts.Messages, m, execOpts.Input, p.registry.Dotprompt()) + execMsgs, err := renderMessages(ctx, tempOpts, []*Message{}, m, execOpts.Input, p.registry.Dotprompt()) if err != nil { return nil, err } + + var systemMsgs []*Message + var msgs []*Message + foundNonSystem := false + + for _, msg := range actionOpts.Messages { + if msg.Role == RoleSystem && !foundNonSystem { + systemMsgs = append(systemMsgs, msg) + } else { + foundNonSystem = true + msgs = append(msgs, msg) + } + } + + actionOpts.Messages = append(systemMsgs, execMsgs...) + actionOpts.Messages = append(actionOpts.Messages, msgs...) } toolRefs := execOpts.Tools diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index 9282193a0e..7390d41335 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -511,7 +511,7 @@ func TestValidPrompt(t *testing.T) { WithInput(HelloPromptInput{Name: "foo"}), WithMessages(NewModelTextMessage("I remember you said your name is {{Name}}")), }, - wantTextOutput: "Echo: system: say hello; my name is foo; I remember you said your name is foo; config: {\n \"temperature\": 11\n}; context: null", + wantTextOutput: "Echo: system: say hello; I remember you said your name is foo; my name is foo; config: {\n \"temperature\": 11\n}; context: null", wantGenerated: &ModelRequest{ Config: &GenerationCommonConfig{ Temperature: 11, @@ -525,14 +525,14 @@ func TestValidPrompt(t *testing.T) { Role: RoleSystem, Content: []*Part{NewTextPart("say hello")}, }, - { - Role: RoleUser, - Content: []*Part{NewTextPart("my name is foo")}, - }, { Role: RoleModel, Content: []*Part{NewTextPart("I remember you said your name is foo")}, }, + { + Role: RoleUser, + Content: []*Part{NewTextPart("my name is foo")}, + }, }, }, },