Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pkg/cli/update_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ import (

"github.com/githubnext/gh-aw/pkg/console"
"github.com/githubnext/gh-aw/pkg/constants"
"github.com/githubnext/gh-aw/pkg/logger"
"github.com/githubnext/gh-aw/pkg/parser"
"github.com/githubnext/gh-aw/pkg/workflow"
"github.com/spf13/cobra"
)

var updateLog = logger.New("cli:update_command")

// NewUpdateCommand creates the update command
func NewUpdateCommand(validateEngine func(string) error) *cobra.Command {
cmd := &cobra.Command{
Expand Down Expand Up @@ -118,6 +121,8 @@ func checkExtensionUpdate(verbose bool) error {
// 2. Update workflows from source repositories (compiles each workflow after update)
// 3. Optionally create a PR
func UpdateWorkflowsWithExtensionCheck(workflowNames []string, allowMajor, force, verbose bool, engineOverride string, createPR bool, workflowsDir string, noStopAfter bool, stopAfter string) error {
updateLog.Printf("Starting update process: workflows=%v, allowMajor=%v, force=%v, createPR=%v", workflowNames, allowMajor, force, createPR)

// Step 1: Check for gh-aw extension updates
if err := checkExtensionUpdate(verbose); err != nil {
return fmt.Errorf("extension update check failed: %w", err)
Expand Down Expand Up @@ -224,6 +229,8 @@ func createUpdatePR(verbose bool) error {

// UpdateWorkflows updates workflows from their source repositories
func UpdateWorkflows(workflowNames []string, allowMajor, force, verbose bool, engineOverride string, workflowsDir string, noStopAfter bool, stopAfter string) error {
updateLog.Printf("Scanning for workflows with source field: dir=%s, filter=%v", workflowsDir, workflowNames)

// Use provided workflows directory or default
if workflowsDir == "" {
workflowsDir = getWorkflowsDir()
Expand All @@ -235,6 +242,8 @@ func UpdateWorkflows(workflowNames []string, allowMajor, force, verbose bool, en
return err
}

updateLog.Printf("Found %d workflows with source field", len(workflows))

if len(workflows) == 0 {
if len(workflowNames) > 0 {
return fmt.Errorf("no workflows found matching the specified names with source field")
Expand Down Expand Up @@ -394,12 +403,15 @@ func findWorkflowsWithSource(workflowsDir string, filterNames []string, verbose

// resolveLatestRef resolves the latest ref for a workflow source
func resolveLatestRef(repo, currentRef string, allowMajor, verbose bool) (string, error) {
updateLog.Printf("Resolving latest ref: repo=%s, currentRef=%s, allowMajor=%v", repo, currentRef, allowMajor)

if verbose {
fmt.Fprintln(os.Stderr, console.FormatVerboseMessage(fmt.Sprintf("Resolving latest ref for %s (current: %s)", repo, currentRef)))
}

// Check if current ref is a tag (looks like a semantic version)
if isSemanticVersionTag(currentRef) {
updateLog.Print("Current ref is semantic version tag, resolving latest release")
return resolveLatestRelease(repo, currentRef, allowMajor, verbose)
}

Expand All @@ -414,10 +426,12 @@ func resolveLatestRef(repo, currentRef string, allowMajor, verbose bool) (string
}

if isBranch {
updateLog.Printf("Current ref is branch: %s", currentRef)
return resolveBranchHead(repo, currentRef, verbose)
}

// Otherwise, use default branch
updateLog.Print("Using default branch for ref resolution")
return resolveDefaultBranchHead(repo, verbose)
}

Expand Down Expand Up @@ -546,6 +560,8 @@ func resolveDefaultBranchHead(repo string, verbose bool) (string, error) {

// updateWorkflow updates a single workflow from its source
func updateWorkflow(wf *workflowWithSource, allowMajor, force, verbose bool, engineOverride string, noStopAfter bool, stopAfter string) error {
updateLog.Printf("Updating workflow: name=%s, source=%s, force=%v", wf.Name, wf.SourceSpec, force)

if verbose {
fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("\nUpdating workflow: %s", wf.Name)))
fmt.Fprintln(os.Stderr, console.FormatVerboseMessage(fmt.Sprintf("Source: %s", wf.SourceSpec)))
Expand All @@ -554,6 +570,7 @@ func updateWorkflow(wf *workflowWithSource, allowMajor, force, verbose bool, eng
// Parse source spec
sourceSpec, err := parseSourceSpec(wf.SourceSpec)
if err != nil {
updateLog.Printf("Failed to parse source spec: %v", err)
return fmt.Errorf("failed to parse source spec: %w", err)
}

Expand All @@ -576,6 +593,8 @@ func updateWorkflow(wf *workflowWithSource, allowMajor, force, verbose bool, eng

// Check if update is needed
if !force && currentRef == latestRef {
updateLog.Printf("Workflow already at latest ref: %s, checking for local modifications", currentRef)

// Download the source content to check if local file has been modified
sourceContent, err := downloadWorkflowContent(sourceSpec.Repo, sourceSpec.Path, currentRef, verbose)
if err != nil {
Expand All @@ -595,11 +614,13 @@ func updateWorkflow(wf *workflowWithSource, allowMajor, force, verbose bool, eng

// Check if local file differs from source
if hasLocalModifications(string(sourceContent), string(currentContent), wf.SourceSpec, verbose) {
updateLog.Printf("Local modifications detected in workflow: %s", wf.Name)
fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Workflow %s is already up to date (%s)", wf.Name, currentRef)))
fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf("⚠️ Local copy of %s has been modified from source", wf.Name)))
return nil
}

updateLog.Printf("Workflow %s is up to date with no local modifications", wf.Name)
fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Workflow %s is already up to date (%s)", wf.Name, currentRef)))
return nil
}
Expand Down Expand Up @@ -631,11 +652,17 @@ func updateWorkflow(wf *workflowWithSource, allowMajor, force, verbose bool, eng
}

// Perform 3-way merge using git merge-file
updateLog.Printf("Performing 3-way merge for workflow: %s", wf.Name)
mergedContent, hasConflicts, err := MergeWorkflowContent(string(baseContent), string(currentContent), string(newContent), wf.SourceSpec, latestRef, verbose)
if err != nil {
updateLog.Printf("Merge failed for workflow %s: %v", wf.Name, err)
return fmt.Errorf("failed to merge workflow content: %w", err)
}

if hasConflicts {
updateLog.Printf("Merge conflicts detected in workflow: %s", wf.Name)
}

// Handle stop-after field modifications
if noStopAfter {
// Remove stop-after field if requested
Expand Down Expand Up @@ -675,10 +702,13 @@ func updateWorkflow(wf *workflowWithSource, allowMajor, force, verbose bool, eng
return nil // Not an error, but user needs to resolve conflicts
}

updateLog.Printf("Successfully updated workflow %s from %s to %s", wf.Name, currentRef, latestRef)
fmt.Fprintln(os.Stderr, console.FormatSuccessMessage(fmt.Sprintf("Updated %s from %s to %s", wf.Name, currentRef, latestRef)))

// Compile the updated workflow with refreshStopTime enabled
updateLog.Printf("Compiling updated workflow: %s", wf.Name)
if err := compileWorkflowWithRefresh(wf.Path, verbose, engineOverride, true); err != nil {
updateLog.Printf("Compilation failed for workflow %s: %v", wf.Name, err)
return fmt.Errorf("failed to compile updated workflow: %w", err)
}

Expand Down
15 changes: 15 additions & 0 deletions pkg/parser/github_urls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ import (
"path/filepath"
"strconv"
"strings"

"github.com/githubnext/gh-aw/pkg/logger"
)

var urlLog = logger.New("parser:github_urls")

// GitHubURLType represents the type of GitHub URL
type GitHubURLType string

Expand Down Expand Up @@ -45,9 +49,11 @@ type GitHubURLComponents struct {
// - Raw content: https://raw.githubusercontent.com/owner/repo/main/path/to/file.md
// - Enterprise URLs: https://github.example.com/owner/repo/...
func ParseGitHubURL(urlStr string) (*GitHubURLComponents, error) {
urlLog.Printf("Parsing GitHub URL: %s", urlStr)
// Parse the URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
urlLog.Printf("Failed to parse URL: %v", err)
return nil, fmt.Errorf("invalid URL: %w", err)
}

Expand All @@ -57,8 +63,11 @@ func ParseGitHubURL(urlStr string) (*GitHubURLComponents, error) {
return nil, fmt.Errorf("URL must include a host")
}

urlLog.Printf("Detected host: %s", host)

// Handle raw.githubusercontent.com specially
if host == "raw.githubusercontent.com" {
urlLog.Print("Detected raw.githubusercontent.com URL")
return parseRawGitHubContentURL(parsedURL)
}

Expand All @@ -81,23 +90,27 @@ func ParseGitHubURL(urlStr string) (*GitHubURLComponents, error) {
// Determine the type based on path structure
if len(pathParts) >= 4 {
urlType := pathParts[2]
urlLog.Printf("Detected URL type segment: %s for %s/%s", urlType, owner, repo)

switch urlType {
case "actions":
// Pattern: /owner/repo/actions/runs/12345678
if len(pathParts) >= 5 && pathParts[3] == "runs" {
urlLog.Print("Parsing GitHub Actions run URL")
return parseRunURL(host, owner, repo, pathParts[4:])
}

case "runs":
// Pattern: /owner/repo/runs/12345678 (short form)
if len(pathParts) >= 4 {
urlLog.Print("Parsing GitHub Actions run URL (short form)")
return parseRunURL(host, owner, repo, pathParts[3:])
}

case "pull":
// Pattern: /owner/repo/pull/123
if len(pathParts) >= 4 {
urlLog.Print("Parsing pull request URL")
prNumber, err := strconv.ParseInt(pathParts[3], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid PR number: %s", pathParts[3])
Expand Down Expand Up @@ -130,6 +143,7 @@ func ParseGitHubURL(urlStr string) (*GitHubURLComponents, error) {
case "blob", "tree", "raw":
// Pattern: /owner/repo/{blob|tree|raw}/ref/path/to/file
if len(pathParts) >= 5 {
urlLog.Printf("Parsing file URL (type=%s)", urlType)
ref := pathParts[3]
filePath := strings.Join(pathParts[4:], "/")

Expand All @@ -143,6 +157,7 @@ func ParseGitHubURL(urlStr string) (*GitHubURLComponents, error) {
urlTypeEnum = URLTypeRaw
}

urlLog.Printf("Parsed file URL: ref=%s, path=%s", ref, filePath)
return &GitHubURLComponents{
Host: host,
Owner: owner,
Expand Down
13 changes: 13 additions & 0 deletions pkg/parser/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func ExtractMCPConfigurations(frontmatter map[string]any, serverFilter string) (

// Check for safe-outputs configuration first (built-in MCP)
if safeOutputsSection, hasSafeOutputs := frontmatter["safe-outputs"]; hasSafeOutputs {
mcpLog.Print("Found safe-outputs configuration")
// Apply server filter if specified
if serverFilter == "" || strings.Contains(constants.SafeOutputsMCPServerID, strings.ToLower(serverFilter)) {
config := MCPServerConfig{
Expand Down Expand Up @@ -142,6 +143,7 @@ func ExtractMCPConfigurations(frontmatter map[string]any, serverFilter string) (

// Check for top-level safe-jobs configuration
if safeJobsSection, hasSafeJobs := frontmatter["safe-jobs"]; hasSafeJobs {
mcpLog.Print("Found safe-jobs configuration")
// Apply server filter if specified
if serverFilter == "" || strings.Contains(constants.SafeOutputsMCPServerID, strings.ToLower(serverFilter)) {
// Find existing safe-outputs config or create new one
Expand Down Expand Up @@ -176,6 +178,7 @@ func ExtractMCPConfigurations(frontmatter map[string]any, serverFilter string) (
// Get mcp-servers section from frontmatter
mcpServersSection, hasMCPServers := frontmatter["mcp-servers"]
if !hasMCPServers {
mcpLog.Print("No mcp-servers section found, checking for built-in tools")
// Also check tools section for built-in MCP tools (github, playwright)
toolsSection, hasTools := frontmatter["tools"]
if hasTools {
Expand All @@ -188,12 +191,14 @@ func ExtractMCPConfigurations(frontmatter map[string]any, serverFilter string) (
return nil, err
}
if config != nil {
mcpLog.Printf("Added built-in MCP tool: %s", toolName)
configs = append(configs, *config)
}
}
}
}
}
mcpLog.Printf("Extracted %d MCP configurations total", len(configs))
return configs, nil // No mcp-servers configured, but we might have safe-outputs and built-in tools
}

Expand Down Expand Up @@ -222,6 +227,7 @@ func ExtractMCPConfigurations(frontmatter map[string]any, serverFilter string) (
}

// Process custom MCP servers from mcp-servers section
mcpLog.Printf("Processing %d custom MCP servers", len(mcpServers))
for serverName, serverValue := range mcpServers {
// Apply server filter if specified
if serverFilter != "" && !strings.Contains(strings.ToLower(serverName), strings.ToLower(serverFilter)) {
Expand All @@ -239,9 +245,11 @@ func ExtractMCPConfigurations(frontmatter map[string]any, serverFilter string) (
return nil, fmt.Errorf("failed to parse MCP config for %s: %w", serverName, err)
}

mcpLog.Printf("Parsed custom MCP server: %s (type=%s)", serverName, config.Type)
configs = append(configs, config)
}

mcpLog.Printf("Extracted %d MCP configurations total", len(configs))
return configs, nil
}

Expand Down Expand Up @@ -512,10 +520,13 @@ func ParseMCPConfig(toolName string, mcpSection any, toolConfig map[string]any)
// Infer type from presence of fields
if _, hasURL := mcpConfig["url"]; hasURL {
config.Type = "http"
mcpLog.Printf("Inferred MCP type 'http' for tool %s based on url field", toolName)
} else if _, hasCommand := mcpConfig["command"]; hasCommand {
config.Type = "stdio"
mcpLog.Printf("Inferred MCP type 'stdio' for tool %s based on command field", toolName)
} else if _, hasContainer := mcpConfig["container"]; hasContainer {
config.Type = "stdio"
mcpLog.Printf("Inferred MCP type 'stdio' for tool %s based on container field", toolName)
} else {
return config, fmt.Errorf("unable to determine MCP type for tool '%s': missing type, url, command, or container", toolName)
}
Expand All @@ -531,11 +542,13 @@ func ParseMCPConfig(toolName string, mcpSection any, toolConfig map[string]any)
}

// Extract configuration based on type
mcpLog.Printf("Extracting %s configuration for tool: %s", config.Type, toolName)
switch config.Type {
case "stdio":
// Handle container field (simplified Docker run)
if container, hasContainer := mcpConfig["container"]; hasContainer {
if containerStr, ok := container.(string); ok {
mcpLog.Printf("Tool %s uses container: %s", toolName, containerStr)
config.Container = containerStr
config.Command = "docker"
config.Args = []string{"run", "--rm", "-i"}
Expand Down
14 changes: 14 additions & 0 deletions pkg/workflow/threat_detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"strings"

"github.com/githubnext/gh-aw/pkg/constants"
"github.com/githubnext/gh-aw/pkg/logger"
)

var threatLog = logger.New("workflow:threat_detection")

//go:embed templates/threat_detection.md
var defaultThreatDetectionPrompt string

Expand All @@ -22,12 +25,15 @@ type ThreatDetectionConfig struct {
// parseThreatDetectionConfig handles threat-detection configuration
func (c *Compiler) parseThreatDetectionConfig(outputMap map[string]any) *ThreatDetectionConfig {
if configData, exists := outputMap["threat-detection"]; exists {
threatLog.Print("Found threat-detection configuration")
// Handle boolean values
if boolVal, ok := configData.(bool); ok {
if !boolVal {
threatLog.Print("Threat detection explicitly disabled")
// When explicitly disabled, return nil
return nil
}
threatLog.Print("Threat detection enabled with default settings")
// When enabled as boolean, return empty config
return &ThreatDetectionConfig{}
}
Expand All @@ -38,6 +44,7 @@ func (c *Compiler) parseThreatDetectionConfig(outputMap map[string]any) *ThreatD
if enabled, exists := configMap["enabled"]; exists {
if enabledBool, ok := enabled.(bool); ok {
if !enabledBool {
threatLog.Print("Threat detection disabled via enabled field")
// When explicitly disabled, return nil
return nil
}
Expand Down Expand Up @@ -66,36 +73,43 @@ func (c *Compiler) parseThreatDetectionConfig(outputMap map[string]any) *ThreatD
// Handle boolean false to disable AI engine
if engineBool, ok := engine.(bool); ok {
if !engineBool {
threatLog.Print("Threat detection AI engine disabled")
// engine: false means no AI engine steps
threatConfig.EngineConfig = nil
threatConfig.EngineDisabled = true
}
} else if engineStr, ok := engine.(string); ok {
threatLog.Printf("Threat detection engine set to: %s", engineStr)
// Handle string format
threatConfig.EngineConfig = &EngineConfig{ID: engineStr}
} else if engineObj, ok := engine.(map[string]any); ok {
threatLog.Print("Parsing threat detection engine configuration")
// Handle object format - use extractEngineConfig logic
_, engineConfig := c.ExtractEngineConfig(map[string]any{"engine": engineObj})
threatConfig.EngineConfig = engineConfig
}
}

threatLog.Printf("Threat detection configured with custom prompt: %v, custom steps: %v", threatConfig.Prompt != "", len(threatConfig.Steps) > 0)
return threatConfig
}
}

// Default behavior: enabled if any safe-outputs are configured
threatLog.Print("Using default threat detection configuration")
return &ThreatDetectionConfig{}
}

// buildThreatDetectionJob creates the detection job
func (c *Compiler) buildThreatDetectionJob(data *WorkflowData, mainJobName string) (*Job, error) {
threatLog.Printf("Building threat detection job for main job: %s", mainJobName)
if data.SafeOutputs == nil || data.SafeOutputs.ThreatDetection == nil {
return nil, fmt.Errorf("threat detection is not enabled")
}

// Build steps using a more structured approach
steps := c.buildThreatDetectionSteps(data, mainJobName)
threatLog.Printf("Generated %d steps for threat detection job", len(steps))

// Generate agent concurrency configuration (same as main agent job)
agentConcurrency := GenerateJobConcurrencyConfig(data)
Expand Down
Loading
Loading