diff --git a/cmd/handle_analyze_manifest_test.go b/cmd/handle_analyze_manifest_test.go index 9131d35..c19fb80 100644 --- a/cmd/handle_analyze_manifest_test.go +++ b/cmd/handle_analyze_manifest_test.go @@ -294,3 +294,194 @@ jobs: t.Logf("Successfully analyzed manifest with %d findings", len(insights.Findings)) } + +func TestHandleAnalyzeManifestWithAllowedRules(t *testing.T) { + ctx := context.Background() + + analyzer, err := createTestAnalyzer(ctx) + require.NoError(t, err) + + vulnerableManifest := `name: Test Workflow +on: + pull_request_target: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Run test + run: echo "Testing ${{ github.event.pull_request.head.ref }}"` + + t.Run("without allowed_rules filter", func(t *testing.T) { + request := NewCallToolRequest("analyze_manifest", map[string]interface{}{ + "content": vulnerableManifest, + "manifest_type": "github-actions", + }) + + result, err := handleAnalyzeManifest(ctx, request, analyzer) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + contentText := extractTextFromContent(t, result.Content[0]) + require.NotEmpty(t, contentText) + + var insights struct { + Findings []results.Finding `json:"findings"` + Rules map[string]results.Rule `json:"rules"` + } + err = json.Unmarshal([]byte(contentText), &insights) + require.NoError(t, err) + + // Should have multiple findings including injection + assert.Greater(t, len(insights.Findings), 1, "Should have multiple findings without filter") + + // Verify injection rule is present + hasInjection := false + for _, finding := range insights.Findings { + if finding.RuleId == "injection" { + hasInjection = true + break + } + } + assert.True(t, hasInjection, "Should have injection finding") + + t.Logf("Found %d findings without filter", len(insights.Findings)) + }) + + t.Run("with allowed_rules filter for injection only", func(t *testing.T) { + request := NewCallToolRequest("analyze_manifest", map[string]interface{}{ + "content": vulnerableManifest, + "manifest_type": "github-actions", + "allowed_rules": []string{"injection"}, + }) + + result, err := handleAnalyzeManifest(ctx, request, analyzer) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + contentText := extractTextFromContent(t, result.Content[0]) + require.NotEmpty(t, contentText) + + var insights struct { + Findings []results.Finding `json:"findings"` + Rules map[string]results.Rule `json:"rules"` + } + err = json.Unmarshal([]byte(contentText), &insights) + require.NoError(t, err) + + // Should have only injection finding + assert.Len(t, insights.Findings, 1, "Should have only one finding with filter") + assert.Equal(t, "injection", insights.Findings[0].RuleId, "Should only have injection finding") + + // Should have only injection rule + assert.Len(t, insights.Rules, 1, "Should have only one rule with filter") + _, hasInjectionRule := insights.Rules["injection"] + assert.True(t, hasInjectionRule, "Should have injection rule in rules map") + + t.Logf("Found %d findings with allowed_rules filter", len(insights.Findings)) + }) + + t.Run("with allowed_rules filter for non-existent rule", func(t *testing.T) { + request := NewCallToolRequest("analyze_manifest", map[string]interface{}{ + "content": vulnerableManifest, + "manifest_type": "github-actions", + "allowed_rules": []string{"non_existent_rule"}, + }) + + result, err := handleAnalyzeManifest(ctx, request, analyzer) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + contentText := extractTextFromContent(t, result.Content[0]) + require.NotEmpty(t, contentText) + + var insights struct { + Findings []results.Finding `json:"findings"` + Rules map[string]results.Rule `json:"rules"` + } + err = json.Unmarshal([]byte(contentText), &insights) + require.NoError(t, err) + + // Should have no findings + assert.Empty(t, insights.Findings, "Should have no findings with non-existent rule filter") + assert.Empty(t, insights.Rules, "Should have no rules with non-existent rule filter") + + t.Logf("Found %d findings with non-existent rule filter", len(insights.Findings)) + }) +} + +func TestMCPServerGlobalAllowedRules(t *testing.T) { + ctx := context.Background() + + // Save original state + originalAllowedRules := allowedRules + defer func() { + allowedRules = originalAllowedRules + }() + + // Simulate global --allowed-rules injection flag + allowedRules = []string{"injection"} + + // Create analyzer with global allowed rules (simulating startMCPServer behavior) + testConfig := *config + if len(allowedRules) > 0 { + testConfig.AllowedRules = allowedRules + } + + vulnerableManifest := `name: Test Workflow +on: + pull_request_target: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Run test + run: echo "Testing ${{ github.event.pull_request.head.ref }}"` + + t.Run("handleAnalyzeManifest respects global allowed rules", func(t *testing.T) { + opaClient, err := newOpaWithConfig(ctx, &testConfig) + require.NoError(t, err) + manifestAnalyzer := analyze.NewAnalyzer(nil, nil, &noop.Format{}, &testConfig, opaClient) + + // Call without allowed_rules parameter (should inherit global) + request := NewCallToolRequest("analyze_manifest", map[string]interface{}{ + "content": vulnerableManifest, + "manifest_type": "github-actions", + }) + + result, err := handleAnalyzeManifest(ctx, request, manifestAnalyzer) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + contentText := extractTextFromContent(t, result.Content[0]) + require.NotEmpty(t, contentText) + + var insights struct { + Findings []results.Finding `json:"findings"` + Rules map[string]results.Rule `json:"rules"` + } + err = json.Unmarshal([]byte(contentText), &insights) + require.NoError(t, err) + + // Should have only injection finding due to global allowed rules + assert.Len(t, insights.Findings, 1, "Should have only injection finding with global allowed rules") + assert.Equal(t, "injection", insights.Findings[0].RuleId, "Should only have injection finding") + + t.Logf("Global allowed rules test: Found %d findings", len(insights.Findings)) + }) +} diff --git a/cmd/mcp_server.go b/cmd/mcp_server.go index 46d5be4..01786e3 100644 --- a/cmd/mcp_server.go +++ b/cmd/mcp_server.go @@ -59,13 +59,18 @@ The SCM access token should be provided via the --token flag or GH_TOKEN/GL_TOKE func startMCPServer(ctx context.Context) error { Format = "noop" - defaultConfig := *config - opaClient, err := newOpaWithConfig(ctx, &defaultConfig) + // Create default config with global allowedRules applied + mcpDefaultConfig := *config + // Apply global allowedRules setting to MCP server config + if len(allowedRules) > 0 { + mcpDefaultConfig.AllowedRules = allowedRules + } + opaClient, err := newOpaWithConfig(ctx, &mcpDefaultConfig) if err != nil { log.Error().Err(err).Msg("Failed to create manifest OPA client") return fmt.Errorf("failed to create manifest opa client: %w", err) } - manifestAnalyzer := analyze.NewAnalyzer(nil, nil, &noop.Format{}, &defaultConfig, opaClient) + manifestAnalyzer := analyze.NewAnalyzer(nil, nil, &noop.Format{}, &mcpDefaultConfig, opaClient) // Create MCP server s := server.NewMCPServer( @@ -95,6 +100,10 @@ func startMCPServer(ctx context.Context) error { mcp.WithBoolean("ignore_forks", mcp.Description("Ignore forked repositories"), ), + mcp.WithArray("allowed_rules", + mcp.Description("Filter to only run specified rules (optional)"), + mcp.WithStringItems(), + ), mcp.WithTitleAnnotation("CI/CD Pipeline Security Scan - Organization"), mcp.WithReadOnlyHintAnnotation(true), mcp.WithDestructiveHintAnnotation(false), @@ -119,6 +128,10 @@ func startMCPServer(ctx context.Context) error { mcp.WithString("ref", mcp.Description("Commit or branch to analyze"), ), + mcp.WithArray("allowed_rules", + mcp.Description("Filter to only run specified rules (optional)"), + mcp.WithStringItems(), + ), mcp.WithTitleAnnotation("CI/CD Pipeline Security Scan - Repository"), mcp.WithReadOnlyHintAnnotation(true), mcp.WithDestructiveHintAnnotation(false), @@ -149,6 +162,10 @@ func startMCPServer(ctx context.Context) error { mcp.WithString("regex", mcp.Description("Regex to check if the workflow is accessible in stale branches"), ), + mcp.WithArray("allowed_rules", + mcp.Description("Filter to only run specified rules (optional)"), + mcp.WithStringItems(), + ), mcp.WithTitleAnnotation("CI/CD Pipeline Security Scan - Stale Branches"), mcp.WithReadOnlyHintAnnotation(true), mcp.WithDestructiveHintAnnotation(false), @@ -163,6 +180,10 @@ func startMCPServer(ctx context.Context) error { mcp.Required(), mcp.Description("Local file system path to the repository"), ), + mcp.WithArray("allowed_rules", + mcp.Description("Filter to only run specified rules (optional)"), + mcp.WithStringItems(), + ), mcp.WithTitleAnnotation("CI/CD Pipeline Security Scan - Local Repository"), mcp.WithReadOnlyHintAnnotation(true), mcp.WithDestructiveHintAnnotation(false), @@ -244,6 +265,10 @@ Remember: This tool exists to prevent security vulnerabilities in generated code mcp.Description("Type of CI/CD manifest: 'github-actions' for GitHub Actions workflows, 'gitlab-ci' for GitLab CI, 'azure-pipelines' for Azure Pipelines, 'tekton' for Tekton pipelines"), mcp.Enum("github-actions", "gitlab-ci", "azure-pipelines", "tekton"), ), + mcp.WithArray("allowed_rules", + mcp.Description("Filter to only run specified rules (optional)"), + mcp.WithStringItems(), + ), mcp.WithTitleAnnotation("CI/CD Pipeline Security Scan - Manifest"), mcp.WithReadOnlyHintAnnotation(true), mcp.WithDestructiveHintAnnotation(false), @@ -252,12 +277,18 @@ Remember: This tool exists to prevent security vulnerabilities in generated code ) // Add tool handlers - s.AddTool(analyzeOrgTool, handleAnalyzeOrg) - s.AddTool(analyzeRepoTool, handleAnalyzeRepo) + s.AddTool(analyzeOrgTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return handleAnalyzeOrg(ctx, request, &mcpDefaultConfig) + }) + s.AddTool(analyzeRepoTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return handleAnalyzeRepo(ctx, request, &mcpDefaultConfig) + }) s.AddTool(analyzeLocalTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return handleAnalyzeLocal(ctx, request, opaClient) + return handleAnalyzeLocal(ctx, request, opaClient, &mcpDefaultConfig) + }) + s.AddTool(analyzeStaleBranchesTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return handleAnalyzeStaleBranches(ctx, request, &mcpDefaultConfig) }) - s.AddTool(analyzeStaleBranchesTool, handleAnalyzeStaleBranches) s.AddTool(analyzeManifestTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return handleAnalyzeManifest(ctx, request, manifestAnalyzer) }) @@ -273,7 +304,7 @@ Remember: This tool exists to prevent security vulnerabilities in generated code return nil } -func handleAnalyzeOrg(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleAnalyzeOrg(ctx context.Context, request mcp.CallToolRequest, defaultConfig *models.Config) (*mcp.CallToolResult, error) { token := viper.GetString("token") if token == "" { return mcp.NewToolResultError("SCM access token is required. Please provide it via --token flag or GH_TOKEN/GL_TOKEN environment variable"), nil @@ -288,9 +319,13 @@ func handleAnalyzeOrg(ctx context.Context, request mcp.CallToolRequest) (*mcp.Ca scmBaseURLStr := request.GetString("scm_base_url", "") threads := int(request.GetFloat("threads", 2)) ignoreForks := request.GetBool("ignore_forks", false) + allowedRulesParam := request.GetStringSlice("allowed_rules", []string{}) - requestConfig := *config + requestConfig := *defaultConfig requestConfig.IgnoreForks = ignoreForks + if len(allowedRulesParam) > 0 { + requestConfig.AllowedRules = allowedRulesParam + } analyzer, err := GetAnalyzerWithConfig(ctx, "analyze_org", scmProvider, scmBaseURLStr, token, &requestConfig) if err != nil { @@ -319,7 +354,7 @@ func handleAnalyzeOrg(ctx context.Context, request mcp.CallToolRequest) (*mcp.Ca return mcp.NewToolResultText(string(resultData)), nil } -func handleAnalyzeRepo(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleAnalyzeRepo(ctx context.Context, request mcp.CallToolRequest, defaultConfig *models.Config) (*mcp.CallToolResult, error) { token := viper.GetString("token") if token == "" { return mcp.NewToolResultError("SCM access token is required. Please provide it via --token flag or GH_TOKEN/GL_TOKEN environment variable"), nil @@ -333,8 +368,12 @@ func handleAnalyzeRepo(ctx context.Context, request mcp.CallToolRequest) (*mcp.C scmProvider := request.GetString("scm_provider", "github") scmBaseURLStr := request.GetString("scm_base_url", "") ref := request.GetString("ref", "HEAD") + allowedRulesParam := request.GetStringSlice("allowed_rules", []string{}) - requestConfig := *config + requestConfig := *defaultConfig + if len(allowedRulesParam) > 0 { + requestConfig.AllowedRules = allowedRulesParam + } analyzer, err := GetAnalyzerWithConfig(ctx, "analyze_repo", scmProvider, scmBaseURLStr, token, &requestConfig) if err != nil { @@ -360,13 +399,29 @@ func handleAnalyzeRepo(ctx context.Context, request mcp.CallToolRequest) (*mcp.C return mcp.NewToolResultText(string(resultData)), nil } -func handleAnalyzeLocal(ctx context.Context, request mcp.CallToolRequest, opaClient *opa.Opa) (*mcp.CallToolResult, error) { +func handleAnalyzeLocal(ctx context.Context, request mcp.CallToolRequest, opaClient *opa.Opa, defaultConfig *models.Config) (*mcp.CallToolResult, error) { path, err := request.RequireString("path") if err != nil { return mcp.NewToolResultError("path parameter is required"), nil } - requestConfig := *config + allowedRulesParam := request.GetStringSlice("allowed_rules", []string{}) + + requestConfig := *defaultConfig + if len(allowedRulesParam) > 0 { + requestConfig.AllowedRules = allowedRulesParam + } + + // Create a new OPA client with the request-specific config if allowed_rules is specified + var requestOpaClient *opa.Opa + if len(allowedRulesParam) > 0 { + requestOpaClient, err = newOpaWithConfig(ctx, &requestConfig) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to create OPA client with allowed rules: %v", err)), nil + } + } else { + requestOpaClient = opaClient + } localScmClient, err := local.NewGitSCMClient(ctx, path, nil) if err != nil { @@ -377,7 +432,7 @@ func handleAnalyzeLocal(ctx context.Context, request mcp.CallToolRequest, opaCli formatter := &noop.Format{} - analyzer := analyze.NewAnalyzer(localScmClient, localGitClient, formatter, &requestConfig, opaClient) + analyzer := analyze.NewAnalyzer(localScmClient, localGitClient, formatter, &requestConfig, requestOpaClient) analysisResults, err := analyzer.AnalyzeLocalRepo(ctx, path) if err != nil { @@ -398,7 +453,7 @@ func handleAnalyzeLocal(ctx context.Context, request mcp.CallToolRequest, opaCli return mcp.NewToolResultText(string(resultData)), nil } -func handleAnalyzeStaleBranches(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleAnalyzeStaleBranches(ctx context.Context, request mcp.CallToolRequest, defaultConfig *models.Config) (*mcp.CallToolResult, error) { token := viper.GetString("token") if token == "" { return mcp.NewToolResultError("SCM access token is required. Please provide it via --token flag or GH_TOKEN/GL_TOKEN environment variable"), nil @@ -414,6 +469,7 @@ func handleAnalyzeStaleBranches(ctx context.Context, request mcp.CallToolRequest threads := int(request.GetFloat("threads", 5)) expand := request.GetBool("expand", false) regexStr := request.GetString("regex", "pull_request_target") + allowedRulesParam := request.GetStringSlice("allowed_rules", []string{}) // Compile the regex reg, err := regexp.Compile(regexStr) @@ -421,7 +477,10 @@ func handleAnalyzeStaleBranches(ctx context.Context, request mcp.CallToolRequest return mcp.NewToolResultError(fmt.Sprintf("error compiling regex: %v", err)), nil } - requestConfig := *config + requestConfig := *defaultConfig + if len(allowedRulesParam) > 0 { + requestConfig.AllowedRules = allowedRulesParam + } analyzer, err := GetAnalyzerWithConfig(ctx, "analyze_repo_stale_branches", scmProvider, scmBaseURLStr, token, &requestConfig) if err != nil { @@ -454,9 +513,26 @@ func handleAnalyzeManifest(ctx context.Context, request mcp.CallToolRequest, ana } manifestType := request.GetString("manifest_type", "github-actions") + allowedRulesParam := request.GetStringSlice("allowed_rules", []string{}) + + // Create a new analyzer with allowed rules if specified + var requestAnalyzer *analyze.Analyzer + if len(allowedRulesParam) > 0 { + requestConfig := *config + requestConfig.AllowedRules = allowedRulesParam + + requestOpaClient, err := newOpaWithConfig(ctx, &requestConfig) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to create OPA client with allowed rules: %v", err)), nil + } + + requestAnalyzer = analyze.NewAnalyzer(nil, nil, &noop.Format{}, &requestConfig, requestOpaClient) + } else { + requestAnalyzer = analyzer + } manifestReader := strings.NewReader(content) - analysisResults, err := analyzer.AnalyzeManifest(ctx, manifestReader, manifestType) + analysisResults, err := requestAnalyzer.AnalyzeManifest(ctx, manifestReader, manifestType) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to analyze manifest: %v", err)), nil }