Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/cli/mcp_inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func InspectWorkflowMCP(workflowFile string, serverFilter string, toolFilter str

// Process imports from frontmatter to merge imported MCP servers
markdownDir := filepath.Dir(workflowPath)
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir)
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir, nil)
if err != nil {
return fmt.Errorf("failed to process imports from frontmatter: %w", err)
}
Expand Down Expand Up @@ -295,7 +295,7 @@ func spawnMCPInspector(workflowFile string, serverFilter string, verbose bool) e

// Process imports from frontmatter to merge imported MCP servers
markdownDir := filepath.Dir(workflowPath)
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir)
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir, nil)
if err != nil {
return fmt.Errorf("failed to process imports from frontmatter: %w", err)
}
Expand Down
104 changes: 92 additions & 12 deletions pkg/parser/frontmatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func ExtractMarkdown(filePath string) (string, error) {
// ProcessImportsFromFrontmatter processes imports field from frontmatter
// Returns merged tools and engines from imported files
func ProcessImportsFromFrontmatter(frontmatter map[string]any, baseDir string) (mergedTools string, mergedEngines []string, err error) {
result, err := ProcessImportsFromFrontmatterWithManifest(frontmatter, baseDir)
result, err := ProcessImportsFromFrontmatterWithManifest(frontmatter, baseDir, nil)
if err != nil {
return "", nil, err
}
Expand All @@ -389,7 +389,7 @@ type importQueueItem struct {
// ProcessImportsFromFrontmatterWithManifest processes imports field from frontmatter
// Returns result containing merged tools, engines, markdown content, and list of imported files
// Uses BFS traversal with queues for deterministic ordering and cycle detection
func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseDir string) (*ImportsResult, error) {
func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseDir string, cache *ImportCache) (*ImportsResult, error) {
// Check if imports field exists
importsField, exists := frontmatter["imports"]
if !exists {
Expand Down Expand Up @@ -451,7 +451,7 @@ func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseD
}

// Resolve import path (supports workflowspec format)
fullPath, err := resolveIncludePath(filePath, baseDir)
fullPath, err := resolveIncludePath(filePath, baseDir, cache)
if err != nil {
return nil, fmt.Errorf("failed to resolve import '%s': %w", filePath, err)
}
Expand Down Expand Up @@ -558,7 +558,7 @@ func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseD
}

// Resolve nested import path relative to the workflows directory, not the nested file's directory
nestedFullPath, err := resolveIncludePath(nestedFilePath, baseDir)
nestedFullPath, err := resolveIncludePath(nestedFilePath, baseDir, cache)
if err != nil {
return nil, fmt.Errorf("failed to resolve nested import '%s' from '%s': %w", nestedFilePath, item.fullPath, err)
}
Expand Down Expand Up @@ -724,7 +724,7 @@ func processIncludesWithVisited(content, baseDir string, extractTools bool, visi
}

// Resolve file path first to get the canonical path
fullPath, err := resolveIncludePath(filePath, baseDir)
fullPath, err := resolveIncludePath(filePath, baseDir, nil)
if err != nil {
if isOptional {
// For optional includes, show a friendly informational message to stdout
Expand Down Expand Up @@ -796,12 +796,12 @@ func isUnderWorkflowsDirectory(filePath string) bool {
}

// resolveIncludePath resolves include path based on workflowspec format or relative path
func resolveIncludePath(filePath, baseDir string) (string, error) {
func resolveIncludePath(filePath, baseDir string, cache *ImportCache) (string, error) {
// Check if this is a workflowspec (contains owner/repo/path format)
// Format: owner/repo/path@ref or owner/repo/path@ref#section
if isWorkflowSpec(filePath) {
// Download from GitHub using workflowspec
return downloadIncludeFromWorkflowSpec(filePath)
// Download from GitHub using workflowspec (with cache support)
return downloadIncludeFromWorkflowSpec(filePath, cache)
}

// Regular path, resolve relative to base directory
Expand Down Expand Up @@ -850,7 +850,8 @@ func isWorkflowSpec(path string) bool {
}

// downloadIncludeFromWorkflowSpec downloads an include file from GitHub using workflowspec
func downloadIncludeFromWorkflowSpec(spec string) (string, error) {
// It first checks the cache, and only downloads if not cached
func downloadIncludeFromWorkflowSpec(spec string, cache *ImportCache) (string, error) {
// Parse the workflowspec
// Format: owner/repo/path@ref or owner/repo/path@ref#section

Expand Down Expand Up @@ -880,13 +881,47 @@ func downloadIncludeFromWorkflowSpec(spec string) (string, error) {
repo := slashParts[1]
filePath := strings.Join(slashParts[2:], "/")

// Resolve ref to SHA for cache lookup
var sha string
if cache != nil {
// Only resolve SHA if we're using the cache
resolvedSHA, err := resolveRefToSHA(owner, repo, ref)
if err != nil {
Comment thread
pelikhan marked this conversation as resolved.
// If the error is an authentication error, propagate it immediately
lowerErr := strings.ToLower(err.Error())
if strings.Contains(lowerErr, "auth") || strings.Contains(lowerErr, "unauthoriz") || strings.Contains(lowerErr, "forbidden") || strings.Contains(lowerErr, "token") || strings.Contains(lowerErr, "permission denied") {
return "", fmt.Errorf("failed to resolve ref to SHA due to authentication error: %w", err)
}
log.Printf("Failed to resolve ref to SHA, will skip cache: %v", err)
// Continue without caching if SHA resolution fails
} else {
sha = resolvedSHA
// Check cache using SHA
if cachedPath, found := cache.Get(owner, repo, filePath, sha); found {
log.Printf("Using cached import: %s/%s/%s@%s (SHA: %s)", owner, repo, filePath, ref, sha)
return cachedPath, nil
}
}
}

// Download the file content from GitHub
content, err := downloadFileFromGitHub(owner, repo, filePath, ref)
if err != nil {
return "", fmt.Errorf("failed to download include from %s: %w", spec, err)
}

// Create a temporary file to store the downloaded content
// If cache is available and we have a SHA, store in cache
if cache != nil && sha != "" {
cachedPath, err := cache.Set(owner, repo, filePath, sha, content)
if err != nil {
log.Printf("Failed to cache import: %v", err)
// Don't fail the compilation, fall back to temp file
} else {
return cachedPath, nil
}
}

// Fallback: Create a temporary file to store the downloaded content
tempFile, err := os.CreateTemp("", "gh-aw-include-*.md")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
Expand All @@ -906,7 +941,52 @@ func downloadIncludeFromWorkflowSpec(spec string) (string, error) {
return tempFile.Name(), nil
}

// downloadFileFromGitHub downloads a file from GitHub using gh CLI
// resolveRefToSHA resolves a git ref (branch, tag, or SHA) to its commit SHA
func resolveRefToSHA(owner, repo, ref string) (string, error) {
// If ref is already a full SHA (40 hex characters), return it as-is
if len(ref) == 40 && isHexString(ref) {
return ref, nil
}

// Use gh CLI to get the commit SHA for the ref
// This works for branches, tags, and short SHAs
cmd := exec.Command("gh", "api", fmt.Sprintf("/repos/%s/%s/commits/%s", owner, repo, ref), "--jq", ".sha")

output, err := cmd.CombinedOutput()
if err != nil {
outputStr := string(output)
if strings.Contains(outputStr, "GH_TOKEN") || strings.Contains(outputStr, "authentication") || strings.Contains(outputStr, "not logged into") {
return "", fmt.Errorf("failed to resolve ref to SHA: GitHub authentication required. Please run 'gh auth login' or set GH_TOKEN/GITHUB_TOKEN environment variable: %w", err)
}
return "", fmt.Errorf("failed to resolve ref %s to SHA for %s/%s: %s: %w", ref, owner, repo, strings.TrimSpace(outputStr), err)
}

sha := strings.TrimSpace(string(output))
if sha == "" {
return "", fmt.Errorf("empty SHA returned for ref %s in %s/%s", ref, owner, repo)
}

// Validate it's a valid SHA (40 hex characters)
if len(sha) != 40 || !isHexString(sha) {
return "", fmt.Errorf("invalid SHA format returned: %s", sha)
}

return sha, nil
}

// isHexString checks if a string contains only hexadecimal characters
func isHexString(s string) bool {
if len(s) == 0 {
return false
}
for _, c := range s {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
Comment thread
pelikhan marked this conversation as resolved.

func downloadFileFromGitHub(owner, repo, path, ref string) ([]byte, error) {
// Use go-gh/v2 to download the file
stdout, stderr, err := gh.Exec("api", fmt.Sprintf("/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), "--jq", ".content")
Expand Down Expand Up @@ -1321,7 +1401,7 @@ func processIncludesForField(content, baseDir string, extractFunc func(string) (
}

// Resolve file path
fullPath, err := resolveIncludePath(filePath, baseDir)
fullPath, err := resolveIncludePath(filePath, baseDir, nil)
if err != nil {
if isOptional {
// For optional includes, skip extraction
Expand Down
2 changes: 1 addition & 1 deletion pkg/parser/frontmatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ func TestResolveIncludePath(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := resolveIncludePath(tt.filePath, tt.baseDir)
result, err := resolveIncludePath(tt.filePath, tt.baseDir, nil)

if tt.wantErr {
if err == nil {
Expand Down
165 changes: 165 additions & 0 deletions pkg/parser/import_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package parser

import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/githubnext/gh-aw/pkg/logger"
)

var importCacheLog = logger.New("parser:import_cache")

const (
// ImportCacheDir is the directory where cached imports are stored
ImportCacheDir = ".github/aw/imports"
)

// sanitizePath converts a file path to a safe filename by using filepath.Clean
// and replacing directory separators with underscores
func sanitizePath(path string) string {
// Clean the path to remove any ".." or other suspicious elements
cleaned := filepath.Clean(path)
// Replace directory separators with underscores to create a flat filename
// This prevents directory traversal while preserving path uniqueness
sanitized := strings.ReplaceAll(cleaned, string(filepath.Separator), "_")
return sanitized
}

// validatePathComponents validates that path components don't contain malicious sequences
func validatePathComponents(owner, repo, path, sha string) error {
components := []string{owner, repo, path, sha}
for _, comp := range components {
// Check for empty components
if comp == "" {
return fmt.Errorf("empty component in path")
}
// Check for path traversal attempts
if strings.Contains(comp, "..") {
return fmt.Errorf("component contains '..' sequence: %s", comp)
}
// Check for absolute paths
if filepath.IsAbs(comp) {
return fmt.Errorf("component is absolute path: %s", comp)
}
}
return nil
}

// ImportCache manages cached imported workflow files
type ImportCache struct {
baseDir string // Base directory for cache (typically repo root)
}

// NewImportCache creates a new import cache instance
func NewImportCache(repoRoot string) *ImportCache {
importCacheLog.Printf("Creating import cache with base dir: %s", repoRoot)
return &ImportCache{
baseDir: repoRoot,
}
}

// Get retrieves a cached file path if it exists
// sha parameter should be the resolved commit SHA
func (c *ImportCache) Get(owner, repo, path, sha string) (string, bool) {
// Use SHA-based approach: cache files are stored by commit SHA
// Cache path: .github/aw/imports/owner/repo/sha/sanitized_path.md
sanitizedPath := sanitizePath(path)
relativeCachePath := filepath.Join(ImportCacheDir, owner, repo, sha, sanitizedPath)
fullCachePath := filepath.Join(c.baseDir, relativeCachePath)

// Check if the cached file exists
if _, err := os.Stat(fullCachePath); err != nil {
if os.IsNotExist(err) {
importCacheLog.Printf("Cache miss: %s/%s/%s@%s", owner, repo, path, sha)
} else {
// Log other types of errors (permissions, I/O issues, etc.)
importCacheLog.Printf("Cache access error for %s/%s/%s@%s: %v", owner, repo, path, sha, err)
}
return "", false
}
Comment on lines +72 to +81
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Get method doesn't handle other types of filesystem errors beyond os.IsNotExist. If the file exists but is inaccessible due to permissions or other I/O errors, the method will still return false (cache miss) rather than indicating an error occurred.

Consider differentiating between:

  • File doesn't exist (cache miss, expected behavior)
  • File exists but can't be accessed (error condition)

This would help with debugging cache-related issues.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot add logging about error type

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Added detailed error logging in Get() method. Now distinguishes between cache miss (file doesn't exist) and other errors like permission issues or I/O errors. Commit: [commit hash in progress report]


importCacheLog.Printf("Cache hit: %s/%s/%s@%s -> %s", owner, repo, path, sha, fullCachePath)
return fullCachePath, true
}

// Set stores a new cache entry by saving the content to the cache directory
// sha parameter should be the resolved commit SHA
func (c *ImportCache) Set(owner, repo, path, sha string, content []byte) (string, error) {
// Validate file size (max 10MB)
const maxFileSize = 10 * 1024 * 1024
if len(content) > maxFileSize {
return "", fmt.Errorf("file size (%d bytes) exceeds maximum allowed size (%d bytes)", len(content), maxFileSize)
}

// Validate path components to prevent path traversal
if err := validatePathComponents(owner, repo, path, sha); err != nil {
return "", fmt.Errorf("invalid path components: %w", err)
}

// Use SHA in path for consistent caching
// This ensures that different refs pointing to the same commit reuse the same cache
sanitizedPath := sanitizePath(path)
relativeCachePath := filepath.Join(ImportCacheDir, owner, repo, sha, sanitizedPath)
fullCachePath := filepath.Join(c.baseDir, relativeCachePath)

// Ensure directory exists
dir := filepath.Dir(fullCachePath)
if err := os.MkdirAll(dir, 0755); err != nil {
importCacheLog.Printf("Failed to create cache directory: %v", err)
return "", err
}

// Ensure .gitattributes file exists in cache root
if err := c.ensureGitAttributes(); err != nil {
importCacheLog.Printf("Failed to ensure .gitattributes: %v", err)
// Non-fatal error - continue with caching
}

// Write content to cache file
if err := os.WriteFile(fullCachePath, content, 0644); err != nil {
importCacheLog.Printf("Failed to write cache file: %v", err)
Comment on lines +87 to +122
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description and diff mention creating a .gitattributes file in the cache directory to mark cached files as generated and use 'ours' merge strategy. However, the actual implementation in import_cache.go (80 lines) doesn't include the ensureGitAttributes() function shown in the diff (lines 66-70 and 87-117).

The test file import_cache_test.go also has test cases checking for .gitattributes in the diff (lines 95-113), but the actual file doesn't include these tests.

This creates a discrepancy between what the PR claims to implement and what's actually in the code. The .gitattributes functionality needs to be either:

  1. Implemented as described in the PR
  2. Removed from the PR description if it's not intended to be part of this change

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .gitattributes functionality is fully implemented. The ensureGitAttributes() method automatically creates the file when the cache directory is first used. See commit b89d7b1 for the initial implementation and the latest commit for security enhancements.

Comment on lines +87 to +122
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache stores files from remote GitHub repositories without validating their content or size. This could lead to several issues:

  1. Disk space exhaustion: A malicious or very large file could fill up the disk
  2. Path traversal: While the path is sanitized, there's no validation that the owner/repo/sha/path components don't contain malicious sequences
  3. No integrity checking: The cached file could be modified outside the application, and there's no verification that it matches the expected SHA

Consider adding:

  • Maximum file size validation before caching
  • Path component validation (e.g., no .., no absolute paths)
  • Content verification (checksum or signature) when retrieving from cache

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback use file.Clean()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Added comprehensive validation:

  • File size limit (10MB max) to prevent disk exhaustion
  • Path component validation using validatePathComponents() to prevent ".." sequences and absolute paths
  • Path sanitization using filepath.Clean() for security

Commit: [commit hash in progress report]

return "", err
}

importCacheLog.Printf("Cached import: %s/%s/%s@%s -> %s", owner, repo, path, sha, fullCachePath)
return fullCachePath, nil
}

// GetCacheDir returns the base cache directory path
func (c *ImportCache) GetCacheDir() string {
return filepath.Join(c.baseDir, ImportCacheDir)
}

// ensureGitAttributes creates the .gitattributes file in the cache directory if it doesn't exist
func (c *ImportCache) ensureGitAttributes() error {
gitAttributesPath := filepath.Join(c.GetCacheDir(), ".gitattributes")

// Check if .gitattributes already exists
if _, err := os.Stat(gitAttributesPath); err == nil {
// File already exists, nothing to do
return nil
}

// Ensure cache root directory exists
cacheDir := c.GetCacheDir()
if err := os.MkdirAll(cacheDir, 0755); err != nil {
return err
}

// Create .gitattributes file with content
content := `# Mark all cached import files as generated
* linguist-generated=true

# Use 'ours' merge strategy to keep local cached versions
* merge=ours
`

if err := os.WriteFile(gitAttributesPath, []byte(content), 0644); err != nil {
return err
}

importCacheLog.Printf("Created .gitattributes in cache directory: %s", gitAttributesPath)
return nil
}
Loading
Loading