diff --git a/pkg/workflow/safe_jobs_needs_validation.go b/pkg/workflow/safe_jobs_needs_validation.go index a225b74f4dc..f7971a3ffac 100644 --- a/pkg/workflow/safe_jobs_needs_validation.go +++ b/pkg/workflow/safe_jobs_needs_validation.go @@ -135,9 +135,8 @@ func consolidatedSafeOutputsJobWillExist(safeOutputs *SafeOutputsConfig) bool { if len(safeOutputs.Scripts) > 0 || len(safeOutputs.Actions) > 0 || len(safeOutputs.Steps) > 0 { return true } - // Reuse the existing reflection-based check with the dynamic fields cleared. - // hasAnySafeOutputEnabled will then fall through to reflection over safeOutputFieldMapping, - // which covers every builtin pointer type (create-issue, add-comment, etc.). + // Reuse the direct-check function with the dynamic fields cleared. + // hasAnySafeOutputEnabled covers every builtin pointer type (create-issue, add-comment, etc.). stripped := *safeOutputs stripped.Jobs = nil stripped.Scripts = nil diff --git a/pkg/workflow/safe_outputs_state.go b/pkg/workflow/safe_outputs_state.go index 2d6343c7874..0f9d827502d 100644 --- a/pkg/workflow/safe_outputs_state.go +++ b/pkg/workflow/safe_outputs_state.go @@ -2,9 +2,6 @@ package workflow import ( "fmt" - "reflect" - - "github.com/github/gh-aw/pkg/logger" ) // ======================================== @@ -12,14 +9,21 @@ import ( // ======================================== // // This file contains functions for querying, inspecting, and validating the -// state of a SafeOutputsConfig. It uses reflection to check which tool types -// are enabled without requiring a large switch statement. - -var safeOutputReflectionLog = logger.New("workflow:safe_outputs_config_helpers_reflection") - -// safeOutputFieldMapping maps struct field names to their tool names. -// This map drives reflection-based checks across hasAnySafeOutputEnabled, -// getEnabledSafeOutputToolNamesReflection, and hasNonBuiltinSafeOutputsEnabled. +// state of a SafeOutputsConfig. hasAnySafeOutputEnabled and +// hasNonBuiltinSafeOutputsEnabled use direct nil-checks instead of reflection +// for performance (these functions are called on every compilation). +// +// NOTE: When adding a new pointer field to SafeOutputsConfig that represents +// a user-facing safe output action, add it to ALL of the following locations: +// 1. safeOutputFieldMapping (below) — drives imports, prompt generation, etc. +// 2. hasAnySafeOutputEnabled — performance-critical hot path +// 3. hasNonBuiltinSafeOutputsEnabled — if it is NOT a builtin (noop/missing-*) +// 4. hasSafeOutputType in imports.go — used for conflict detection + +// safeOutputFieldMapping maps SafeOutputsConfig struct field names to their tool names. +// This map is used by imports, prompt generation, and other metadata operations. +// It is NOT used for existence checks — see hasAnySafeOutputEnabled and +// hasNonBuiltinSafeOutputsEnabled for the performance-critical direct-field versions. var safeOutputFieldMapping = map[string]string{ "CreateIssues": "create_issue", "CreateAgentSessions": "create_agent_session", @@ -65,100 +69,128 @@ var safeOutputFieldMapping = map[string]string{ "MarkPullRequestAsReadyForReview": "mark_pull_request_as_ready_for_review", } -// hasAnySafeOutputEnabled uses reflection to check if any safe output field is non-nil. -// It checks Jobs separately (map field) before falling back to pointer fields. +// hasAnySafeOutputEnabled reports whether any safe output field is non-nil. +// It uses direct struct-field nil checks instead of reflection for performance; +// this function is called on every compilation and is on the hot path. +// +// NOTE: keep this function in sync with safeOutputFieldMapping above and +// hasNonBuiltinSafeOutputsEnabled below when adding new safe output types. func hasAnySafeOutputEnabled(safeOutputs *SafeOutputsConfig) bool { if safeOutputs == nil { return false } - safeOutputReflectionLog.Print("Checking if any safe outputs are enabled using reflection") - - // Check Jobs separately as it's a map - if len(safeOutputs.Jobs) > 0 { - safeOutputReflectionLog.Printf("Found %d custom jobs enabled", len(safeOutputs.Jobs)) - return true - } - - // Check Scripts separately as it's a map - if len(safeOutputs.Scripts) > 0 { - safeOutputReflectionLog.Printf("Found %d custom scripts enabled", len(safeOutputs.Scripts)) + // Check map fields separately + if len(safeOutputs.Jobs) > 0 || len(safeOutputs.Scripts) > 0 || len(safeOutputs.Actions) > 0 { return true } - // Check Actions separately as it's a map - if len(safeOutputs.Actions) > 0 { - safeOutputReflectionLog.Printf("Found %d custom actions enabled", len(safeOutputs.Actions)) - return true - } - - // Use reflection to check all pointer fields - val := reflect.ValueOf(safeOutputs).Elem() - for fieldName := range safeOutputFieldMapping { - field := val.FieldByName(fieldName) - if field.IsValid() && !field.IsNil() { - safeOutputReflectionLog.Printf("Found enabled safe output field: %s", fieldName) - return true - } - } - - safeOutputReflectionLog.Print("No safe outputs enabled") - return false + // Direct nil checks — no reflection, no heap allocation (43 fields matching safeOutputFieldMapping + // plus CommentMemory which is attached via tools.comment-memory and not in safeOutputFieldMapping). + return safeOutputs.CreateIssues != nil || + safeOutputs.CreateAgentSessions != nil || + safeOutputs.CreateDiscussions != nil || + safeOutputs.UpdateDiscussions != nil || + safeOutputs.CloseDiscussions != nil || + safeOutputs.CloseIssues != nil || + safeOutputs.ClosePullRequests != nil || + safeOutputs.MarkPullRequestAsReadyForReview != nil || + safeOutputs.AddComments != nil || + safeOutputs.CommentMemory != nil || + safeOutputs.CreatePullRequests != nil || + safeOutputs.CreatePullRequestReviewComments != nil || + safeOutputs.SubmitPullRequestReview != nil || + safeOutputs.ReplyToPullRequestReviewComment != nil || + safeOutputs.ResolvePullRequestReviewThread != nil || + safeOutputs.CreateCodeScanningAlerts != nil || + safeOutputs.AutofixCodeScanningAlert != nil || + safeOutputs.AddLabels != nil || + safeOutputs.RemoveLabels != nil || + safeOutputs.AddReviewer != nil || + safeOutputs.AssignMilestone != nil || + safeOutputs.AssignToAgent != nil || + safeOutputs.AssignToUser != nil || + safeOutputs.UnassignFromUser != nil || + safeOutputs.UpdateIssues != nil || + safeOutputs.UpdatePullRequests != nil || + safeOutputs.MergePullRequest != nil || + safeOutputs.PushToPullRequestBranch != nil || + safeOutputs.UploadAssets != nil || + safeOutputs.UploadArtifact != nil || + safeOutputs.UpdateRelease != nil || + safeOutputs.UpdateProjects != nil || + safeOutputs.CreateProjects != nil || + safeOutputs.CreateProjectStatusUpdates != nil || + safeOutputs.LinkSubIssue != nil || + safeOutputs.HideComment != nil || + safeOutputs.DispatchWorkflow != nil || + safeOutputs.DispatchRepository != nil || + safeOutputs.CallWorkflow != nil || + safeOutputs.MissingTool != nil || + safeOutputs.MissingData != nil || + safeOutputs.SetIssueType != nil || + safeOutputs.NoOp != nil // 43rd field } -// builtinSafeOutputFields contains the struct field names for the built-in safe output types -// that are excluded from the "non-builtin" check. These are: noop, missing-data, missing-tool. -var builtinSafeOutputFields = map[string]bool{ - "NoOp": true, - "MissingData": true, - "MissingTool": true, -} - -// nonBuiltinSafeOutputFieldNames is a pre-computed list of field names from safeOutputFieldMapping -// that are not builtins, used by hasNonBuiltinSafeOutputsEnabled to avoid repeated map iterations. -var nonBuiltinSafeOutputFieldNames = func() []string { - var fields []string - for fieldName := range safeOutputFieldMapping { - if !builtinSafeOutputFields[fieldName] { - fields = append(fields, fieldName) - } - } - return fields -}() - -// hasNonBuiltinSafeOutputsEnabled checks if any non-builtin safe outputs are configured. +// hasNonBuiltinSafeOutputsEnabled reports whether any non-builtin safe output is configured. // The builtin types (noop, missing-data, missing-tool) are excluded from this check // because they are always auto-enabled and do not represent a meaningful output action. +// +// NOTE: keep this function in sync with safeOutputFieldMapping above and +// hasAnySafeOutputEnabled above when adding new safe output types. func hasNonBuiltinSafeOutputsEnabled(safeOutputs *SafeOutputsConfig) bool { if safeOutputs == nil { return false } - // Custom safe-jobs are always non-builtin - if len(safeOutputs.Jobs) > 0 { + // Custom safe-jobs, scripts, and actions are always non-builtin + if len(safeOutputs.Jobs) > 0 || len(safeOutputs.Scripts) > 0 || len(safeOutputs.Actions) > 0 { return true } - // Custom scripts are always non-builtin - if len(safeOutputs.Scripts) > 0 { - return true - } - - // Custom actions are always non-builtin - if len(safeOutputs.Actions) > 0 { - return true - } - - // Check non-builtin pointer fields using the pre-computed list - val := reflect.ValueOf(safeOutputs).Elem() - for _, fieldName := range nonBuiltinSafeOutputFieldNames { - field := val.FieldByName(fieldName) - if field.IsValid() && !field.IsNil() { - return true - } - } - - return false + // Direct nil checks for non-builtin pointer fields (40 fields = 43 total minus 3 builtins: + // NoOp, MissingData, MissingTool). Includes CommentMemory which is attached via + // tools.comment-memory and is not in safeOutputFieldMapping. + return safeOutputs.CreateIssues != nil || + safeOutputs.CreateAgentSessions != nil || + safeOutputs.CreateDiscussions != nil || + safeOutputs.UpdateDiscussions != nil || + safeOutputs.CloseDiscussions != nil || + safeOutputs.CloseIssues != nil || + safeOutputs.ClosePullRequests != nil || + safeOutputs.MarkPullRequestAsReadyForReview != nil || + safeOutputs.AddComments != nil || + safeOutputs.CommentMemory != nil || + safeOutputs.CreatePullRequests != nil || + safeOutputs.CreatePullRequestReviewComments != nil || + safeOutputs.SubmitPullRequestReview != nil || + safeOutputs.ReplyToPullRequestReviewComment != nil || + safeOutputs.ResolvePullRequestReviewThread != nil || + safeOutputs.CreateCodeScanningAlerts != nil || + safeOutputs.AutofixCodeScanningAlert != nil || + safeOutputs.AddLabels != nil || + safeOutputs.RemoveLabels != nil || + safeOutputs.AddReviewer != nil || + safeOutputs.AssignMilestone != nil || + safeOutputs.AssignToAgent != nil || + safeOutputs.AssignToUser != nil || + safeOutputs.UnassignFromUser != nil || + safeOutputs.UpdateIssues != nil || + safeOutputs.UpdatePullRequests != nil || + safeOutputs.MergePullRequest != nil || + safeOutputs.PushToPullRequestBranch != nil || + safeOutputs.UploadAssets != nil || + safeOutputs.UploadArtifact != nil || + safeOutputs.UpdateRelease != nil || + safeOutputs.UpdateProjects != nil || + safeOutputs.CreateProjects != nil || + safeOutputs.CreateProjectStatusUpdates != nil || + safeOutputs.LinkSubIssue != nil || + safeOutputs.HideComment != nil || + safeOutputs.DispatchWorkflow != nil || + safeOutputs.DispatchRepository != nil || + safeOutputs.CallWorkflow != nil || + safeOutputs.SetIssueType != nil // 40th non-builtin field } // HasSafeOutputsEnabled checks if any safe-outputs are enabled diff --git a/pkg/workflow/safe_outputs_state_test.go b/pkg/workflow/safe_outputs_state_test.go new file mode 100644 index 00000000000..74d8f1bbea3 --- /dev/null +++ b/pkg/workflow/safe_outputs_state_test.go @@ -0,0 +1,64 @@ +//go:build !integration + +package workflow + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSafeOutputStateFieldCoverage verifies that hasAnySafeOutputEnabled and +// hasNonBuiltinSafeOutputsEnabled cover every pointer field listed in +// safeOutputFieldMapping. This acts as a regression guard to ensure that when +// a new safe output type is added to safeOutputFieldMapping, the developer is +// reminded (via a failing test) to also update the two direct-check functions. +func TestSafeOutputStateFieldCoverage(t *testing.T) { + // builtins excluded from hasNonBuiltinSafeOutputsEnabled + builtins := map[string]bool{ + "NoOp": true, + "MissingData": true, + "MissingTool": true, + } + + for fieldName := range safeOutputFieldMapping { + t.Run(fieldName, func(t *testing.T) { + // Build a SafeOutputsConfig with only this one field set to a non-nil value. + cfg := &SafeOutputsConfig{} + val := reflect.ValueOf(cfg).Elem() + field := val.FieldByName(fieldName) + require.True(t, field.IsValid(), + "safeOutputFieldMapping references unknown struct field %q; update the mapping or the struct", fieldName) + require.Equal(t, reflect.Ptr, field.Kind(), + "safeOutputFieldMapping field %q is expected to be a pointer type", fieldName) + + field.Set(reflect.New(field.Type().Elem())) + + // hasAnySafeOutputEnabled must return true for every field in the mapping. + assert.True(t, hasAnySafeOutputEnabled(cfg), + "hasAnySafeOutputEnabled missing check for field %q; add it to the direct nil-check list", fieldName) + + // hasNonBuiltinSafeOutputsEnabled must return true for every non-builtin field. + if !builtins[fieldName] { + assert.True(t, hasNonBuiltinSafeOutputsEnabled(cfg), + "hasNonBuiltinSafeOutputsEnabled missing check for non-builtin field %q; add it to the direct nil-check list", fieldName) + } + }) + } +} + +// TestSafeOutputStateCommentMemoryCoverage explicitly tests CommentMemory, which is +// attached to SafeOutputs via tools.comment-memory (not listed in safeOutputFieldMapping) +// and must be checked by both state inspection functions. +func TestSafeOutputStateCommentMemoryCoverage(t *testing.T) { + cfg := &SafeOutputsConfig{ + CommentMemory: &CommentMemoryConfig{}, + } + + assert.True(t, hasAnySafeOutputEnabled(cfg), + "hasAnySafeOutputEnabled should return true when CommentMemory is set") + assert.True(t, hasNonBuiltinSafeOutputsEnabled(cfg), + "hasNonBuiltinSafeOutputsEnabled should return true when CommentMemory is set") +}