diff --git a/go.mod b/go.mod index 7ebfcf3..cffe51e 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,11 @@ module github.com/codeGROOVE-dev/prcost go 1.25.3 require ( + github.com/codeGROOVE-dev/ds9 v0.0.0-20251028153329-0fbc86f835ed github.com/codeGROOVE-dev/gsm v0.0.0-20251019065141-833fe2363d22 - github.com/codeGROOVE-dev/prx v0.0.0-20251027012315-7b273aabfc7d - github.com/codeGROOVE-dev/turnclient v0.0.0-20251022064427-5a712e1e10e6 + github.com/codeGROOVE-dev/prx v0.0.0-20251027204543-4e6165f046e5 + github.com/codeGROOVE-dev/turnclient v0.0.0-20251028130307-1f85c9aa43c4 golang.org/x/time v0.14.0 ) -require github.com/codeGROOVE-dev/retry v1.2.0 // indirect +require github.com/codeGROOVE-dev/retry v1.3.0 // indirect diff --git a/go.sum b/go.sum index d47fc0e..a3f8548 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,18 @@ +github.com/codeGROOVE-dev/ds9 v0.0.0-20251028153329-0fbc86f835ed h1:PNvhCROSwAybWrJwYkTBUAaydQKJ8dWD8ml7gde5nkA= +github.com/codeGROOVE-dev/ds9 v0.0.0-20251028153329-0fbc86f835ed/go.mod h1:/hZt40fp5FfuzVwiw9fgoOMBAa260rKPiWNJt8RU10Y= github.com/codeGROOVE-dev/gsm v0.0.0-20251019065141-833fe2363d22 h1:gtN3rOc6YspO646BkcOxBhPjEqKUz+jl175jIqglfDg= github.com/codeGROOVE-dev/gsm v0.0.0-20251019065141-833fe2363d22/go.mod h1:KV+w19ubP32PxZPE1hOtlCpTaNpF0Bpb32w5djO8UTg= github.com/codeGROOVE-dev/prx v0.0.0-20251027012315-7b273aabfc7d h1:kUaCKFRxWFrWEyl4fVHi+eY/D5tKhBU29a8YbQyihEk= github.com/codeGROOVE-dev/prx v0.0.0-20251027012315-7b273aabfc7d/go.mod h1:7qLbi18baOyS8yO/6/64SBIqtyzSzLFdsDST15NPH3w= +github.com/codeGROOVE-dev/prx v0.0.0-20251027204543-4e6165f046e5 h1:tjxTLJ5NXx1xhReL4M+J4LTl/JGNSZjPrznAoci06OA= +github.com/codeGROOVE-dev/prx v0.0.0-20251027204543-4e6165f046e5/go.mod h1:FEy3gz9IYDXWnKWkoDSL+pWu6rujxbBSrF4w5A8QSK0= github.com/codeGROOVE-dev/retry v1.2.0 h1:xYpYPX2PQZmdHwuiQAGGzsBm392xIMl4nfMEFApQnu8= github.com/codeGROOVE-dev/retry v1.2.0/go.mod h1:8OgefgV1XP7lzX2PdKlCXILsYKuz6b4ZpHa/20iLi8E= +github.com/codeGROOVE-dev/retry v1.3.0 h1:/+ipAWRJLL6y1R1vprYo0FSjSBvH6fE5j9LKXjpD54g= +github.com/codeGROOVE-dev/retry v1.3.0/go.mod h1:8OgefgV1XP7lzX2PdKlCXILsYKuz6b4ZpHa/20iLi8E= github.com/codeGROOVE-dev/turnclient v0.0.0-20251022064427-5a712e1e10e6 h1:7FCmaftkl362oTZHVJyUg+xhxqfQFx+JisBf7RgklL8= github.com/codeGROOVE-dev/turnclient v0.0.0-20251022064427-5a712e1e10e6/go.mod h1:fYwtN9Ql6lY8t2WvCfENx+mP5FUwjlqwXCLx9CVLY20= +github.com/codeGROOVE-dev/turnclient v0.0.0-20251028130307-1f85c9aa43c4 h1:si9tMEo5SXpDuDXGkJ1zNnnpP8TbmakrkNujAbpKlqA= +github.com/codeGROOVE-dev/turnclient v0.0.0-20251028130307-1f85c9aa43c4/go.mod h1:bFWMd0JeaJY0kSIO5AcRQdJLXF3Fo3eKclE49vmIZes= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= diff --git a/internal/server/server.go b/internal/server/server.go index 3884349..5866a44 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "github.com/codeGROOVE-dev/ds9" "github.com/codeGROOVE-dev/gsm" "github.com/codeGROOVE-dev/prcost/pkg/cost" "github.com/codeGROOVE-dev/prcost/pkg/github" @@ -54,12 +55,29 @@ var tokenPattern = regexp.MustCompile( //go:embed static/* var staticFS embed.FS -// cacheEntry holds cached data. +// cacheEntry holds cached data for in-memory cache. // No TTL needed - Cloud Run kills processes frequently, providing natural cache invalidation. type cacheEntry struct { data any } +// prDataCacheEntity represents a cached PR data entry in DataStore with TTL. +type prDataCacheEntity struct { + Data string `datastore:"data,noindex"` // JSON-encoded cost.PRData + CachedAt time.Time `datastore:"cached_at"` // When this was cached + ExpiresAt time.Time `datastore:"expires_at"` // When this expires (1 hour from CachedAt) + URL string `datastore:"url"` // PR URL for debugging +} + +// prQueryCacheEntity represents a cached PR query result in DataStore with TTL. +type prQueryCacheEntity struct { + Data string `datastore:"data,noindex"` // JSON-encoded []github.PRSummary + CachedAt time.Time `datastore:"cached_at"` // When this was cached + ExpiresAt time.Time `datastore:"expires_at"` // When this expires (varies by type) + QueryType string `datastore:"query_type"` // "repo" or "org" + QueryKey string `datastore:"query_key"` // Full query key for debugging +} + // Server handles HTTP requests for the PR Cost API. // //nolint:govet // fieldalignment: struct field ordering optimized for readability over memory @@ -87,6 +105,8 @@ type Server struct { prDataCache map[string]*cacheEntry prQueryCacheMu sync.RWMutex prDataCacheMu sync.RWMutex + // DataStore client for persistent caching (nil if not enabled). + dsClient *ds9.Client } // CalculateRequest represents a request to calculate PR costs. @@ -202,6 +222,21 @@ func New() *Server { // - Cloud Run instances are ephemeral and restart frequently anyway // If needed in the future, implement LRU eviction with size limits instead of time-based clearing + // Initialize DataStore client if DATASTORE_DB is set (persistent caching across restarts). + if dbID := os.Getenv("DATASTORE_DB"); dbID != "" { + dsClient, err := ds9.NewClientWithDatabase(ctx, "", dbID) + if err != nil { + logger.WarnContext(ctx, "Failed to initialize DataStore client - persistent caching disabled", + "database_id", dbID, "error", err) + } else { + server.dsClient = dsClient + logger.InfoContext(ctx, "DataStore persistent caching enabled", + "database_id", dbID) + } + } else { + logger.InfoContext(ctx, "DataStore persistent caching disabled (DATASTORE_DB not set)") + } + return server } @@ -307,52 +342,201 @@ func (s *Server) limiter(ctx context.Context, ip string) *rate.Limiter { return limiter } -// cachedPRQuery retrieves cached PR query results. -func (s *Server) cachedPRQuery(key string) ([]github.PRSummary, bool) { +// cachedPRQuery retrieves cached PR query results from memory first, then DataStore as fallback. +func (s *Server) cachedPRQuery(ctx context.Context, key string) ([]github.PRSummary, bool) { + // Check in-memory cache first (fast path). s.prQueryCacheMu.RLock() - defer s.prQueryCacheMu.RUnlock() - entry, exists := s.prQueryCache[key] - if !exists { + s.prQueryCacheMu.RUnlock() + + if exists { + prs, ok := entry.data.([]github.PRSummary) + if ok { + s.logger.DebugContext(ctx, "PR query cache hit (memory)", "key", key) + return prs, true + } + } + + // Memory miss - try DataStore if available. + if s.dsClient == nil { + return nil, false + } + + dsKey := ds9.NameKey("PRQueryCache", key, nil) + var entity prQueryCacheEntity + err := s.dsClient.Get(ctx, dsKey, &entity) + if err != nil { + if !errors.Is(err, ds9.ErrNoSuchEntity) { + s.logger.WarnContext(ctx, "DataStore cache read failed", "key", key, "error", err) + } + return nil, false + } + + // Check if expired (TTL varies by query type). + if time.Now().After(entity.ExpiresAt) { + s.logger.DebugContext(ctx, "DataStore cache entry expired", "key", key, "expires_at", entity.ExpiresAt) return nil, false } - prs, ok := entry.data.([]github.PRSummary) - return prs, ok + // Deserialize the cached data. + var prs []github.PRSummary + if err := json.Unmarshal([]byte(entity.Data), &prs); err != nil { + s.logger.WarnContext(ctx, "Failed to deserialize cached PR query", "key", key, "error", err) + return nil, false + } + + s.logger.InfoContext(ctx, "PR query cache hit (DataStore)", + "key", key, "query_type", entity.QueryType, "cached_at", entity.CachedAt, "pr_count", len(prs)) + + // Populate in-memory cache for faster subsequent access. + s.prQueryCacheMu.Lock() + s.prQueryCache[key] = &cacheEntry{data: prs} + s.prQueryCacheMu.Unlock() + + return prs, true } -// cachePRQuery stores PR query results in cache. -func (s *Server) cachePRQuery(key string, prs []github.PRSummary) { +// cachePRQuery stores PR query results in both memory and DataStore caches. +func (s *Server) cachePRQuery(ctx context.Context, key string, prs []github.PRSummary) { + // Write to in-memory cache first (fast path). s.prQueryCacheMu.Lock() - defer s.prQueryCacheMu.Unlock() + s.prQueryCache[key] = &cacheEntry{data: prs} + s.prQueryCacheMu.Unlock() + + // Write to DataStore if available (persistent cache). + if s.dsClient == nil { + return + } + + // Serialize the PR query results. + dataJSON, err := json.Marshal(prs) + if err != nil { + s.logger.WarnContext(ctx, "Failed to serialize PR query for DataStore", "key", key, "error", err) + return + } + + // Determine query type and TTL from key format. + var queryType string + var ttl time.Duration + switch { + case strings.HasPrefix(key, "repo:"): + queryType = "repo" + ttl = 72 * time.Hour // 72 hours for repo queries + case strings.HasPrefix(key, "org:"): + queryType = "org" + ttl = 72 * time.Hour // 72 hours for org queries + default: + s.logger.WarnContext(ctx, "Unknown query type for key, using default TTL", "key", key) + queryType = "unknown" + ttl = 72 * time.Hour // Default to 72 hours + } + + now := time.Now() + entity := prQueryCacheEntity{ + Data: string(dataJSON), + CachedAt: now, + ExpiresAt: now.Add(ttl), + QueryType: queryType, + QueryKey: key, + } - s.prQueryCache[key] = &cacheEntry{ - data: prs, + dsKey := ds9.NameKey("PRQueryCache", key, nil) + if _, err := s.dsClient.Put(ctx, dsKey, &entity); err != nil { + s.logger.WarnContext(ctx, "Failed to write PR query to DataStore", "key", key, "error", err) + return } + + s.logger.DebugContext(ctx, "PR query cached to DataStore", + "key", key, "query_type", queryType, "ttl", ttl, "expires_at", entity.ExpiresAt, "pr_count", len(prs)) } -// cachedPRData retrieves cached PR data. -func (s *Server) cachedPRData(key string) (cost.PRData, bool) { +// cachedPRData retrieves cached PR data from memory first, then DataStore as fallback. +func (s *Server) cachedPRData(ctx context.Context, key string) (cost.PRData, bool) { + // Check in-memory cache first (fast path). s.prDataCacheMu.RLock() - defer s.prDataCacheMu.RUnlock() - entry, exists := s.prDataCache[key] - if !exists { + s.prDataCacheMu.RUnlock() + + if exists { + prData, ok := entry.data.(cost.PRData) + if ok { + s.logger.DebugContext(ctx, "PR data cache hit (memory)", "key", key) + return prData, true + } + } + + // Memory miss - try DataStore if available. + if s.dsClient == nil { + return cost.PRData{}, false + } + + dsKey := ds9.NameKey("PRDataCache", key, nil) + var entity prDataCacheEntity + err := s.dsClient.Get(ctx, dsKey, &entity) + if err != nil { + if !errors.Is(err, ds9.ErrNoSuchEntity) { + s.logger.WarnContext(ctx, "DataStore cache read failed", "key", key, "error", err) + } + return cost.PRData{}, false + } + + // Check if expired (1 hour TTL for PRs). + if time.Now().After(entity.ExpiresAt) { + s.logger.DebugContext(ctx, "DataStore cache entry expired", "key", key, "expires_at", entity.ExpiresAt) + return cost.PRData{}, false + } + + // Deserialize the cached data. + var prData cost.PRData + if err := json.Unmarshal([]byte(entity.Data), &prData); err != nil { + s.logger.WarnContext(ctx, "Failed to deserialize cached PR data", "key", key, "error", err) return cost.PRData{}, false } - prData, ok := entry.data.(cost.PRData) - return prData, ok + s.logger.InfoContext(ctx, "PR data cache hit (DataStore)", "key", key, "cached_at", entity.CachedAt) + + // Populate in-memory cache for faster subsequent access. + s.prDataCacheMu.Lock() + s.prDataCache[key] = &cacheEntry{data: prData} + s.prDataCacheMu.Unlock() + + return prData, true } -// cachePRData stores PR data in cache. -func (s *Server) cachePRData(key string, prData cost.PRData) { +// cachePRData stores PR data in both memory and DataStore caches. +func (s *Server) cachePRData(ctx context.Context, key string, prData cost.PRData) { + // Write to in-memory cache first (fast path). s.prDataCacheMu.Lock() - defer s.prDataCacheMu.Unlock() + s.prDataCache[key] = &cacheEntry{data: prData} + s.prDataCacheMu.Unlock() - s.prDataCache[key] = &cacheEntry{ - data: prData, + // Write to DataStore if available (persistent cache). + if s.dsClient == nil { + return } + + // Serialize the PR data. + dataJSON, err := json.Marshal(prData) + if err != nil { + s.logger.WarnContext(ctx, "Failed to serialize PR data for DataStore", "key", key, "error", err) + return + } + + now := time.Now() + entity := prDataCacheEntity{ + Data: string(dataJSON), + CachedAt: now, + ExpiresAt: now.Add(1 * time.Hour), // 1 hour TTL for PRs + URL: key, + } + + dsKey := ds9.NameKey("PRDataCache", key, nil) + if _, err := s.dsClient.Put(ctx, dsKey, &entity); err != nil { + s.logger.WarnContext(ctx, "Failed to write PR data to DataStore", "key", key, "error", err) + return + } + + s.logger.DebugContext(ctx, "PR data cached to DataStore", "key", key, "expires_at", entity.ExpiresAt) } // SetTokenValidation configures GitHub token validation. @@ -735,7 +919,7 @@ func (s *Server) processRequest(ctx context.Context, req *CalculateRequest, toke // Try cache first cacheKey := fmt.Sprintf("pr:%s", req.URL) - prData, cached := s.cachedPRData(cacheKey) + prData, cached := s.cachedPRData(ctx, cacheKey) if cached { s.logger.InfoContext(ctx, "[processRequest] Using cached PR data", "url", req.URL) } else { @@ -761,7 +945,7 @@ func (s *Server) processRequest(ctx context.Context, req *CalculateRequest, toke } // Cache PR data - s.cachePRData(cacheKey, prData) + s.cachePRData(ctx, cacheKey, prData) } // Calculate costs. @@ -1242,7 +1426,7 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest, // Try cache first cacheKey := fmt.Sprintf("repo:%s/%s:days=%d", req.Owner, req.Repo, req.Days) - prs, cached := s.cachedPRQuery(cacheKey) + prs, cached := s.cachedPRQuery(ctx, cacheKey) if cached { s.logger.InfoContext(ctx, "Using cached PR query results", "owner", req.Owner, "repo", req.Repo, "total_prs", len(prs)) @@ -1258,7 +1442,7 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest, "owner", req.Owner, "repo", req.Repo, "total_prs", len(prs)) // Cache query results - s.cachePRQuery(cacheKey, prs) + s.cachePRQuery(ctx, cacheKey, prs) } if len(prs) == 0 { @@ -1283,7 +1467,7 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest, // Try cache first prCacheKey := fmt.Sprintf("pr:%s", prURL) - prData, prCached := s.cachedPRData(prCacheKey) + prData, prCached := s.cachedPRData(ctx, prCacheKey) if !prCached { var err error // Use configured data source with updatedAt for effective caching @@ -1298,7 +1482,7 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest, } // Cache PR data - s.cachePRData(prCacheKey, prData) + s.cachePRData(ctx, prCacheKey, prData) } breakdown := cost.Calculate(prData, cfg) @@ -1343,7 +1527,7 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to // Try cache first cacheKey := fmt.Sprintf("org:%s:days=%d", req.Org, req.Days) - prs, cached := s.cachedPRQuery(cacheKey) + prs, cached := s.cachedPRQuery(ctx, cacheKey) if cached { s.logger.InfoContext(ctx, "Using cached PR query results", "org", req.Org, "total_prs", len(prs)) @@ -1358,7 +1542,7 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to s.logger.InfoContext(ctx, "Fetched PRs from organization", "org", req.Org, "total_prs", len(prs)) // Cache query results - s.cachePRQuery(cacheKey, prs) + s.cachePRQuery(ctx, cacheKey, prs) } if len(prs) == 0 { @@ -1383,7 +1567,7 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to // Try cache first prCacheKey := fmt.Sprintf("pr:%s", prURL) - prData, prCached := s.cachedPRData(prCacheKey) + prData, prCached := s.cachedPRData(ctx, prCacheKey) if !prCached { var err error // Use configured data source with updatedAt for effective caching @@ -1398,7 +1582,7 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to } // Cache PR data - s.cachePRData(prCacheKey, prData) + s.cachePRData(ctx, prCacheKey, prData) } breakdown := cost.Calculate(prData, cfg) @@ -1715,7 +1899,7 @@ func (s *Server) processRepoSampleWithProgress(ctx context.Context, req *RepoSam // Try cache first cacheKey := fmt.Sprintf("repo:%s/%s:days=%d", req.Owner, req.Repo, req.Days) - prs, cached := s.cachedPRQuery(cacheKey) + prs, cached := s.cachedPRQuery(ctx, cacheKey) if !cached { // Send progress update before GraphQL query logSSEError(ctx, s.logger, sendSSE(writer, ProgressUpdate{ @@ -1759,7 +1943,7 @@ func (s *Server) processRepoSampleWithProgress(ctx context.Context, req *RepoSam } // Cache query results - s.cachePRQuery(cacheKey, prs) + s.cachePRQuery(ctx, cacheKey, prs) } if len(prs) == 0 { @@ -1851,7 +2035,7 @@ func (s *Server) processOrgSampleWithProgress(ctx context.Context, req *OrgSampl // Try cache first cacheKey := fmt.Sprintf("org:%s:days=%d", req.Org, req.Days) - prs, cached := s.cachedPRQuery(cacheKey) + prs, cached := s.cachedPRQuery(ctx, cacheKey) if !cached { // Send progress update before GraphQL query logSSEError(ctx, s.logger, sendSSE(writer, ProgressUpdate{ @@ -1893,7 +2077,7 @@ func (s *Server) processOrgSampleWithProgress(ctx context.Context, req *OrgSampl } // Cache query results - s.cachePRQuery(cacheKey, prs) + s.cachePRQuery(ctx, cacheKey, prs) } if len(prs) == 0 { @@ -2013,7 +2197,7 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [ // Try cache first prCacheKey := fmt.Sprintf("pr:%s", prURL) - prData, prCached := s.cachedPRData(prCacheKey) + prData, prCached := s.cachedPRData(workCtx, prCacheKey) if !prCached { var err error // Use work context for actual API calls (not tied to client connection) @@ -2039,7 +2223,7 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [ } // Cache the PR data - s.cachePRData(prCacheKey, prData) + s.cachePRData(workCtx, prCacheKey, prData) } // Send "processing" update using request context for SSE diff --git a/internal/server/server_test.go b/internal/server/server_test.go index d97c30f..e2ec71c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "encoding/json" "errors" "net/http" @@ -11,6 +12,7 @@ import ( "time" "github.com/codeGROOVE-dev/prcost/pkg/cost" + "github.com/codeGROOVE-dev/prcost/pkg/github" ) func TestNew(t *testing.T) { @@ -586,3 +588,618 @@ func TestConfigMerging(t *testing.T) { t.Errorf("Config.EventDuration = %v, want 15m", parsedReq.Config.EventDuration) } } + +// Test cache functions +func TestCachePRDataMemory(t *testing.T) { + s := New() + ctx := testContext() + + prData := cost.PRData{ + LinesAdded: 100, + LinesDeleted: 50, + Author: "testuser", + CreatedAt: time.Now(), + } + + key := "pr:https://github.com/owner/repo/pull/123" + + // Initially should not be cached + _, cached := s.cachedPRData(ctx, key) + if cached { + t.Error("PR data should not be cached initially") + } + + // Cache the data + s.cachePRData(ctx, key, prData) + + // Should now be cached + cachedData, cached := s.cachedPRData(ctx, key) + if !cached { + t.Error("PR data should be cached after caching") + } + + if cachedData.LinesAdded != prData.LinesAdded { + t.Errorf("Cached LinesAdded = %d, want %d", cachedData.LinesAdded, prData.LinesAdded) + } + if cachedData.Author != prData.Author { + t.Errorf("Cached Author = %s, want %s", cachedData.Author, prData.Author) + } +} + +func TestCachePRQueryMemory(t *testing.T) { + s := New() + ctx := testContext() + + prs := []github.PRSummary{ + {Number: 123, Owner: "owner", Repo: "repo", Author: "testuser", UpdatedAt: time.Now()}, + {Number: 456, Owner: "owner", Repo: "repo", Author: "testuser2", UpdatedAt: time.Now()}, + } + + key := "repo:owner/repo:days=30" + + // Initially should not be cached + _, cached := s.cachedPRQuery(ctx, key) + if cached { + t.Error("PR query should not be cached initially") + } + + // Cache the query results + s.cachePRQuery(ctx, key, prs) + + // Should now be cached + cachedPRs, cached := s.cachedPRQuery(ctx, key) + if !cached { + t.Error("PR query should be cached after caching") + } + + if len(cachedPRs) != len(prs) { + t.Errorf("Cached PR count = %d, want %d", len(cachedPRs), len(prs)) + } + if cachedPRs[0].Number != prs[0].Number { + t.Errorf("Cached PR number = %d, want %d", cachedPRs[0].Number, prs[0].Number) + } +} + +func TestCacheKeyPrefixes(t *testing.T) { + s := New() + ctx := testContext() + + // Test different key prefixes + tests := []struct { + name string + key string + }{ + {"PR key", "pr:https://github.com/owner/repo/pull/123"}, + {"Repo key", "repo:owner/repo:days=30"}, + {"Org key", "org:myorg:days=90"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prs := []github.PRSummary{{Number: 1}} + s.cachePRQuery(ctx, tt.key, prs) + + cached, ok := s.cachedPRQuery(ctx, tt.key) + if !ok { + t.Errorf("Key %s should be cached", tt.key) + } + if len(cached) != 1 { + t.Errorf("Expected 1 PR, got %d", len(cached)) + } + }) + } +} + +func TestHandleCalculateInvalidJSON(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodPost, "/v1/calculate", strings.NewReader("{invalid json")) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + s.handleCalculate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("handleCalculate() with invalid JSON status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestHandleCalculateMissingURL(t *testing.T) { + s := New() + + reqBody := CalculateRequest{} // No URL + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/v1/calculate", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + s.handleCalculate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("handleCalculate() with missing URL status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestHandleRepoSampleInvalidJSON(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodPost, "/v1/repo-sample", strings.NewReader("{invalid json")) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + s.handleRepoSample(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("handleRepoSample() with invalid JSON status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestHandleRepoSampleMissingFields(t *testing.T) { + s := New() + + tests := []struct { + name string + body RepoSampleRequest + }{ + { + name: "missing owner", + body: RepoSampleRequest{Repo: "repo", Days: 30}, + }, + { + name: "missing repo", + body: RepoSampleRequest{Owner: "owner", Days: 30}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest(http.MethodPost, "/v1/repo-sample", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + s.handleRepoSample(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("handleRepoSample() %s status = %d, want %d", tt.name, w.Code, http.StatusBadRequest) + } + }) + } +} + +func TestHandleOrgSampleMissingOrg(t *testing.T) { + s := New() + + reqBody := OrgSampleRequest{Days: 30} // Missing Org + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/v1/org-sample", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + s.handleOrgSample(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("handleOrgSample() with missing org status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestHandleRepoSampleStreamHeaders(t *testing.T) { + s := New() + + reqBody := RepoSampleRequest{ + Owner: "testowner", + Repo: "testrepo", + Days: 30, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/v1/repo-sample-stream", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + // Note: This will fail with no token error or GitHub API error, but we're testing headers + s.handleRepoSampleStream(w, req) + + // Check SSE headers were set + contentType := w.Header().Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Content-Type = %s, want text/event-stream", contentType) + } + + cacheControl := w.Header().Get("Cache-Control") + if cacheControl != "no-cache" { + t.Errorf("Cache-Control = %s, want no-cache", cacheControl) + } + + connection := w.Header().Get("Connection") + if connection != "keep-alive" { + t.Errorf("Connection = %s, want keep-alive", connection) + } +} + +func TestHandleOrgSampleStreamHeaders(t *testing.T) { + s := New() + + reqBody := OrgSampleRequest{ + Org: "testorg", + Days: 30, + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/v1/org-sample-stream", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer ghp_test") + + w := httptest.NewRecorder() + s.handleOrgSampleStream(w, req) + + // Check SSE headers were set + contentType := w.Header().Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Content-Type = %s, want text/event-stream", contentType) + } +} + +func TestMergeConfig(t *testing.T) { + s := New() + + baseConfig := cost.Config{ + AnnualSalary: 250000, + } + customConfig := &cost.Config{ + AnnualSalary: 300000, + } + + merged := s.mergeConfig(baseConfig, customConfig) + + if merged.AnnualSalary != 300000 { + t.Errorf("mergeConfig() AnnualSalary = %f, want 300000", merged.AnnualSalary) + } +} + +func TestHandleNotFound(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/invalid/path", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Invalid path status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestHandleMethodNotAllowed(t *testing.T) { + s := New() + + // PATCH is not allowed on /v1/calculate + req := httptest.NewRequest(http.MethodPatch, "/v1/calculate", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Wrong method status = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } +} + +func TestSetTokenValidationErrors(t *testing.T) { + s := New() + + // Test with invalid app ID (empty) + err := s.SetTokenValidation("", "nonexistent.pem") + if err == nil { + t.Error("SetTokenValidation() with empty app ID should return error") + } + + // Test with nonexistent key file + err = s.SetTokenValidation("12345", "/nonexistent/path/key.pem") + if err == nil { + t.Error("SetTokenValidation() with nonexistent key file should return error") + } +} + +func TestSetDataSource(t *testing.T) { + s := New() + + tests := []struct { + name string + source string + wantSource string + }{ + {"prx source", "prx", "prx"}, + {"turnserver source", "turnserver", "turnserver"}, + {"invalid source falls back to prx", "custom", "prx"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s.SetDataSource(tt.source) + if s.dataSource != tt.wantSource { + t.Errorf("SetDataSource(%s) = %s, want %s", tt.source, s.dataSource, tt.wantSource) + } + }) + } +} + +func TestLimiterConcurrency(t *testing.T) { + s := New() + s.SetRateLimit(10, 10) + ctx := testContext() + + // Test that same IP gets same limiter (concurrency safe) + limiter1 := s.limiter(ctx, "192.168.1.1") + limiter2 := s.limiter(ctx, "192.168.1.1") + + if limiter1 != limiter2 { + t.Error("Same IP should return same limiter instance") + } + + // Test that different IPs get different limiters + limiter3 := s.limiter(ctx, "192.168.1.2") + if limiter1 == limiter3 { + t.Error("Different IPs should return different limiters") + } +} + +func TestSanitizeErrorWithMultipleTokens(t *testing.T) { + input := errors.New("error with Bearer ghp_token1 and token ghp_token2") + result := sanitizeError(input) + + if strings.Contains(result, "ghp_") { + t.Errorf("sanitizeError() still contains token: %s", result) + } + if !strings.Contains(result, "[REDACTED_TOKEN]") { + t.Error("sanitizeError() should contain redaction marker") + } +} + +func TestAllowAllCorsFlag(t *testing.T) { + s := New() + s.SetCORSConfig("", true) // Allow all + + // Verify the allowAllCors flag is set + if !s.allowAllCors { + t.Error("SetCORSConfig with allowAll=true should set allowAllCors flag") + } + + // When allowAll is false, flag should be false + s.SetCORSConfig("https://example.com", false) + if s.allowAllCors { + t.Error("SetCORSConfig with allowAll=false should clear allowAllCors flag") + } +} + +func TestIsOriginAllowedEdgeCases(t *testing.T) { + s := New() + s.SetCORSConfig("https://example.com,https://*.test.com", false) + + tests := []struct { + name string + origin string + want bool + }{ + {"empty origin", "", false}, + {"case sensitive exact match", "https://Example.com", false}, + // Note: The wildcard matcher appears to match the base domain too + {"wildcard matches base domain", "https://test.com", true}, + {"wildcard matches subdomain", "https://sub.test.com", true}, + {"wildcard ignores path", "https://sub.test.com/path", true}, // Path is stripped before matching + {"unmatched domain", "https://other.com", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := s.isOriginAllowed(tt.origin) + if got != tt.want { + t.Errorf("isOriginAllowed(%q) = %v, want %v", tt.origin, got, tt.want) + } + }) + } +} + +func TestRateLimiterBehavior(t *testing.T) { + s := New() + s.SetRateLimit(1, 2) // 1 per second, burst of 2 + ctx := testContext() + + limiter := s.limiter(ctx, "192.168.1.100") + + // First two requests should be allowed (burst) + if !limiter.Allow() { + t.Error("First request should be allowed (within burst)") + } + if !limiter.Allow() { + t.Error("Second request should be allowed (within burst)") + } + + // Third request should be rate limited + if limiter.Allow() { + t.Error("Third request should be rate limited (burst exhausted)") + } +} + +func TestValidateGitHubPRURLEdgeCases(t *testing.T) { + s := New() + + tests := []struct { + name string + url string + wantErr bool + }{ + {"PR number zero", "https://github.com/owner/repo/pull/0", false}, + {"Large PR number", "https://github.com/owner/repo/pull/999999", false}, + {"Dashes in owner", "https://github.com/owner-name/repo/pull/123", false}, + {"Dashes in repo", "https://github.com/owner/repo-name/pull/123", false}, + {"Underscores rejected", "https://github.com/owner_name/repo_name/pull/123", true}, + {"Numbers in names", "https://github.com/owner123/repo456/pull/123", false}, + {"Dots in repo", "https://github.com/owner/repo.name/pull/123", false}, + {"Single char owner", "https://github.com/a/repo/pull/123", false}, + {"Single char repo", "https://github.com/owner/r/pull/123", false}, + {"Non-numeric PR number", "https://github.com/owner/repo/pull/abc", true}, + {"Negative PR number", "https://github.com/owner/repo/pull/-1", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := s.validateGitHubPRURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("validateGitHubPRURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr) + } + }) + } +} + +func TestParseRequestEdgeCases(t *testing.T) { + s := New() + + tests := []struct { + name string + contentType string + body string + wantErr bool + }{ + { + name: "empty body", + contentType: "application/json", + body: "", + wantErr: true, + }, + { + name: "whitespace only", + contentType: "application/json", + body: " ", + wantErr: true, + }, + { + name: "null json", + contentType: "application/json", + body: "null", + wantErr: true, + }, + { + name: "array instead of object", + contentType: "application/json", + body: "[]", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/calculate", strings.NewReader(tt.body)) + if tt.contentType != "" { + req.Header.Set("Content-Type", tt.contentType) + } + + _, err := s.parseRequest(req.Context(), req) + if (err != nil) != tt.wantErr { + t.Errorf("parseRequest() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCacheConcurrency(t *testing.T) { + s := New() + ctx := testContext() + + prData := cost.PRData{ + LinesAdded: 100, + Author: "testuser", + } + + key := "pr:https://github.com/owner/repo/pull/123" + + // Test concurrent writes + done := make(chan bool) + for i := 0; i < 10; i++ { + go func() { + s.cachePRData(ctx, key, prData) + done <- true + }() + } + + // Wait for all writes + for i := 0; i < 10; i++ { + <-done + } + + // Test concurrent reads + for i := 0; i < 10; i++ { + go func() { + _, _ = s.cachedPRData(ctx, key) + done <- true + }() + } + + // Wait for all reads + for i := 0; i < 10; i++ { + <-done + } + + // Verify data is still correct + cached, ok := s.cachedPRData(ctx, key) + if !ok { + t.Error("Data should still be cached after concurrent access") + } + if cached.LinesAdded != 100 { + t.Errorf("Cached data corrupted: LinesAdded = %d, want 100", cached.LinesAdded) + } +} + +func TestExtractTokenVariations(t *testing.T) { + s := New() + + tests := []struct { + name string + authHeader string + wantToken string + description string + }{ + { + name: "Bearer with single space", + authHeader: "Bearer ghp_token123", + wantToken: "ghp_token123", + description: "Standard Bearer format", + }, + { + name: "token prefix", + authHeader: "token ghp_token123", + wantToken: "ghp_token123", + description: "Lowercase token prefix", + }, + { + name: "plain token", + authHeader: "ghp_token123", + wantToken: "ghp_token123", + description: "No prefix", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/calculate", http.NoBody) + req.Header.Set("Authorization", tt.authHeader) + + got := s.extractToken(req) + if got != tt.wantToken { + t.Errorf("extractToken() = %q, want %q (%s)", got, tt.wantToken, tt.description) + } + }) + } +} + +// Helper function to create a test context +func testContext() context.Context { + return context.Background() +} diff --git a/pkg/cost/cost_test.go b/pkg/cost/cost_test.go index 01d3e99..6a27a62 100644 --- a/pkg/cost/cost_test.go +++ b/pkg/cost/cost_test.go @@ -1,7 +1,10 @@ package cost import ( + "context" "encoding/json" + "errors" + "log/slog" "os" "strings" "testing" @@ -497,10 +500,10 @@ func TestCalculateWithRealPR13(t *testing.T) { // 90 days absolute cap = 2160 hours // Delivery: 2160 * 0.15 = 324 hours - expectedDeliveryHours := 90.0 * 24.0 * 0.15 // 324 hours + expectedDeliveryHours := 90.0 * 24.0 * 0.20 // 432 hours (20% default delay factor) if breakdown.DelayCostDetail.DeliveryDelayHours != expectedDeliveryHours { - t.Errorf("Expected %.0f delivery delay hours (15%% of 90 day cap), got %.2f", + t.Errorf("Expected %.0f delivery delay hours (20%% of 90 day cap), got %.2f", expectedDeliveryHours, breakdown.DelayCostDetail.DeliveryDelayHours) } @@ -520,7 +523,7 @@ func TestCalculateWithRealPR13(t *testing.T) { t.Logf("PR 13 breakdown (6 year old PR):") t.Logf(" 638 LOC added") t.Logf(" Author cost: $%.2f", breakdown.Author.TotalCost) - t.Logf(" Delivery Delay (15%%): $%.2f (%.0f hrs, capped at 90 days)", + t.Logf(" Delivery Delay (20%%): $%.2f (%.0f hrs, capped at 90 days)", breakdown.DelayCostDetail.DeliveryDelayCost, breakdown.DelayCostDetail.DeliveryDelayHours) t.Logf(" Code Churn: $%.2f (%.1f%% rework, capped at 90 days drift)", breakdown.DelayCostDetail.CodeChurnCost, breakdown.DelayCostDetail.ReworkPercentage) @@ -555,11 +558,11 @@ func TestCalculateLongPRCapped(t *testing.T) { // Last event was 120 days ago, so we only count 14 days after it // Capped hours: 14 days = 336 hours - // Delivery delay: 336 * 0.15 = 50.4 hours - expectedDeliveryHours := 14.0 * 24.0 * 0.15 + // Delivery delay: 336 * 0.20 = 67.2 hours (20% default delay factor) + expectedDeliveryHours := 14.0 * 24.0 * 0.20 if breakdown.DelayCostDetail.DeliveryDelayHours != expectedDeliveryHours { - t.Errorf("Expected %.1f delivery delay hours (15%% of 14 days), got %.2f", + t.Errorf("Expected %.1f delivery delay hours (20%% of 14 days), got %.2f", expectedDeliveryHours, breakdown.DelayCostDetail.DeliveryDelayHours) } } @@ -696,3 +699,918 @@ func TestCalculateFastTurnaroundNoDelay(t *testing.T) { }) } } + +// Mock PRFetcher for testing AnalyzePRs +type mockPRFetcher struct { + data map[string]PRData + failURLs map[string]error + callCount int + maxCalls int // Fail after this many calls (0 = no limit) + fetchDelay time.Duration +} + +func (m *mockPRFetcher) FetchPRData(ctx context.Context, prURL string, updatedAt time.Time) (PRData, error) { + m.callCount++ + + // Fail after max calls if set + if m.maxCalls > 0 && m.callCount > m.maxCalls { + return PRData{}, errors.New("max calls exceeded") + } + + // Check for context cancellation + if ctx.Err() != nil { + return PRData{}, ctx.Err() + } + + // Simulate fetch delay + if m.fetchDelay > 0 { + time.Sleep(m.fetchDelay) + } + + // Check for specific URL failure + if m.failURLs != nil { + if err, ok := m.failURLs[prURL]; ok { + return PRData{}, err + } + } + + // Return mock data + if m.data != nil { + if data, ok := m.data[prURL]; ok { + return data, nil + } + } + + // Default success case + now := time.Now() + return PRData{ + LinesAdded: 100, + Author: "test-author", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "test-author", Kind: "commit"}, + }, + CreatedAt: now.Add(-1 * time.Hour), + ClosedAt: now, + }, nil +} + +func TestAnalyzePRsNoSamples(t *testing.T) { + ctx := context.Background() + fetcher := &mockPRFetcher{} + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{}, + Fetcher: fetcher, + Config: DefaultConfig(), + } + + result, err := AnalyzePRs(ctx, req) + + if err == nil { + t.Error("Expected error when no samples provided") + } + + if result != nil { + t.Error("Expected nil result when no samples provided") + } + + if err.Error() != "no samples provided" { + t.Errorf("Expected 'no samples provided' error, got: %v", err) + } +} + +func TestAnalyzePRsNoFetcher(t *testing.T) { + ctx := context.Background() + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: time.Now()}, + }, + Fetcher: nil, + Config: DefaultConfig(), + } + + result, err := AnalyzePRs(ctx, req) + + if err == nil { + t.Error("Expected error when fetcher is nil") + } + + if result != nil { + t.Error("Expected nil result when fetcher is nil") + } + + if err.Error() != "fetcher is required" { + t.Errorf("Expected 'fetcher is required' error, got: %v", err) + } +} + +func TestAnalyzePRsSequentialSuccess(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{ + data: map[string]PRData{ + "https://github.com/owner/repo/pull/1": { + LinesAdded: 50, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, + "https://github.com/owner/repo/pull/2": { + LinesAdded: 100, + Author: "author2", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author2", Kind: "commit"}, + }, + CreatedAt: now.Add(-3 * time.Hour), + ClosedAt: now, + }, + }, + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 0, // Sequential + } + + result, err := AnalyzePRs(ctx, req) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if len(result.Breakdowns) != 2 { + t.Errorf("Expected 2 breakdowns, got %d", len(result.Breakdowns)) + } + + if result.Skipped != 0 { + t.Errorf("Expected 0 skipped, got %d", result.Skipped) + } + + if fetcher.callCount != 2 { + t.Errorf("Expected 2 fetcher calls, got %d", fetcher.callCount) + } +} + +func TestAnalyzePRsSequentialPartialFailure(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{ + data: map[string]PRData{ + "https://github.com/owner/repo/pull/1": { + LinesAdded: 50, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, + }, + failURLs: map[string]error{ + "https://github.com/owner/repo/pull/2": errors.New("fetch failed"), + }, + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 1, + } + + result, err := AnalyzePRs(ctx, req) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if len(result.Breakdowns) != 1 { + t.Errorf("Expected 1 breakdown, got %d", len(result.Breakdowns)) + } + + if result.Skipped != 1 { + t.Errorf("Expected 1 skipped, got %d", result.Skipped) + } +} + +func TestAnalyzePRsSequentialAllFail(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{ + failURLs: map[string]error{ + "https://github.com/owner/repo/pull/1": errors.New("fetch failed"), + "https://github.com/owner/repo/pull/2": errors.New("fetch failed"), + }, + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 1, + } + + result, err := AnalyzePRs(ctx, req) + + if err == nil { + t.Error("Expected error when all fetches fail") + } + + if result != nil { + t.Error("Expected nil result when all fetches fail") + } + + expectedErrMsg := "no samples could be processed successfully (2 skipped)" + if err.Error() != expectedErrMsg { + t.Errorf("Expected error message '%s', got: %v", expectedErrMsg, err) + } +} + +func TestAnalyzePRsParallelSuccess(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{ + data: map[string]PRData{ + "https://github.com/owner/repo/pull/1": { + LinesAdded: 50, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, + "https://github.com/owner/repo/pull/2": { + LinesAdded: 100, + Author: "author2", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author2", Kind: "commit"}, + }, + CreatedAt: now.Add(-3 * time.Hour), + ClosedAt: now, + }, + "https://github.com/owner/repo/pull/3": { + LinesAdded: 75, + Author: "author3", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author3", Kind: "commit"}, + }, + CreatedAt: now.Add(-4 * time.Hour), + ClosedAt: now, + }, + }, + fetchDelay: 10 * time.Millisecond, // Small delay to test concurrency + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 3, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 2, // Parallel with 2 workers + } + + result, err := AnalyzePRs(ctx, req) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if len(result.Breakdowns) != 3 { + t.Errorf("Expected 3 breakdowns, got %d", len(result.Breakdowns)) + } + + if result.Skipped != 0 { + t.Errorf("Expected 0 skipped, got %d", result.Skipped) + } + + if fetcher.callCount != 3 { + t.Errorf("Expected 3 fetcher calls, got %d", fetcher.callCount) + } +} + +func TestAnalyzePRsParallelPartialFailure(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{ + data: map[string]PRData{ + "https://github.com/owner/repo/pull/1": { + LinesAdded: 50, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, + }, + failURLs: map[string]error{ + "https://github.com/owner/repo/pull/2": errors.New("fetch failed"), + }, + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 2, + } + + result, err := AnalyzePRs(ctx, req) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if len(result.Breakdowns) != 1 { + t.Errorf("Expected 1 breakdown, got %d", len(result.Breakdowns)) + } + + if result.Skipped != 1 { + t.Errorf("Expected 1 skipped, got %d", result.Skipped) + } +} + +func TestAnalyzePRsParallelAllFail(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{ + failURLs: map[string]error{ + "https://github.com/owner/repo/pull/1": errors.New("fetch failed"), + "https://github.com/owner/repo/pull/2": errors.New("fetch failed"), + }, + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 2, + } + + result, err := AnalyzePRs(ctx, req) + + if err == nil { + t.Error("Expected error when all fetches fail") + } + + if result != nil { + t.Error("Expected nil result when all fetches fail") + } + + expectedErrMsg := "no samples could be processed successfully (2 skipped)" + if err.Error() != expectedErrMsg { + t.Errorf("Expected error message '%s', got: %v", expectedErrMsg, err) + } +} + +func TestAnalyzePRsWithLogger(t *testing.T) { + ctx := context.Background() + now := time.Now() + + // Create a logger that writes to a buffer + var logBuf strings.Builder + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + + fetcher := &mockPRFetcher{ + data: map[string]PRData{ + "https://github.com/owner/repo/pull/1": { + LinesAdded: 50, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, + }, + failURLs: map[string]error{ + "https://github.com/owner/repo/pull/2": errors.New("fetch failed"), + }, + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Logger: logger, + Config: DefaultConfig(), + Concurrency: 1, + } + + result, err := AnalyzePRs(ctx, req) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + logOutput := logBuf.String() + + // Check that processing logs are present + if !strings.Contains(logOutput, "Processing sample PR") { + t.Error("Expected 'Processing sample PR' in log output") + } + + // Check that skip warning is present + if !strings.Contains(logOutput, "Failed to fetch PR data, skipping") { + t.Error("Expected 'Failed to fetch PR data, skipping' in log output") + } + + // Check that progress is logged + if !strings.Contains(logOutput, "1/2") { + t.Error("Expected '1/2' progress in log output") + } + + if !strings.Contains(logOutput, "2/2") { + t.Error("Expected '2/2' progress in log output") + } +} + +func TestAnalyzePRsConcurrencyDefault(t *testing.T) { + ctx := context.Background() + now := time.Now() + + fetcher := &mockPRFetcher{} + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 0, // Should default to sequential (1) + } + + result, err := AnalyzePRs(ctx, req) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if len(result.Breakdowns) != 1 { + t.Errorf("Expected 1 breakdown, got %d", len(result.Breakdowns)) + } +} + +func TestAnalyzePRsContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + now := time.Now() + + fetcher := &mockPRFetcher{ + fetchDelay: 100 * time.Millisecond, // Delay to allow cancellation + } + + req := &AnalysisRequest{ + Samples: []PRSummaryInfo{ + {Owner: "owner", Repo: "repo", Number: 1, UpdatedAt: now}, + {Owner: "owner", Repo: "repo", Number: 2, UpdatedAt: now}, + }, + Fetcher: fetcher, + Config: DefaultConfig(), + Concurrency: 1, + } + + // Cancel context after a short delay + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + result, err := AnalyzePRs(ctx, req) + + // Should either fail completely or have some skipped + if err == nil && result != nil && result.Skipped == 0 { + // This is acceptable if cancellation happened after all fetches + return + } + + // If we got here, either err or skipped should be non-zero + if err == nil && (result == nil || result.Skipped == 0) { + t.Error("Expected context cancellation to affect results") + } +} + +func TestExtrapolateFromSamplesEmpty(t *testing.T) { + cfg := DefaultConfig() + result := ExtrapolateFromSamples([]Breakdown{}, 100, 10, 5, 30, cfg) + + if result.TotalPRs != 100 { + t.Errorf("Expected TotalPRs=100, got %d", result.TotalPRs) + } + + if result.SampledPRs != 0 { + t.Errorf("Expected SampledPRs=0, got %d", result.SampledPRs) + } + + if result.SuccessfulSamples != 0 { + t.Errorf("Expected SuccessfulSamples=0, got %d", result.SuccessfulSamples) + } + + if result.TotalCost != 0 { + t.Errorf("Expected TotalCost=0, got $%.2f", result.TotalCost) + } +} + +func TestExtrapolateFromSamplesSingle(t *testing.T) { + now := time.Now() + cfg := DefaultConfig() + + // Create a single breakdown + breakdown := Calculate(PRData{ + LinesAdded: 100, + Author: "test-author", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "test-author", Kind: "commit"}, + {Timestamp: now.Add(10 * time.Minute), Actor: "reviewer", Kind: "review"}, + }, + CreatedAt: now.Add(-24 * time.Hour), + ClosedAt: now, + }, cfg) + + // Extrapolate from 1 sample to 10 total PRs + result := ExtrapolateFromSamples([]Breakdown{breakdown}, 10, 2, 0, 7, cfg) + + if result.TotalPRs != 10 { + t.Errorf("Expected TotalPRs=10, got %d", result.TotalPRs) + } + + if result.SampledPRs != 1 { + t.Errorf("Expected SampledPRs=1, got %d", result.SampledPRs) + } + + if result.SuccessfulSamples != 1 { + t.Errorf("Expected SuccessfulSamples=1, got %d", result.SuccessfulSamples) + } + + // Total cost should be roughly 10x the single breakdown cost + expectedTotalCost := breakdown.TotalCost * 10 + if result.TotalCost < expectedTotalCost*0.9 || result.TotalCost > expectedTotalCost*1.1 { + t.Errorf("Expected TotalCost≈$%.2f (10x single), got $%.2f", expectedTotalCost, result.TotalCost) + } + + // Check that author cost is extrapolated + if result.AuthorTotalCost <= 0 { + t.Error("Expected positive author total cost") + } + + // Check that participant cost is extrapolated + if result.ParticipantTotalCost <= 0 { + t.Error("Expected positive participant total cost") + } + + // Check unique authors count + if result.UniqueAuthors != 1 { + t.Errorf("Expected 1 unique author, got %d", result.UniqueAuthors) + } +} + +func TestExtrapolateFromSamplesMultiple(t *testing.T) { + now := time.Now() + cfg := DefaultConfig() + + // Create multiple breakdowns with different characteristics + breakdowns := []Breakdown{ + Calculate(PRData{ + LinesAdded: 100, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, cfg), + Calculate(PRData{ + LinesAdded: 200, + Author: "author2", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author2", Kind: "commit"}, + {Timestamp: now.Add(10 * time.Minute), Actor: "reviewer", Kind: "review"}, + }, + CreatedAt: now.Add(-48 * time.Hour), + ClosedAt: now, + }, cfg), + } + + // Extrapolate from 2 samples to 20 total PRs over 14 days + result := ExtrapolateFromSamples(breakdowns, 20, 5, 3, 14, cfg) + + if result.TotalPRs != 20 { + t.Errorf("Expected TotalPRs=20, got %d", result.TotalPRs) + } + + if result.SampledPRs != 2 { + t.Errorf("Expected SampledPRs=2, got %d", result.SampledPRs) + } + + if result.SuccessfulSamples != 2 { + t.Errorf("Expected SuccessfulSamples=2, got %d", result.SuccessfulSamples) + } + + // Check unique authors (should be 2) + if result.UniqueAuthors != 2 { + t.Errorf("Expected 2 unique authors, got %d", result.UniqueAuthors) + } + + // Total cost should be roughly 10x the average breakdown cost + avgCost := (breakdowns[0].TotalCost + breakdowns[1].TotalCost) / 2 + expectedTotalCost := avgCost * 20 + if result.TotalCost < expectedTotalCost*0.9 || result.TotalCost > expectedTotalCost*1.1 { + t.Errorf("Expected TotalCost≈$%.2f, got $%.2f", expectedTotalCost, result.TotalCost) + } + + // Check waste per week calculations (should be > 0 for 14 day period) + if result.WasteHoursPerWeek <= 0 { + t.Error("Expected positive waste hours per week") + } + + if result.WasteCostPerWeek <= 0 { + t.Error("Expected positive waste cost per week") + } + + // Check average PR duration is calculated + if result.AvgPRDurationHours <= 0 { + t.Error("Expected positive average PR duration") + } +} + +func TestExtrapolateFromSamplesBotVsHuman(t *testing.T) { + cfg := DefaultConfig() + + // Create breakdowns with both human and bot PRs + breakdowns := []Breakdown{ + // Human PR + { + PRAuthor: "human-author", + AuthorBot: false, + PRDuration: 24.0, + Author: AuthorCostDetail{ + NewLines: 100, + ModifiedLines: 150, + }, + TotalCost: 1000, + }, + // Bot PR + { + PRAuthor: "dependabot[bot]", + AuthorBot: true, + PRDuration: 2.0, + Author: AuthorCostDetail{ + NewLines: 50, + ModifiedLines: 60, + }, + TotalCost: 100, + }, + } + + result := ExtrapolateFromSamples(breakdowns, 10, 5, 0, 7, cfg) + + // Should have both human and bot PR counts + if result.HumanPRs <= 0 { + t.Error("Expected positive human PR count") + } + + if result.BotPRs <= 0 { + t.Error("Expected positive bot PR count") + } + + // Should have separate duration averages + if result.AvgHumanPRDurationHours <= 0 { + t.Error("Expected positive average human PR duration") + } + + if result.AvgBotPRDurationHours <= 0 { + t.Error("Expected positive average bot PR duration") + } + + // Bot LOC should be tracked separately + if result.BotNewLines <= 0 { + t.Error("Expected positive bot new lines") + } + + if result.BotModifiedLines <= 0 { + t.Error("Expected positive bot modified lines") + } + + // Human authors should only count human PRs + if result.UniqueAuthors != 1 { + t.Errorf("Expected 1 unique human author, got %d", result.UniqueAuthors) + } +} + +func TestExtrapolateFromSamplesWasteCalculation(t *testing.T) { + now := time.Now() + cfg := DefaultConfig() + + // Create a breakdown with significant delay costs + breakdown := Calculate(PRData{ + LinesAdded: 100, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now.Add(-168 * time.Hour), Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-168 * time.Hour), // 7 days old + ClosedAt: now, + }, cfg) + + // Extrapolate over 7 days + result := ExtrapolateFromSamples([]Breakdown{breakdown}, 10, 3, 0, 7, cfg) + + // Waste per week should be calculated + if result.WasteHoursPerWeek <= 0 { + t.Error("Expected positive waste hours per week") + } + + if result.WasteCostPerWeek <= 0 { + t.Error("Expected positive waste cost per week") + } + + // Per-author waste should be calculated + if result.WasteHoursPerAuthorPerWeek <= 0 { + t.Error("Expected positive waste hours per author per week") + } + + if result.WasteCostPerAuthorPerWeek <= 0 { + t.Error("Expected positive waste cost per author per week") + } + + // Waste should be roughly the delay costs + // WastePerWeek = (delay costs) / weeks + expectedWastePerWeek := breakdown.DelayCost * 10 // Extrapolated to 10 PRs, 1 week period + if result.WasteCostPerWeek < expectedWastePerWeek*0.5 || result.WasteCostPerWeek > expectedWastePerWeek*1.5 { + t.Errorf("Expected WasteCostPerWeek≈$%.2f, got $%.2f", expectedWastePerWeek, result.WasteCostPerWeek) + } +} + +func TestExtrapolateFromSamplesR2RSavings(t *testing.T) { + now := time.Now() + cfg := DefaultConfig() + + // Create breakdowns with long PR durations (high waste) + breakdowns := []Breakdown{ + Calculate(PRData{ + LinesAdded: 100, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now.Add(-72 * time.Hour), Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-72 * time.Hour), // 3 days old + ClosedAt: now, + }, cfg), + } + + result := ExtrapolateFromSamples(breakdowns, 100, 10, 5, 30, cfg) + + // R2R savings should be calculated + // Savings formula: baseline waste - remodeled waste - subscription cost + // Should be > 0 if current waste is high enough + if result.R2RSavings < 0 { + t.Error("R2R savings should not be negative") + } + + // For a 3-day PR, there should be significant savings + // (R2R targets 40-minute PRs, which would eliminate most delay costs) + if result.R2RSavings == 0 { + t.Error("Expected positive R2R savings for long-duration PRs") + } + + // UniqueNonBotUsers should be tracked + if result.UniqueNonBotUsers <= 0 { + t.Error("Expected positive unique non-bot users count") + } +} + +func TestExtrapolateFromSamplesOpenPRTracking(t *testing.T) { + now := time.Now() + cfg := DefaultConfig() + + breakdown := Calculate(PRData{ + LinesAdded: 50, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + }, + CreatedAt: now.Add(-1 * time.Hour), + ClosedAt: now, + }, cfg) + + // Test with actual open PRs + actualOpenPRs := 15 + result := ExtrapolateFromSamples([]Breakdown{breakdown}, 100, 5, actualOpenPRs, 30, cfg) + + // Open PRs should match actual count (not extrapolated) + if result.OpenPRs != actualOpenPRs { + t.Errorf("Expected OpenPRs=%d (actual), got %d", actualOpenPRs, result.OpenPRs) + } + + // PR tracking cost should be based on actual open PRs + if result.PRTrackingCost <= 0 { + t.Error("Expected positive PR tracking cost with open PRs") + } + + // PR tracking hours should scale with open PRs and user count + if result.PRTrackingHours <= 0 { + t.Error("Expected positive PR tracking hours") + } +} + +func TestExtrapolateFromSamplesParticipants(t *testing.T) { + now := time.Now() + cfg := DefaultConfig() + + // Create breakdown with multiple participants + breakdown := Calculate(PRData{ + LinesAdded: 100, + Author: "author1", + Events: []ParticipantEvent{ + {Timestamp: now, Actor: "author1", Kind: "commit"}, + {Timestamp: now.Add(10 * time.Minute), Actor: "reviewer1", Kind: "review"}, + {Timestamp: now.Add(20 * time.Minute), Actor: "reviewer2", Kind: "review"}, + {Timestamp: now.Add(30 * time.Minute), Actor: "commenter1", Kind: "comment"}, + }, + CreatedAt: now.Add(-2 * time.Hour), + ClosedAt: now, + }, cfg) + + result := ExtrapolateFromSamples([]Breakdown{breakdown}, 10, 5, 0, 7, cfg) + + // Participant costs should be extrapolated + if result.ParticipantReviewCost <= 0 { + t.Error("Expected positive participant review cost") + } + + if result.ParticipantTotalCost <= 0 { + t.Error("Expected positive participant total cost") + } + + // Participant metrics should be tracked + if result.ParticipantEvents <= 0 { + t.Error("Expected positive participant events count") + } + + if result.ParticipantSessions <= 0 { + t.Error("Expected positive participant sessions count") + } + + // Unique non-bot users should include both authors and participants + if result.UniqueNonBotUsers < 2 { + t.Errorf("Expected at least 2 unique non-bot users (author + reviewers), got %d", result.UniqueNonBotUsers) + } +}