Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,54 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
}

key := promptKey(name, variant, namespace)
prompt := DefinePrompt(r, key, opts, WithPrompt(parsedPrompt.Template))

dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
if err != nil {
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
return nil
}

var systemText string
var nonSystemMessages []*Message
for _, dpMsg := range dpMessages {
parts, err := convertToPartPointers(dpMsg.Content)
if err != nil {
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
return nil
}

role := Role(dpMsg.Role)
if role == RoleSystem {
var textParts []string
for _, part := range parts {
if part.IsText() {
textParts = append(textParts, part.Text)
}
}

if len(textParts) > 0 {
systemText = strings.Join(textParts, " ")
}
} else {
nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts})
}
}

promptOpts := []PromptOption{opts}

// Add system prompt if found
if systemText != "" {
promptOpts = append(promptOpts, WithSystem(systemText))
}

// If there are non-system messages, use WithMessages, otherwise use WithPrompt for template
if len(nonSystemMessages) > 0 {
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
} else if systemText == "" {
promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template))
}

prompt := DefinePrompt(r, key, promptOpts...)

slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile)

Expand Down
61 changes: 61 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1148,3 +1148,64 @@ Hello!
t.Fatalf("Failed to execute prompt: %v", err)
}
}

func TestMultiMessagesRenderPrompt(t *testing.T) {
tempDir := t.TempDir()

mockPromptFile := filepath.Join(tempDir, "example.prompt")
mockPromptContent := `---
model: test/chat
description: A test prompt
---
<<<dotprompt:role:system>>>
You are a pirate!

<<<dotprompt:role:user>>>
Hello!
`

if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}

prompt := LoadPrompt(registry.New(), tempDir, "example.prompt", "multi-namespace-roles")

actionOpts, err := prompt.Render(context.Background(), map[string]any{})
if err != nil {
t.Fatalf("Failed to execute prompt: %v", err)
}

// Check that actionOpts is not nil
if actionOpts == nil {
t.Fatal("Expected actionOpts to be non-nil")
}

// Check that we have exactly 2 messages (system and user)
if len(actionOpts.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %d", len(actionOpts.Messages))
}

// Check first message (system role)
systemMsg := actionOpts.Messages[0]
if systemMsg.Role != RoleSystem {
t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role)
}
if len(systemMsg.Content) == 0 {
t.Fatal("Expected system message to have content")
}
if strings.TrimSpace(systemMsg.Content[0].Text) != "You are a pirate!" {
t.Errorf("Expected system message text to be 'You are a pirate!', got '%s'", systemMsg.Content[0].Text)
}

// Check second message (user role)
userMsg := actionOpts.Messages[1]
if userMsg.Role != RoleUser {
t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role)
}
if len(userMsg.Content) == 0 {
t.Fatal("Expected user message to have content")
}
if strings.TrimSpace(userMsg.Content[0].Text) != "Hello!" {
t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Content[0].Text)
}
}
Loading