Skip to content

Commit 532e5ed

Browse files
author
mpowers5
committed
feat: Add selectors to frontmatter for tasks.
For Example: ```markdown --- task_name: my_task selectors: language: go env: prod --- ``` Would be the same as: ``` -s language=go -s env=prod ```
1 parent c67575b commit 532e5ed

File tree

2 files changed

+195
-17
lines changed

2 files changed

+195
-17
lines changed

main.go

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ type codingContext struct {
2626

2727
downloadedDirs []string
2828
matchingTaskFile string
29+
taskFrontmatter frontMatter // Parsed task frontmatter
30+
taskContent string // Parsed task content (before parameter expansion)
2931
totalTokens int
3032
output io.Writer
3133
logOut io.Writer
@@ -101,12 +103,17 @@ func (cc *codingContext) run(ctx context.Context, args []string) error {
101103
return fmt.Errorf("failed to find task file: %w", err)
102104
}
103105

106+
// Parse task file early to extract selector labels for filtering rules and tools
107+
if err := cc.parseTaskFile(); err != nil {
108+
return fmt.Errorf("failed to parse task file: %w", err)
109+
}
110+
104111
if err := cc.findExecuteRuleFiles(ctx, homeDir); err != nil {
105112
return fmt.Errorf("failed to find and execute rule files: %w", err)
106113
}
107114

108-
if err := cc.writeTaskFileContent(); err != nil {
109-
return fmt.Errorf("failed to write task file content: %w", err)
115+
if err := cc.emitTaskFileContent(); err != nil {
116+
return fmt.Errorf("failed to emit task file content: %w", err)
110117
}
111118

112119
return nil
@@ -284,15 +291,48 @@ func (cc *codingContext) runBootstrapScript(ctx context.Context, path, ext strin
284291
return nil
285292
}
286293

287-
func (cc *codingContext) writeTaskFileContent() error {
288-
taskMatter := make(map[string]any)
294+
// parseTaskFile parses the task file and extracts selector labels from frontmatter.
295+
// The selectors are added to cc.includes for filtering rules and tools.
296+
// The parsed frontmatter and content are stored in cc.taskFrontmatter and cc.taskContent.
297+
func (cc *codingContext) parseTaskFile() error {
298+
cc.taskFrontmatter = make(frontMatter)
289299

290-
content, err := parseMarkdownFile(cc.matchingTaskFile, &taskMatter)
300+
content, err := parseMarkdownFile(cc.matchingTaskFile, &cc.taskFrontmatter)
291301
if err != nil {
292-
return fmt.Errorf("failed to parse prompt file %s: %w", cc.matchingTaskFile, err)
302+
return fmt.Errorf("failed to parse task file %s: %w", cc.matchingTaskFile, err)
303+
}
304+
305+
cc.taskContent = content
306+
307+
// Extract selector labels from frontmatter
308+
// Look for a "selectors" field that contains a map of key-value pairs
309+
if selectorsRaw, ok := cc.taskFrontmatter["selectors"]; ok {
310+
selectorsMap, ok := selectorsRaw.(map[string]any)
311+
if !ok {
312+
// Try to handle it as a map[interface{}]interface{} (common YAML unmarshal result)
313+
if selectorsMapAny, ok := selectorsRaw.(map[any]any); ok {
314+
selectorsMap = make(map[string]any)
315+
for k, v := range selectorsMapAny {
316+
selectorsMap[fmt.Sprint(k)] = v
317+
}
318+
} else {
319+
return fmt.Errorf("task file %s has invalid 'selectors' field: expected map, got %T", cc.matchingTaskFile, selectorsRaw)
320+
}
321+
}
322+
323+
// Add selectors to includes
324+
for key, value := range selectorsMap {
325+
cc.includes[key] = fmt.Sprint(value)
326+
}
293327
}
294328

295-
expanded := os.Expand(content, func(key string) string {
329+
return nil
330+
}
331+
332+
// emitTaskFileContent emits the parsed task content to the output.
333+
// It expands parameters, estimates tokens, and optionally includes frontmatter.
334+
func (cc *codingContext) emitTaskFileContent() error {
335+
expanded := os.Expand(cc.taskContent, func(key string) string {
296336
if val, ok := cc.params[key]; ok {
297337
return val
298338
}
@@ -307,7 +347,7 @@ func (cc *codingContext) writeTaskFileContent() error {
307347

308348
if cc.emitTaskFrontmatter {
309349
fmt.Fprintln(cc.output, "---")
310-
if err := yaml.NewEncoder(cc.output).Encode(taskMatter); err != nil {
350+
if err := yaml.NewEncoder(cc.output).Encode(cc.taskFrontmatter); err != nil {
311351
return fmt.Errorf("failed to encode task matter: %w", err)
312352
}
313353
fmt.Fprintln(cc.output, "---")

main_test.go

Lines changed: 147 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -689,45 +689,183 @@ func TestWriteTaskFileContent(t *testing.T) {
689689
emitTaskFrontmatter: tt.emitTaskFrontmatter,
690690
output: &output,
691691
logOut: &logOut,
692+
includes: make(selectorMap),
692693
}
693694

694-
err := cc.writeTaskFileContent()
695+
// Parse task file first
696+
if err := cc.parseTaskFile(); err != nil {
697+
if !tt.wantErr {
698+
t.Errorf("parseTaskFile() unexpected error: %v", err)
699+
}
700+
return
701+
}
702+
703+
// Then emit the content
704+
err := cc.emitTaskFileContent()
695705

696706
if tt.wantErr {
697707
if err == nil {
698-
t.Errorf("writeTaskFileContent() expected error, got nil")
708+
t.Errorf("emitTaskFileContent() expected error, got nil")
699709
}
700710
} else {
701711
if err != nil {
702-
t.Errorf("writeTaskFileContent() unexpected error: %v", err)
712+
t.Errorf("emitTaskFileContent() unexpected error: %v", err)
703713
}
704714
}
705715

706716
outputStr := output.String()
707717
if tt.expectInOutput != "" {
708718
if !strings.Contains(outputStr, tt.expectInOutput) {
709-
t.Errorf("writeTaskFileContent() output should contain %q, got:\n%s", tt.expectInOutput, outputStr)
719+
t.Errorf("emitTaskFileContent() output should contain %q, got:\n%s", tt.expectInOutput, outputStr)
710720
}
711721
}
712722

713723
// Additional checks for frontmatter emission
714724
if tt.emitTaskFrontmatter {
715725
// Verify frontmatter delimiters are present
716726
if !strings.Contains(outputStr, "---") {
717-
t.Errorf("writeTaskFileContent() with emitTaskFrontmatter=true should contain '---' delimiters, got:\n%s", outputStr)
727+
t.Errorf("emitTaskFileContent() with emitTaskFrontmatter=true should contain '---' delimiters, got:\n%s", outputStr)
718728
}
719729
// Verify YAML frontmatter structure
720730
if !strings.Contains(outputStr, "task_name:") {
721-
t.Errorf("writeTaskFileContent() with emitTaskFrontmatter=true should contain 'task_name:' field, got:\n%s", outputStr)
731+
t.Errorf("emitTaskFileContent() with emitTaskFrontmatter=true should contain 'task_name:' field, got:\n%s", outputStr)
722732
}
723733
// Verify task content is still present
724734
if !strings.Contains(outputStr, "# Task with Frontmatter") {
725-
t.Errorf("writeTaskFileContent() should contain task content, got:\n%s", outputStr)
735+
t.Errorf("emitTaskFileContent() should contain task content, got:\n%s", outputStr)
726736
}
727737
}
728738

729739
if !tt.wantErr && cc.totalTokens <= 0 {
730-
t.Errorf("writeTaskFileContent() expected tokens > 0, got %d", cc.totalTokens)
740+
t.Errorf("emitTaskFileContent() expected tokens > 0, got %d", cc.totalTokens)
741+
}
742+
})
743+
}
744+
}
745+
746+
func TestParseTaskFile(t *testing.T) {
747+
tests := []struct {
748+
name string
749+
taskFile string
750+
setupFiles func(t *testing.T, tmpDir string) string // returns task file path
751+
initialIncludes selectorMap
752+
expectedIncludes selectorMap // expected includes after parsing
753+
wantErr bool
754+
errContains string
755+
}{
756+
{
757+
name: "task without selectors field",
758+
taskFile: "task.md",
759+
initialIncludes: make(selectorMap),
760+
expectedIncludes: make(selectorMap),
761+
setupFiles: func(t *testing.T, tmpDir string) string {
762+
taskPath := filepath.Join(tmpDir, "task.md")
763+
createMarkdownFile(t, taskPath,
764+
"task_name: test",
765+
"# Simple Task")
766+
return taskPath
767+
},
768+
wantErr: false,
769+
},
770+
{
771+
name: "task with selectors field",
772+
taskFile: "task.md",
773+
initialIncludes: make(selectorMap),
774+
expectedIncludes: selectorMap{
775+
"language": "Go",
776+
"env": "prod",
777+
},
778+
setupFiles: func(t *testing.T, tmpDir string) string {
779+
taskPath := filepath.Join(tmpDir, "task.md")
780+
createMarkdownFile(t, taskPath,
781+
"task_name: test\nselectors:\n language: Go\n env: prod",
782+
"# Task with Selectors")
783+
return taskPath
784+
},
785+
wantErr: false,
786+
},
787+
{
788+
name: "task with selectors merges with existing includes",
789+
taskFile: "task.md",
790+
initialIncludes: selectorMap{"existing": "value"},
791+
expectedIncludes: selectorMap{
792+
"existing": "value",
793+
"language": "Python",
794+
},
795+
setupFiles: func(t *testing.T, tmpDir string) string {
796+
taskPath := filepath.Join(tmpDir, "task.md")
797+
createMarkdownFile(t, taskPath,
798+
"task_name: test\nselectors:\n language: Python",
799+
"# Task with Selectors")
800+
return taskPath
801+
},
802+
wantErr: false,
803+
},
804+
{
805+
name: "task with invalid selectors field type",
806+
taskFile: "task.md",
807+
initialIncludes: make(selectorMap),
808+
setupFiles: func(t *testing.T, tmpDir string) string {
809+
taskPath := filepath.Join(tmpDir, "task.md")
810+
createMarkdownFile(t, taskPath,
811+
"task_name: test\nselectors: invalid",
812+
"# Task with Invalid Selectors")
813+
return taskPath
814+
},
815+
wantErr: true,
816+
errContains: "invalid 'selectors' field",
817+
},
818+
}
819+
820+
for _, tt := range tests {
821+
t.Run(tt.name, func(t *testing.T) {
822+
tmpDir := t.TempDir()
823+
taskPath := tt.setupFiles(t, tmpDir)
824+
825+
cc := &codingContext{
826+
matchingTaskFile: taskPath,
827+
includes: tt.initialIncludes,
828+
}
829+
if cc.includes == nil {
830+
cc.includes = make(selectorMap)
831+
}
832+
833+
err := cc.parseTaskFile()
834+
835+
if tt.wantErr {
836+
if err == nil {
837+
t.Errorf("parseTaskFile() expected error, got nil")
838+
} else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
839+
t.Errorf("parseTaskFile() error = %v, should contain %q", err, tt.errContains)
840+
}
841+
} else {
842+
if err != nil {
843+
t.Errorf("parseTaskFile() unexpected error: %v", err)
844+
}
845+
846+
// Verify selectors were extracted correctly
847+
for key, expectedValue := range tt.expectedIncludes {
848+
if actualValue, ok := cc.includes[key]; !ok {
849+
t.Errorf("parseTaskFile() expected includes[%q] = %q, but key not found", key, expectedValue)
850+
} else if actualValue != expectedValue {
851+
t.Errorf("parseTaskFile() includes[%q] = %q, want %q", key, actualValue, expectedValue)
852+
}
853+
}
854+
855+
// Verify all includes match expected (including initial includes)
856+
if len(cc.includes) != len(tt.expectedIncludes) {
857+
t.Errorf("parseTaskFile() includes length = %d, want %d. Includes: %v", len(cc.includes), len(tt.expectedIncludes), cc.includes)
858+
}
859+
860+
// Verify task content was stored
861+
if cc.taskContent == "" {
862+
t.Errorf("parseTaskFile() expected taskContent to be set, got empty string")
863+
}
864+
865+
// Verify task frontmatter was stored
866+
if cc.taskFrontmatter == nil {
867+
t.Errorf("parseTaskFile() expected taskFrontmatter to be set, got nil")
868+
}
731869
}
732870
})
733871
}
@@ -979,4 +1117,4 @@ func (f *fileInfoMock) Size() int64 { return 0 }
9791117
func (f *fileInfoMock) Mode() os.FileMode { return 0o644 }
9801118
func (f *fileInfoMock) ModTime() time.Time { return time.Time{} }
9811119
func (f *fileInfoMock) IsDir() bool { return f.isDir }
982-
func (f *fileInfoMock) Sys() interface{} { return nil }
1120+
func (f *fileInfoMock) Sys() any { return nil }

0 commit comments

Comments
 (0)