Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions pkg/workflow/compiler_jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ func (c *Compiler) buildCustomJobs(data *WorkflowData, activationJobCreated bool
// don't inherit GITHUB_ENV from the agent job, so the gh CLI won't
// know which host to target without this step.
job.Steps = append(job.Steps, generateGHESHostConfigurationStep())
if shouldInjectNodeSetupForGPUCustomJob(configMap) {
job.Steps = append(job.Steps, generateNodeSetupStepForCustomJob(data))
}
job.Steps = append(job.Steps, preSteps...)
job.Steps = append(job.Steps, regularSteps...)
}
Expand All @@ -824,6 +827,91 @@ func (c *Compiler) buildCustomJobs(data *WorkflowData, activationJobCreated bool
return nil
}

func shouldInjectNodeSetupForGPUCustomJob(configMap map[string]any) bool {
if configMap == nil {
return false
}
runsOn, hasRunsOn := configMap["runs-on"]
if !hasRunsOn || !containsRunnerLabel(runsOn, "aw-gpu-runner-T4") {
return false
}
return !jobStepsContainSetupNode(configMap)
}

func containsRunnerLabel(value any, target string) bool {
switch v := value.(type) {
case string:
return strings.EqualFold(strings.TrimSpace(v), target)
case []any:
for _, item := range v {
if containsRunnerLabel(item, target) {
return true
}
}
case map[string]any:
// Support object-form runs-on (e.g. {group: "...", labels: ["aw-gpu-runner-T4"]}).
// GitHub Actions allows object form for larger/self-hosted runners.
for _, item := range v {
if containsRunnerLabel(item, target) {
return true
}
}
}
return false
}

func jobStepsContainSetupNode(configMap map[string]any) bool {
for _, fieldName := range []string{"pre-steps", "steps"} {
fieldValue, hasField := configMap[fieldName]
if !hasField {
continue
}
steps, ok := fieldValue.([]any)
if !ok {
continue
}
for _, step := range steps {
stepMap, ok := step.(map[string]any)
if !ok {
continue
}
usesValue, hasUses := stepMap["uses"]
if !hasUses {
continue
}
usesStr, ok := usesValue.(string)
if !ok {
continue
}
if strings.Contains(strings.ToLower(usesStr), "setup-node") {
return true
}
}
}
return false
}

func generateNodeSetupStepForCustomJob(data *WorkflowData) string {
requirements := map[string]*RuntimeRequirement{}
nodeRuntime := findRuntimeByID("node")
if nodeRuntime == nil {
compilerJobsLog.Print("Node runtime definition not found; skipping Node setup step injection for GPU custom job")
return ""
}

updateRequiredRuntime(nodeRuntime, "", requirements)
if data != nil && data.Runtimes != nil {
applyRuntimeOverrides(data.Runtimes, requirements)
}
nodeRequirement, exists := requirements["node"]
if !exists {
compilerJobsLog.Print("Node runtime requirement missing after overrides; skipping Node setup step injection for GPU custom job")
return ""
}

return strings.Join(generateSetupStep(nodeRequirement), "\n") + "\n"
}

func (c *Compiler) applyBuiltinJobPreSteps(data *WorkflowData) error {
if data == nil || data.Jobs == nil {
return nil
Expand Down
117 changes: 117 additions & 0 deletions pkg/workflow/compiler_jobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,123 @@ func TestBuildCustomJobsRunsOnForms(t *testing.T) {
}
}

func TestBuildCustomJobsAddsNodeSetupForAWGPURunner(t *testing.T) {
compiler := NewCompiler()
compiler.jobManager = NewJobManager()

data := &WorkflowData{
Name: "Test Workflow",
AI: "copilot",
Jobs: map[string]any{
"gpu_job": map[string]any{
"runs-on": "aw-gpu-runner-T4",
"steps": []any{
map[string]any{"name": "Work", "run": "echo hi"},
},
},
},
}

err := compiler.buildCustomJobs(data, false)
if err != nil {
t.Fatalf("buildCustomJobs() returned unexpected error: %v", err)
}

job, exists := compiler.jobManager.GetJob("gpu_job")
if !exists {
t.Fatal("Expected gpu_job to be added")
}

stepsContent := strings.Join(job.Steps, "")
if !strings.Contains(stepsContent, "name: Setup Node.js") {
t.Fatalf("Expected custom GPU job to include Node setup step, got:\n%s", stepsContent)
}
if !strings.Contains(stepsContent, "node-version: '24'") {
t.Fatalf("Expected default Node 24 for custom GPU job, got:\n%s", stepsContent)
}
}

func TestBuildCustomJobsAddsNodeSetupForAWGPURunnerWithRuntimeOverride(t *testing.T) {
compiler := NewCompiler()
compiler.jobManager = NewJobManager()

data := &WorkflowData{
Name: "Test Workflow",
AI: "copilot",
Runtimes: map[string]any{
"node": map[string]any{
"version": "20",
},
},
Jobs: map[string]any{
"gpu_job": map[string]any{
"runs-on": "aw-gpu-runner-T4",
"steps": []any{
map[string]any{"name": "Work", "run": "echo hi"},
},
},
},
}

err := compiler.buildCustomJobs(data, false)
if err != nil {
t.Fatalf("buildCustomJobs() returned unexpected error: %v", err)
}

job, exists := compiler.jobManager.GetJob("gpu_job")
if !exists {
t.Fatal("Expected gpu_job to be added")
}

stepsContent := strings.Join(job.Steps, "")
if !strings.Contains(stepsContent, "name: Setup Node.js") {
t.Fatalf("Expected custom GPU job to include Node setup step, got:\n%s", stepsContent)
}
if !strings.Contains(stepsContent, "node-version: '20'") {
t.Fatalf("Expected Node version to respect runtime override for custom GPU job, got:\n%s", stepsContent)
}
}

func TestBuildCustomJobsSkipsNodeSetupForAWGPURunnerWhenAlreadyPresent(t *testing.T) {
compiler := NewCompiler()
compiler.jobManager = NewJobManager()

data := &WorkflowData{
Name: "Test Workflow",
AI: "copilot",
Jobs: map[string]any{
"gpu_job": map[string]any{
"runs-on": "aw-gpu-runner-T4",
"steps": []any{
map[string]any{
"name": "Setup Node.js (manual)",
"uses": "actions/setup-node@v6",
"with": map[string]any{
"node-version": "24",
},
},
map[string]any{"name": "Work", "run": "echo hi"},
},
},
},
}

err := compiler.buildCustomJobs(data, false)
if err != nil {
t.Fatalf("buildCustomJobs() returned unexpected error: %v", err)
}

job, exists := compiler.jobManager.GetJob("gpu_job")
if !exists {
t.Fatal("Expected gpu_job to be added")
}

stepsContent := strings.Join(job.Steps, "")
if strings.Count(stepsContent, "uses: actions/setup-node@") != 1 {
t.Fatalf("Expected exactly one setup-node action step when already present, got:\n%s", stepsContent)
}
}

// TestBuildCustomJobsNewSimpleFields tests extraction of simple job fields via CompileWorkflow
func TestBuildCustomJobsNewSimpleFields(t *testing.T) {
tmpDir := testutil.TempDir(t, "new-simple-fields-test")
Expand Down