diff --git a/pkg/cli/mcp_inspect.go b/pkg/cli/mcp_inspect.go index e1cb6ee9fc2..7928aeafbf6 100644 --- a/pkg/cli/mcp_inspect.go +++ b/pkg/cli/mcp_inspect.go @@ -151,7 +151,7 @@ func InspectWorkflowMCP(workflowFile string, serverFilter string, toolFilter str // Process imports from frontmatter to merge imported MCP servers markdownDir := filepath.Dir(workflowPath) - importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir) + importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir, nil) if err != nil { return fmt.Errorf("failed to process imports from frontmatter: %w", err) } @@ -295,7 +295,7 @@ func spawnMCPInspector(workflowFile string, serverFilter string, verbose bool) e // Process imports from frontmatter to merge imported MCP servers markdownDir := filepath.Dir(workflowPath) - importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir) + importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir, nil) if err != nil { return fmt.Errorf("failed to process imports from frontmatter: %w", err) } diff --git a/pkg/parser/frontmatter.go b/pkg/parser/frontmatter.go index 09054f8e3ad..7ab07a9d25d 100644 --- a/pkg/parser/frontmatter.go +++ b/pkg/parser/frontmatter.go @@ -371,7 +371,7 @@ func ExtractMarkdown(filePath string) (string, error) { // ProcessImportsFromFrontmatter processes imports field from frontmatter // Returns merged tools and engines from imported files func ProcessImportsFromFrontmatter(frontmatter map[string]any, baseDir string) (mergedTools string, mergedEngines []string, err error) { - result, err := ProcessImportsFromFrontmatterWithManifest(frontmatter, baseDir) + result, err := ProcessImportsFromFrontmatterWithManifest(frontmatter, baseDir, nil) if err != nil { return "", nil, err } @@ -389,7 +389,7 @@ type importQueueItem struct { // ProcessImportsFromFrontmatterWithManifest processes imports field from frontmatter // Returns result containing merged tools, engines, markdown content, and list of imported files // Uses BFS traversal with queues for deterministic ordering and cycle detection -func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseDir string) (*ImportsResult, error) { +func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseDir string, cache *ImportCache) (*ImportsResult, error) { // Check if imports field exists importsField, exists := frontmatter["imports"] if !exists { @@ -451,7 +451,7 @@ func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseD } // Resolve import path (supports workflowspec format) - fullPath, err := resolveIncludePath(filePath, baseDir) + fullPath, err := resolveIncludePath(filePath, baseDir, cache) if err != nil { return nil, fmt.Errorf("failed to resolve import '%s': %w", filePath, err) } @@ -558,7 +558,7 @@ func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseD } // Resolve nested import path relative to the workflows directory, not the nested file's directory - nestedFullPath, err := resolveIncludePath(nestedFilePath, baseDir) + nestedFullPath, err := resolveIncludePath(nestedFilePath, baseDir, cache) if err != nil { return nil, fmt.Errorf("failed to resolve nested import '%s' from '%s': %w", nestedFilePath, item.fullPath, err) } @@ -724,7 +724,7 @@ func processIncludesWithVisited(content, baseDir string, extractTools bool, visi } // Resolve file path first to get the canonical path - fullPath, err := resolveIncludePath(filePath, baseDir) + fullPath, err := resolveIncludePath(filePath, baseDir, nil) if err != nil { if isOptional { // For optional includes, show a friendly informational message to stdout @@ -796,12 +796,12 @@ func isUnderWorkflowsDirectory(filePath string) bool { } // resolveIncludePath resolves include path based on workflowspec format or relative path -func resolveIncludePath(filePath, baseDir string) (string, error) { +func resolveIncludePath(filePath, baseDir string, cache *ImportCache) (string, error) { // Check if this is a workflowspec (contains owner/repo/path format) // Format: owner/repo/path@ref or owner/repo/path@ref#section if isWorkflowSpec(filePath) { - // Download from GitHub using workflowspec - return downloadIncludeFromWorkflowSpec(filePath) + // Download from GitHub using workflowspec (with cache support) + return downloadIncludeFromWorkflowSpec(filePath, cache) } // Regular path, resolve relative to base directory @@ -850,7 +850,8 @@ func isWorkflowSpec(path string) bool { } // downloadIncludeFromWorkflowSpec downloads an include file from GitHub using workflowspec -func downloadIncludeFromWorkflowSpec(spec string) (string, error) { +// It first checks the cache, and only downloads if not cached +func downloadIncludeFromWorkflowSpec(spec string, cache *ImportCache) (string, error) { // Parse the workflowspec // Format: owner/repo/path@ref or owner/repo/path@ref#section @@ -880,13 +881,47 @@ func downloadIncludeFromWorkflowSpec(spec string) (string, error) { repo := slashParts[1] filePath := strings.Join(slashParts[2:], "/") + // Resolve ref to SHA for cache lookup + var sha string + if cache != nil { + // Only resolve SHA if we're using the cache + resolvedSHA, err := resolveRefToSHA(owner, repo, ref) + if err != nil { + // If the error is an authentication error, propagate it immediately + lowerErr := strings.ToLower(err.Error()) + if strings.Contains(lowerErr, "auth") || strings.Contains(lowerErr, "unauthoriz") || strings.Contains(lowerErr, "forbidden") || strings.Contains(lowerErr, "token") || strings.Contains(lowerErr, "permission denied") { + return "", fmt.Errorf("failed to resolve ref to SHA due to authentication error: %w", err) + } + log.Printf("Failed to resolve ref to SHA, will skip cache: %v", err) + // Continue without caching if SHA resolution fails + } else { + sha = resolvedSHA + // Check cache using SHA + if cachedPath, found := cache.Get(owner, repo, filePath, sha); found { + log.Printf("Using cached import: %s/%s/%s@%s (SHA: %s)", owner, repo, filePath, ref, sha) + return cachedPath, nil + } + } + } + // Download the file content from GitHub content, err := downloadFileFromGitHub(owner, repo, filePath, ref) if err != nil { return "", fmt.Errorf("failed to download include from %s: %w", spec, err) } - // Create a temporary file to store the downloaded content + // If cache is available and we have a SHA, store in cache + if cache != nil && sha != "" { + cachedPath, err := cache.Set(owner, repo, filePath, sha, content) + if err != nil { + log.Printf("Failed to cache import: %v", err) + // Don't fail the compilation, fall back to temp file + } else { + return cachedPath, nil + } + } + + // Fallback: Create a temporary file to store the downloaded content tempFile, err := os.CreateTemp("", "gh-aw-include-*.md") if err != nil { return "", fmt.Errorf("failed to create temp file: %w", err) @@ -906,7 +941,52 @@ func downloadIncludeFromWorkflowSpec(spec string) (string, error) { return tempFile.Name(), nil } -// downloadFileFromGitHub downloads a file from GitHub using gh CLI +// resolveRefToSHA resolves a git ref (branch, tag, or SHA) to its commit SHA +func resolveRefToSHA(owner, repo, ref string) (string, error) { + // If ref is already a full SHA (40 hex characters), return it as-is + if len(ref) == 40 && isHexString(ref) { + return ref, nil + } + + // Use gh CLI to get the commit SHA for the ref + // This works for branches, tags, and short SHAs + cmd := exec.Command("gh", "api", fmt.Sprintf("/repos/%s/%s/commits/%s", owner, repo, ref), "--jq", ".sha") + + output, err := cmd.CombinedOutput() + if err != nil { + outputStr := string(output) + if strings.Contains(outputStr, "GH_TOKEN") || strings.Contains(outputStr, "authentication") || strings.Contains(outputStr, "not logged into") { + return "", fmt.Errorf("failed to resolve ref to SHA: GitHub authentication required. Please run 'gh auth login' or set GH_TOKEN/GITHUB_TOKEN environment variable: %w", err) + } + return "", fmt.Errorf("failed to resolve ref %s to SHA for %s/%s: %s: %w", ref, owner, repo, strings.TrimSpace(outputStr), err) + } + + sha := strings.TrimSpace(string(output)) + if sha == "" { + return "", fmt.Errorf("empty SHA returned for ref %s in %s/%s", ref, owner, repo) + } + + // Validate it's a valid SHA (40 hex characters) + if len(sha) != 40 || !isHexString(sha) { + return "", fmt.Errorf("invalid SHA format returned: %s", sha) + } + + return sha, nil +} + +// isHexString checks if a string contains only hexadecimal characters +func isHexString(s string) bool { + if len(s) == 0 { + return false + } + for _, c := range s { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) { + return false + } + } + return true +} + func downloadFileFromGitHub(owner, repo, path, ref string) ([]byte, error) { // Use go-gh/v2 to download the file stdout, stderr, err := gh.Exec("api", fmt.Sprintf("/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), "--jq", ".content") @@ -1321,7 +1401,7 @@ func processIncludesForField(content, baseDir string, extractFunc func(string) ( } // Resolve file path - fullPath, err := resolveIncludePath(filePath, baseDir) + fullPath, err := resolveIncludePath(filePath, baseDir, nil) if err != nil { if isOptional { // For optional includes, skip extraction diff --git a/pkg/parser/frontmatter_test.go b/pkg/parser/frontmatter_test.go index eb8f5a70bb4..57f18adf453 100644 --- a/pkg/parser/frontmatter_test.go +++ b/pkg/parser/frontmatter_test.go @@ -865,7 +865,7 @@ func TestResolveIncludePath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := resolveIncludePath(tt.filePath, tt.baseDir) + result, err := resolveIncludePath(tt.filePath, tt.baseDir, nil) if tt.wantErr { if err == nil { diff --git a/pkg/parser/import_cache.go b/pkg/parser/import_cache.go new file mode 100644 index 00000000000..be4cf704452 --- /dev/null +++ b/pkg/parser/import_cache.go @@ -0,0 +1,165 @@ +package parser + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/githubnext/gh-aw/pkg/logger" +) + +var importCacheLog = logger.New("parser:import_cache") + +const ( + // ImportCacheDir is the directory where cached imports are stored + ImportCacheDir = ".github/aw/imports" +) + +// sanitizePath converts a file path to a safe filename by using filepath.Clean +// and replacing directory separators with underscores +func sanitizePath(path string) string { + // Clean the path to remove any ".." or other suspicious elements + cleaned := filepath.Clean(path) + // Replace directory separators with underscores to create a flat filename + // This prevents directory traversal while preserving path uniqueness + sanitized := strings.ReplaceAll(cleaned, string(filepath.Separator), "_") + return sanitized +} + +// validatePathComponents validates that path components don't contain malicious sequences +func validatePathComponents(owner, repo, path, sha string) error { + components := []string{owner, repo, path, sha} + for _, comp := range components { + // Check for empty components + if comp == "" { + return fmt.Errorf("empty component in path") + } + // Check for path traversal attempts + if strings.Contains(comp, "..") { + return fmt.Errorf("component contains '..' sequence: %s", comp) + } + // Check for absolute paths + if filepath.IsAbs(comp) { + return fmt.Errorf("component is absolute path: %s", comp) + } + } + return nil +} + +// ImportCache manages cached imported workflow files +type ImportCache struct { + baseDir string // Base directory for cache (typically repo root) +} + +// NewImportCache creates a new import cache instance +func NewImportCache(repoRoot string) *ImportCache { + importCacheLog.Printf("Creating import cache with base dir: %s", repoRoot) + return &ImportCache{ + baseDir: repoRoot, + } +} + +// Get retrieves a cached file path if it exists +// sha parameter should be the resolved commit SHA +func (c *ImportCache) Get(owner, repo, path, sha string) (string, bool) { + // Use SHA-based approach: cache files are stored by commit SHA + // Cache path: .github/aw/imports/owner/repo/sha/sanitized_path.md + sanitizedPath := sanitizePath(path) + relativeCachePath := filepath.Join(ImportCacheDir, owner, repo, sha, sanitizedPath) + fullCachePath := filepath.Join(c.baseDir, relativeCachePath) + + // Check if the cached file exists + if _, err := os.Stat(fullCachePath); err != nil { + if os.IsNotExist(err) { + importCacheLog.Printf("Cache miss: %s/%s/%s@%s", owner, repo, path, sha) + } else { + // Log other types of errors (permissions, I/O issues, etc.) + importCacheLog.Printf("Cache access error for %s/%s/%s@%s: %v", owner, repo, path, sha, err) + } + return "", false + } + + importCacheLog.Printf("Cache hit: %s/%s/%s@%s -> %s", owner, repo, path, sha, fullCachePath) + return fullCachePath, true +} + +// Set stores a new cache entry by saving the content to the cache directory +// sha parameter should be the resolved commit SHA +func (c *ImportCache) Set(owner, repo, path, sha string, content []byte) (string, error) { + // Validate file size (max 10MB) + const maxFileSize = 10 * 1024 * 1024 + if len(content) > maxFileSize { + return "", fmt.Errorf("file size (%d bytes) exceeds maximum allowed size (%d bytes)", len(content), maxFileSize) + } + + // Validate path components to prevent path traversal + if err := validatePathComponents(owner, repo, path, sha); err != nil { + return "", fmt.Errorf("invalid path components: %w", err) + } + + // Use SHA in path for consistent caching + // This ensures that different refs pointing to the same commit reuse the same cache + sanitizedPath := sanitizePath(path) + relativeCachePath := filepath.Join(ImportCacheDir, owner, repo, sha, sanitizedPath) + fullCachePath := filepath.Join(c.baseDir, relativeCachePath) + + // Ensure directory exists + dir := filepath.Dir(fullCachePath) + if err := os.MkdirAll(dir, 0755); err != nil { + importCacheLog.Printf("Failed to create cache directory: %v", err) + return "", err + } + + // Ensure .gitattributes file exists in cache root + if err := c.ensureGitAttributes(); err != nil { + importCacheLog.Printf("Failed to ensure .gitattributes: %v", err) + // Non-fatal error - continue with caching + } + + // Write content to cache file + if err := os.WriteFile(fullCachePath, content, 0644); err != nil { + importCacheLog.Printf("Failed to write cache file: %v", err) + return "", err + } + + importCacheLog.Printf("Cached import: %s/%s/%s@%s -> %s", owner, repo, path, sha, fullCachePath) + return fullCachePath, nil +} + +// GetCacheDir returns the base cache directory path +func (c *ImportCache) GetCacheDir() string { + return filepath.Join(c.baseDir, ImportCacheDir) +} + +// ensureGitAttributes creates the .gitattributes file in the cache directory if it doesn't exist +func (c *ImportCache) ensureGitAttributes() error { + gitAttributesPath := filepath.Join(c.GetCacheDir(), ".gitattributes") + + // Check if .gitattributes already exists + if _, err := os.Stat(gitAttributesPath); err == nil { + // File already exists, nothing to do + return nil + } + + // Ensure cache root directory exists + cacheDir := c.GetCacheDir() + if err := os.MkdirAll(cacheDir, 0755); err != nil { + return err + } + + // Create .gitattributes file with content + content := `# Mark all cached import files as generated +* linguist-generated=true + +# Use 'ours' merge strategy to keep local cached versions +* merge=ours +` + + if err := os.WriteFile(gitAttributesPath, []byte(content), 0644); err != nil { + return err + } + + importCacheLog.Printf("Created .gitattributes in cache directory: %s", gitAttributesPath) + return nil +} diff --git a/pkg/parser/import_cache_integration_test.go b/pkg/parser/import_cache_integration_test.go new file mode 100644 index 00000000000..d163cd77849 --- /dev/null +++ b/pkg/parser/import_cache_integration_test.go @@ -0,0 +1,158 @@ +package parser + +import ( + "os" + "path/filepath" + "testing" +) + +// TestImportCacheIntegration tests the cache with the full import flow +func TestImportCacheIntegration(t *testing.T) { + // Create temp directories for testing + tempDir, err := os.MkdirTemp("", "import-cache-integration-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a cache + cache := NewImportCache(tempDir) + + // Simulate a workflow file that imports from another repo + workflowContent := `--- +imports: + - testowner/testrepo/workflows/shared.md@main +--- + +# Test Workflow + +Use shared configuration. +` + + workflowPath := filepath.Join(tempDir, "test-workflow.md") + if err := os.WriteFile(workflowPath, []byte(workflowContent), 0644); err != nil { + t.Fatalf("Failed to write workflow file: %v", err) + } + + // Simulate a remote file being cached + sharedContent := []byte(`--- +tools: + edit: +--- + +# Shared Configuration + +This is shared configuration. +`) + + // Cache the "remote" file + sha := "abc123" + cachedPath, err := cache.Set("testowner", "testrepo", "workflows/shared.md", sha, sharedContent) + if err != nil { + t.Fatalf("Failed to cache file: %v", err) + } + + // Verify cache can retrieve the file + retrievedPath, found := cache.Get("testowner", "testrepo", "workflows/shared.md", sha) + if !found { + t.Error("Failed to retrieve cached file") + } + if retrievedPath != cachedPath { + t.Errorf("Retrieved path mismatch. Expected %s, got %s", cachedPath, retrievedPath) + } + + // Verify the cached file contains correct content + content, err := os.ReadFile(retrievedPath) + if err != nil { + t.Fatalf("Failed to read cached file: %v", err) + } + if string(content) != string(sharedContent) { + t.Errorf("Content mismatch. Expected %q, got %q", sharedContent, content) + } + + // Test new cache instance can find the file (simulating offline scenario) + cache2 := NewImportCache(tempDir) + + // Verify we can still retrieve the file using filesystem lookup + retrievedPath2, found := cache2.Get("testowner", "testrepo", "workflows/shared.md", sha) + if !found { + t.Error("Failed to retrieve cached file from new cache instance") + } + if retrievedPath2 != cachedPath { + t.Errorf("Retrieved path mismatch from new instance. Expected %s, got %s", cachedPath, retrievedPath2) + } +} + +// TestImportCacheMultipleFiles tests caching multiple files from different repos +func TestImportCacheMultipleFiles(t *testing.T) { + tempDir, err := os.MkdirTemp("", "import-cache-multi-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + cache := NewImportCache(tempDir) + + // Cache multiple files + files := []struct { + owner string + repo string + path string + ref string + sha string + content string + }{ + {"owner1", "repo1", "workflows/a.md", "main", "sha1", "Content A"}, + {"owner1", "repo1", "workflows/b.md", "v1.0", "sha2", "Content B"}, + {"owner2", "repo2", "config/c.md", "main", "sha3", "Content C"}, + } + + for _, f := range files { + _, err := cache.Set(f.owner, f.repo, f.path, f.sha, []byte(f.content)) + if err != nil { + t.Fatalf("Failed to cache file %s/%s/%s@%s: %v", f.owner, f.repo, f.path, f.sha, err) + } + } + + // Verify all files are retrievable + for _, f := range files { + path, found := cache.Get(f.owner, f.repo, f.path, f.sha) + if !found { + t.Errorf("Failed to retrieve cached file %s/%s/%s@%s", f.owner, f.repo, f.path, f.sha) + continue + } + + content, err := os.ReadFile(path) + if err != nil { + t.Errorf("Failed to read cached file: %v", err) + continue + } + + if string(content) != f.content { + t.Errorf("Content mismatch for %s/%s/%s@%s. Expected %q, got %q", + f.owner, f.repo, f.path, f.sha, f.content, string(content)) + } + } + + // Verify from new cache instance using filesystem lookup + cache2 := NewImportCache(tempDir) + + for _, f := range files { + path, found := cache2.Get(f.owner, f.repo, f.path, f.sha) + if !found { + t.Errorf("Failed to retrieve cached file from new instance %s/%s/%s@%s", f.owner, f.repo, f.path, f.sha) + continue + } + + content, err := os.ReadFile(path) + if err != nil { + t.Errorf("Failed to read cached file: %v", err) + continue + } + + if string(content) != f.content { + t.Errorf("Content mismatch from new instance for %s/%s/%s@%s. Expected %q, got %q", + f.owner, f.repo, f.path, f.sha, f.content, string(content)) + } + } +} diff --git a/pkg/parser/import_cache_test.go b/pkg/parser/import_cache_test.go new file mode 100644 index 00000000000..7f625578866 --- /dev/null +++ b/pkg/parser/import_cache_test.go @@ -0,0 +1,159 @@ +package parser + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestImportCache(t *testing.T) { + // Create temp directory for testing + tempDir, err := os.MkdirTemp("", "import-cache-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a new cache + cache := NewImportCache(tempDir) + + // Test Set and Get + testContent := []byte("# Test Workflow\n\nTest content") + owner := "testowner" + repo := "testrepo" + path := "workflows/test.md" + sha := "abc123" + + cachedPath, err := cache.Set(owner, repo, path, sha, testContent) + if err != nil { + t.Fatalf("Failed to set cache entry: %v", err) + } + + // Verify file was created + if _, err := os.Stat(cachedPath); os.IsNotExist(err) { + t.Errorf("Cache file was not created: %s", cachedPath) + } + + // Verify content + content, err := os.ReadFile(cachedPath) + if err != nil { + t.Fatalf("Failed to read cached file: %v", err) + } + if string(content) != string(testContent) { + t.Errorf("Content mismatch. Expected %q, got %q", testContent, content) + } + + // Test Get + retrievedPath, found := cache.Get(owner, repo, path, sha) + if !found { + t.Error("Cache entry not found after Set") + } + if retrievedPath != cachedPath { + t.Errorf("Path mismatch. Expected %s, got %s", cachedPath, retrievedPath) + } + + // Test that a new cache instance can find the file + cache2 := NewImportCache(tempDir) + retrievedPath2, found := cache2.Get(owner, repo, path, sha) + if !found { + t.Error("Cache entry not found from new cache instance") + } + if retrievedPath2 != cachedPath { + t.Errorf("Path mismatch from new instance. Expected %s, got %s", cachedPath, retrievedPath2) + } +} + +func TestImportCacheDirectory(t *testing.T) { + // Create temp directory for testing + tempDir, err := os.MkdirTemp("", "import-cache-dir-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + cache := NewImportCache(tempDir) + + // Test cache directory path + expectedDir := filepath.Join(tempDir, ImportCacheDir) + if cache.GetCacheDir() != expectedDir { + t.Errorf("Cache dir mismatch. Expected %s, got %s", expectedDir, cache.GetCacheDir()) + } + + // Create a cache entry to trigger directory creation + testContent := []byte("test") + _, err = cache.Set("owner", "repo", "test.md", "sha1", testContent) + if err != nil { + t.Fatalf("Failed to set cache entry: %v", err) + } + + // Verify directory was created + if _, err := os.Stat(expectedDir); os.IsNotExist(err) { + t.Errorf("Cache directory was not created: %s", expectedDir) + } + + // Verify .gitattributes was auto-generated + gitAttributesPath := filepath.Join(expectedDir, ".gitattributes") + if _, err := os.Stat(gitAttributesPath); os.IsNotExist(err) { + t.Errorf(".gitattributes file was not created: %s", gitAttributesPath) + } + + // Verify .gitattributes content + content, err := os.ReadFile(gitAttributesPath) + if err != nil { + t.Fatalf("Failed to read .gitattributes: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "linguist-generated=true") { + t.Error(".gitattributes missing 'linguist-generated=true'") + } + if !strings.Contains(contentStr, "merge=ours") { + t.Error(".gitattributes missing 'merge=ours'") + } +} + +func TestImportCacheMissingFile(t *testing.T) { + // Create temp directory for testing + tempDir, err := os.MkdirTemp("", "import-cache-missing-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + cache := NewImportCache(tempDir) + + // Add entry to cache + testContent := []byte("test") + cachedPath, err := cache.Set("owner", "repo", "test.md", "sha1", testContent) + if err != nil { + t.Fatalf("Failed to set cache entry: %v", err) + } + + // Delete the cached file + if err := os.Remove(cachedPath); err != nil { + t.Fatalf("Failed to remove cached file: %v", err) + } + + // Try to get the entry - should return not found since file is missing + _, found := cache.Get("owner", "repo", "test.md", "sha1") + if found { + t.Error("Expected cache miss for deleted file, but got hit") + } +} + +func TestImportCacheEmptyCache(t *testing.T) { + // Create temp directory for testing + tempDir, err := os.MkdirTemp("", "import-cache-empty-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + cache := NewImportCache(tempDir) + + // Try to get from empty cache - should return not found + _, found := cache.Get("owner", "repo", "test.md", "nonexistent-sha") + if found { + t.Error("Expected cache miss for empty cache, but got hit") + } +} diff --git a/pkg/workflow/compiler.go b/pkg/workflow/compiler.go index d205413d9d6..9fd2e9acac8 100644 --- a/pkg/workflow/compiler.go +++ b/pkg/workflow/compiler.go @@ -49,21 +49,22 @@ type FileTracker interface { type Compiler struct { verbose bool engineOverride string - customOutput string // If set, output will be written to this path instead of default location - version string // Version of the extension - skipValidation bool // If true, skip schema validation - noEmit bool // If true, validate without generating lock files - strictMode bool // If true, enforce strict validation requirements - trialMode bool // If true, suppress safe outputs for trial mode execution - trialLogicalRepoSlug string // If set in trial mode, the logical repository to checkout - refreshStopTime bool // If true, regenerate stop-after times instead of preserving existing ones - jobManager *JobManager // Manages jobs and dependencies - engineRegistry *EngineRegistry // Registry of available agentic engines - fileTracker FileTracker // Optional file tracker for tracking created files - warningCount int // Number of warnings encountered during compilation - stepOrderTracker *StepOrderTracker // Tracks step ordering for validation - actionCache *ActionCache // Shared cache for action pin resolutions across all workflows - actionResolver *ActionResolver // Shared resolver for action pins across all workflows + customOutput string // If set, output will be written to this path instead of default location + version string // Version of the extension + skipValidation bool // If true, skip schema validation + noEmit bool // If true, validate without generating lock files + strictMode bool // If true, enforce strict validation requirements + trialMode bool // If true, suppress safe outputs for trial mode execution + trialLogicalRepoSlug string // If set in trial mode, the logical repository to checkout + refreshStopTime bool // If true, regenerate stop-after times instead of preserving existing ones + jobManager *JobManager // Manages jobs and dependencies + engineRegistry *EngineRegistry // Registry of available agentic engines + fileTracker FileTracker // Optional file tracker for tracking created files + warningCount int // Number of warnings encountered during compilation + stepOrderTracker *StepOrderTracker // Tracks step ordering for validation + actionCache *ActionCache // Shared cache for action pin resolutions across all workflows + actionResolver *ActionResolver // Shared resolver for action pins across all workflows + importCache *parser.ImportCache // Shared cache for imported workflow files } // NewCompiler creates a new workflow compiler with optional configuration @@ -148,6 +149,21 @@ func (c *Compiler) getSharedActionResolver() (*ActionCache, *ActionResolver) { return c.actionCache, c.actionResolver } +// getSharedImportCache returns the shared import cache, initializing it on first use +// This ensures all workflows compiled by this compiler instance share the same import cache +func (c *Compiler) getSharedImportCache() *parser.ImportCache { + if c.importCache == nil { + // Initialize cache on first use + cwd, err := os.Getwd() + if err != nil { + cwd = "." + } + c.importCache = parser.NewImportCache(cwd) + log.Print("Initialized shared import cache for compiler") + } + return c.importCache +} + // GetSharedActionCache returns the shared action cache used by this compiler instance. // The cache is lazily initialized on first access and shared across all workflows. // This allows action SHA validation and other operations to reuse cached resolutions. @@ -690,7 +706,8 @@ func (c *Compiler) ParseWorkflowFile(markdownPath string) (*WorkflowData, error) } // Process imports from frontmatter first (before @include directives) - importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(result.Frontmatter, markdownDir) + importCache := c.getSharedImportCache() + importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(result.Frontmatter, markdownDir, importCache) if err != nil { return nil, fmt.Errorf("failed to process imports from frontmatter: %w", err) }