diff --git a/Makefile b/Makefile index befbc55..8e4de09 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,6 @@ +.PHONY: test +test: + go test -race -cover ./... # BEGIN: lint-install . # http://github.com/codeGROOVE-dev/lint-install diff --git a/cmd/prcost/repository.go b/cmd/prcost/repository.go index 898ae7b..4c21cbf 100644 --- a/cmd/prcost/repository.go +++ b/cmd/prcost/repository.go @@ -208,29 +208,24 @@ func analyzeOrganization(ctx context.Context, org string, sampleSize, days int, return nil } -// Ledger formatting functions - all output must use these for consistency +// Ledger formatting functions - all output must use these for consistency. // formatItemLine formats a cost breakdown line item with 4-space indent. -func formatItemLine(label string, cost float64, timeUnit string, detail string) string { - if cost == 0 { +func formatItemLine(label string, amount float64, timeUnit string, detail string) string { + if amount == 0 { return fmt.Sprintf(" %-30s %15s %-6s %s\n", label, "—", timeUnit, detail) } - return fmt.Sprintf(" %-30s $%14s %-6s %s\n", label, formatWithCommas(cost), timeUnit, detail) + return fmt.Sprintf(" %-30s $%14s %-6s %s\n", label, formatWithCommas(amount), timeUnit, detail) } // formatSubtotalLine formats a subtotal line with 4-space indent. -func formatSubtotalLine(label string, cost float64, timeUnit string, detail string) string { - return fmt.Sprintf(" %-30s $%14s %-6s %s\n", label, formatWithCommas(cost), timeUnit, detail) +func formatSubtotalLine(amount float64, timeUnit string, detail string) string { + return fmt.Sprintf(" %-30s $%14s %-6s %s\n", "Subtotal", formatWithCommas(amount), timeUnit, detail) } // formatSummaryLine formats a summary line (like Preventable Loss Total) with 2-space indent. -func formatSummaryLine(label string, cost float64, timeUnit string, detail string) string { - return fmt.Sprintf(" %-30s $%14s %-6s %s\n", label, formatWithCommas(cost), timeUnit, detail) -} - -// formatTotalLine formats a total line with 2-space indent. -func formatTotalLine(label string, cost float64, timeUnit string) string { - return fmt.Sprintf(" %-30s $%14s %-6s\n", label, formatWithCommas(cost), timeUnit) +func formatSummaryLine(label string, amount float64, timeUnit string, detail string) string { + return fmt.Sprintf(" %-30s $%14s %-6s %s\n", label, formatWithCommas(amount), timeUnit, detail) } // formatSectionDivider formats the divider line under subtotals (4-space indent, 32 chars + 14 dashes). @@ -372,7 +367,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea } fmt.Print(formatSectionDivider()) pct := (avgAuthorTotalCost / avgTotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", avgAuthorTotalCost, formatTimeUnit(avgAuthorTotalHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(avgAuthorTotalCost, formatTimeUnit(avgAuthorTotalHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() // Participants section (if any participants) @@ -393,7 +388,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea fmt.Print(formatItemLine("Context Switching", avgParticipantContextCost, formatTimeUnit(avgParticipantContextHours), fmt.Sprintf("(%.1f sessions)", avgParticipantSessions))) fmt.Print(formatSectionDivider()) participantPct := (avgParticipantTotalCost / avgTotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", avgParticipantTotalCost, formatTimeUnit(avgParticipantTotalHours), fmt.Sprintf("(%.1f%%)", participantPct))) + fmt.Print(formatSubtotalLine(avgParticipantTotalCost, formatTimeUnit(avgParticipantTotalHours), fmt.Sprintf("(%.1f%%)", participantPct))) fmt.Println() } @@ -420,7 +415,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea avgMergeDelayHours := avgDeliveryDelayHours + avgAutomatedUpdatesHours + avgPRTrackingHours fmt.Print(formatSectionDivider()) pct = (avgMergeDelayCost / avgTotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", avgMergeDelayCost, formatTimeUnit(avgMergeDelayHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(avgMergeDelayCost, formatTimeUnit(avgMergeDelayHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() // Future Costs section @@ -454,7 +449,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea avgFutureHours := avgCodeChurnHours + avgFutureReviewHours + avgFutureMergeHours + avgFutureContextHours fmt.Print(formatSectionDivider()) pct = (avgFutureCost / avgTotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", avgFutureCost, formatTimeUnit(avgFutureHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(avgFutureCost, formatTimeUnit(avgFutureHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() } @@ -510,7 +505,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea } fmt.Print(formatSectionDivider()) pct = (ext.AuthorTotalCost / ext.TotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", ext.AuthorTotalCost, formatTimeUnit(ext.AuthorTotalHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(ext.AuthorTotalCost, formatTimeUnit(ext.AuthorTotalHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() // Participants section (extrapolated, if any participants) @@ -526,7 +521,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea fmt.Print(formatItemLine("Context Switching", ext.ParticipantContextCost, formatTimeUnit(ext.ParticipantContextHours), fmt.Sprintf("(%d sessions)", ext.ParticipantSessions))) fmt.Print(formatSectionDivider()) pct = (ext.ParticipantTotalCost / ext.TotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", ext.ParticipantTotalCost, formatTimeUnit(ext.ParticipantTotalHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(ext.ParticipantTotalCost, formatTimeUnit(ext.ParticipantTotalHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() } @@ -554,7 +549,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea extMergeDelayHours := ext.DeliveryDelayHours + ext.CodeChurnHours + ext.AutomatedUpdatesHours + ext.PRTrackingHours fmt.Print(formatSectionDivider()) pct = (extMergeDelayCost / ext.TotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", extMergeDelayCost, formatTimeUnit(extMergeDelayHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(extMergeDelayCost, formatTimeUnit(extMergeDelayHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() // Future Costs section (extrapolated) @@ -582,7 +577,7 @@ func printExtrapolatedResults(title string, days int, ext *cost.ExtrapolatedBrea extFutureHours := ext.CodeChurnHours + ext.FutureReviewHours + ext.FutureMergeHours + ext.FutureContextHours fmt.Print(formatSectionDivider()) pct = (extFutureCost / ext.TotalCost) * 100 - fmt.Print(formatSubtotalLine("Subtotal", extFutureCost, formatTimeUnit(extFutureHours), fmt.Sprintf("(%.1f%%)", pct))) + fmt.Print(formatSubtotalLine(extFutureCost, formatTimeUnit(extFutureHours), fmt.Sprintf("(%.1f%%)", pct))) fmt.Println() } diff --git a/internal/server/server.go b/internal/server/server.go index 5866a44..7614745 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -78,6 +78,15 @@ type prQueryCacheEntity struct { QueryKey string `datastore:"query_key"` // Full query key for debugging } +// calcResultCacheEntity represents a cached calculation result in DataStore with TTL. +type calcResultCacheEntity struct { + Data string `datastore:"data,noindex"` // JSON-encoded cost.Breakdown + CachedAt time.Time `datastore:"cached_at"` // When this was cached + ExpiresAt time.Time `datastore:"expires_at"` // When this expires + URL string `datastore:"url"` // PR URL for debugging + ConfigKey string `datastore:"config_key"` // Config hash for debugging +} + // Server handles HTTP requests for the PR Cost API. // //nolint:govet // fieldalignment: struct field ordering optimized for readability over memory @@ -101,10 +110,12 @@ type Server struct { validateTokens bool r2rCallout bool // In-memory caching for PR queries and data. - prQueryCache map[string]*cacheEntry - prDataCache map[string]*cacheEntry - prQueryCacheMu sync.RWMutex - prDataCacheMu sync.RWMutex + prQueryCache map[string]*cacheEntry + prDataCache map[string]*cacheEntry + calcResultCache map[string]*cacheEntry + prQueryCacheMu sync.RWMutex + prDataCacheMu sync.RWMutex + calcResultCacheMu sync.RWMutex // DataStore client for persistent caching (nil if not enabled). dsClient *ds9.Client } @@ -195,16 +206,17 @@ func New() *Server { logger.InfoContext(ctx, "Server initialized with CSRF protection enabled") server := &Server{ - logger: logger, - serverCommit: "", // Will be set via build flags - dataSource: "turnserver", - httpClient: httpClient, - csrfProtection: csrfProtection, - ipLimiters: make(map[string]*rate.Limiter), - rateLimit: DefaultRateLimit, - rateBurst: DefaultRateBurst, - prQueryCache: make(map[string]*cacheEntry), - prDataCache: make(map[string]*cacheEntry), + logger: logger, + serverCommit: "", // Will be set via build flags + dataSource: "turnserver", + httpClient: httpClient, + csrfProtection: csrfProtection, + ipLimiters: make(map[string]*rate.Limiter), + rateLimit: DefaultRateLimit, + rateBurst: DefaultRateBurst, + prQueryCache: make(map[string]*cacheEntry), + prDataCache: make(map[string]*cacheEntry), + calcResultCache: make(map[string]*cacheEntry), } // Load GitHub token at startup and cache in memory for performance and billing. @@ -421,14 +433,14 @@ func (s *Server) cachePRQuery(ctx context.Context, key string, prs []github.PRSu switch { case strings.HasPrefix(key, "repo:"): queryType = "repo" - ttl = 72 * time.Hour // 72 hours for repo queries + ttl = 60 * time.Hour // 60 hours for repo queries case strings.HasPrefix(key, "org:"): queryType = "org" - ttl = 72 * time.Hour // 72 hours for org queries + ttl = 60 * time.Hour // 60 hours for org queries default: s.logger.WarnContext(ctx, "Unknown query type for key, using default TTL", "key", key) queryType = "unknown" - ttl = 72 * time.Hour // Default to 72 hours + ttl = 60 * time.Hour // Default to 60 hours } now := time.Now() @@ -539,6 +551,110 @@ func (s *Server) cachePRData(ctx context.Context, key string, prData cost.PRData s.logger.DebugContext(ctx, "PR data cached to DataStore", "key", key, "expires_at", entity.ExpiresAt) } +// configHash creates a deterministic hash key for a cost.Config. +// Returns a short hash string suitable for use in cache keys. +func configHash(cfg cost.Config) string { + // Create a deterministic string representation of the config + // Use %.2f for floats to avoid floating point precision issues + return fmt.Sprintf("s%.0f_e%.0f_ci%.0f_co%.0f_g%.0f_d%.2f", + cfg.AnnualSalary, + cfg.EventDuration.Minutes(), + cfg.ContextSwitchInDuration.Minutes(), + cfg.ContextSwitchOutDuration.Minutes(), + cfg.SessionGapThreshold.Minutes(), + cfg.DeliveryDelayFactor) +} + +// cachedCalcResult retrieves cached calculation result from memory first, then DataStore as fallback. +func (s *Server) cachedCalcResult(ctx context.Context, prURL string, cfg cost.Config) (cost.Breakdown, bool) { + key := fmt.Sprintf("calc:%s:%s", prURL, configHash(cfg)) + + // Check in-memory cache first (fast path). + s.calcResultCacheMu.RLock() + entry, exists := s.calcResultCache[key] + s.calcResultCacheMu.RUnlock() + + if exists { + breakdown, ok := entry.data.(cost.Breakdown) + if ok { + return breakdown, true + } + } + + // Memory miss - try DataStore if available. + if s.dsClient == nil { + return cost.Breakdown{}, false + } + + dsKey := ds9.NameKey("CalcResultCache", key, nil) + var entity calcResultCacheEntity + err := s.dsClient.Get(ctx, dsKey, &entity) + if err != nil { + if !errors.Is(err, ds9.ErrNoSuchEntity) { + s.logger.WarnContext(ctx, "DataStore calc cache read failed", "key", key, "error", err) + } + return cost.Breakdown{}, false + } + + // Check if expired. + if time.Now().After(entity.ExpiresAt) { + return cost.Breakdown{}, false + } + + // Deserialize the cached data. + var breakdown cost.Breakdown + if err := json.Unmarshal([]byte(entity.Data), &breakdown); err != nil { + s.logger.WarnContext(ctx, "Failed to deserialize cached calc result", "key", key, "error", err) + return cost.Breakdown{}, false + } + + // Populate in-memory cache for faster subsequent access. + s.calcResultCacheMu.Lock() + s.calcResultCache[key] = &cacheEntry{data: breakdown} + s.calcResultCacheMu.Unlock() + + return breakdown, true +} + +// cacheCalcResult stores calculation result in both memory and DataStore caches. +func (s *Server) cacheCalcResult(ctx context.Context, prURL string, cfg cost.Config, b *cost.Breakdown, ttl time.Duration) { + key := fmt.Sprintf("calc:%s:%s", prURL, configHash(cfg)) + + // Write to in-memory cache first (fast path). + s.calcResultCacheMu.Lock() + s.calcResultCache[key] = &cacheEntry{data: *b} + s.calcResultCacheMu.Unlock() + + // Write to DataStore if available (persistent cache). + if s.dsClient == nil { + return + } + + // Serialize the calculation result. + dataJSON, err := json.Marshal(b) + if err != nil { + s.logger.WarnContext(ctx, "Failed to serialize calc result for DataStore", "key", key, "error", err) + return + } + + now := time.Now() + entity := calcResultCacheEntity{ + Data: string(dataJSON), + CachedAt: now, + ExpiresAt: now.Add(ttl), + URL: prURL, + ConfigKey: configHash(cfg), + } + + dsKey := ds9.NameKey("CalcResultCache", key, nil) + if _, err := s.dsClient.Put(ctx, dsKey, &entity); err != nil { + s.logger.WarnContext(ctx, "Failed to write calc result to DataStore", "key", key, "error", err) + return + } + + s.logger.DebugContext(ctx, "Calc result cached to DataStore", "key", key, "ttl", ttl, "expires_at", entity.ExpiresAt) +} + // SetTokenValidation configures GitHub token validation. func (s *Server) SetTokenValidation(appID string, keyFile string) error { keyData, err := os.ReadFile(keyFile) @@ -917,12 +1033,20 @@ func (s *Server) processRequest(ctx context.Context, req *CalculateRequest, toke cfg = s.mergeConfig(cfg, req.Config) } - // Try cache first + // Try calculation result cache first (includes both PR data + calculation) + breakdown, calcCached := s.cachedCalcResult(ctx, req.URL, cfg) + if calcCached { + return &CalculateResponse{ + Breakdown: breakdown, + Timestamp: time.Now(), + Commit: s.serverCommit, + }, nil + } + + // Cache miss - need to fetch PR data and calculate cacheKey := fmt.Sprintf("pr:%s", req.URL) - prData, cached := s.cachedPRData(ctx, cacheKey) - if cached { - s.logger.InfoContext(ctx, "[processRequest] Using cached PR data", "url", req.URL) - } else { + prData, prCached := s.cachedPRData(ctx, cacheKey) + if !prCached { // Fetch PR data using configured data source var err error // For single PR requests, use 1 hour ago as reference time to enable reasonable caching @@ -944,12 +1068,16 @@ func (s *Server) processRequest(ctx context.Context, req *CalculateRequest, toke return nil, fmt.Errorf("failed to fetch PR data: %w", err) } - // Cache PR data + s.logger.InfoContext(ctx, "[processRequest] PR data cache miss - fetched from GitHub", "url", req.URL) + // Cache PR data with 1 hour TTL for direct PR requests s.cachePRData(ctx, cacheKey, prData) } // Calculate costs. - breakdown := cost.Calculate(prData, cfg) + breakdown = cost.Calculate(prData, cfg) + + // Cache the calculation result with 1 hour TTL for direct PR requests + s.cacheCalcResult(ctx, req.URL, cfg, &breakdown, 1*time.Hour) return &CalculateResponse{ Breakdown: breakdown, @@ -1616,6 +1744,9 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to // mergeConfig merges a provided config with defaults. func (*Server) mergeConfig(base cost.Config, override *cost.Config) cost.Config { + if override == nil { + return base + } if override.AnnualSalary > 0 { base.AnnualSalary = override.AnnualSalary } @@ -2195,7 +2326,28 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [ prURL := fmt.Sprintf("https://github.com/%s/%s/pull/%d", owner, repo, prSummary.Number) - // Try cache first + // Try calculation result cache first (includes both PR data + calculation) + breakdown, calcCached := s.cachedCalcResult(workCtx, prURL, cfg) + if calcCached { + // Already have the full calculation result + mu.Lock() + breakdowns = append(breakdowns, breakdown) + mu.Unlock() + + // Send "complete" update using request context for SSE + sseMu.Lock() + logSSEError(reqCtx, s.logger, sendSSE(writer, ProgressUpdate{ + Type: "complete", + PR: prSummary.Number, + Owner: owner, + Repo: repo, + Progress: progress, + })) + sseMu.Unlock() + return + } + + // Cache miss - need to fetch PR data and calculate prCacheKey := fmt.Sprintf("pr:%s", prURL) prData, prCached := s.cachedPRData(workCtx, prCacheKey) if !prCached { @@ -2222,6 +2374,8 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [ return } + s.logger.InfoContext(reqCtx, "PR data cache miss - fetched from GitHub", + "pr_number", prSummary.Number, "owner", owner, "repo", repo) // Cache the PR data s.cachePRData(workCtx, prCacheKey, prData) } @@ -2237,7 +2391,10 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [ })) sseMu.Unlock() - breakdown := cost.Calculate(prData, cfg) + breakdown = cost.Calculate(prData, cfg) + + // Cache the calculation result with 1 week TTL for PRs from queries + s.cacheCalcResult(workCtx, prURL, cfg, &breakdown, 7*24*time.Hour) // Add to results mu.Lock() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index e2ec71c..dc47b60 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -5,8 +5,10 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -709,7 +711,10 @@ func TestHandleCalculateMissingURL(t *testing.T) { s := New() reqBody := CalculateRequest{} // No URL - body, _ := json.Marshal(reqBody) + body, err := json.Marshal(reqBody) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } req := httptest.NewRequest(http.MethodPost, "/v1/calculate", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -757,7 +762,10 @@ func TestHandleRepoSampleMissingFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - body, _ := json.Marshal(tt.body) + body, err := json.Marshal(tt.body) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } req := httptest.NewRequest(http.MethodPost, "/v1/repo-sample", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer ghp_test") @@ -776,7 +784,10 @@ func TestHandleOrgSampleMissingOrg(t *testing.T) { s := New() reqBody := OrgSampleRequest{Days: 30} // Missing Org - body, _ := json.Marshal(reqBody) + body, err := json.Marshal(reqBody) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } req := httptest.NewRequest(http.MethodPost, "/v1/org-sample", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -798,7 +809,10 @@ func TestHandleRepoSampleStreamHeaders(t *testing.T) { Repo: "testrepo", Days: 30, } - body, _ := json.Marshal(reqBody) + body, err := json.Marshal(reqBody) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } req := httptest.NewRequest(http.MethodPost, "/v1/repo-sample-stream", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -832,7 +846,10 @@ func TestHandleOrgSampleStreamHeaders(t *testing.T) { Org: "testorg", Days: 30, } - body, _ := json.Marshal(reqBody) + body, err := json.Marshal(reqBody) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } req := httptest.NewRequest(http.MethodPost, "/v1/org-sample-stream", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -1122,7 +1139,7 @@ func TestCacheConcurrency(t *testing.T) { // Test concurrent writes done := make(chan bool) - for i := 0; i < 10; i++ { + for range 10 { go func() { s.cachePRData(ctx, key, prData) done <- true @@ -1130,12 +1147,12 @@ func TestCacheConcurrency(t *testing.T) { } // Wait for all writes - for i := 0; i < 10; i++ { + for range 10 { <-done } // Test concurrent reads - for i := 0; i < 10; i++ { + for range 10 { go func() { _, _ = s.cachedPRData(ctx, key) done <- true @@ -1143,7 +1160,7 @@ func TestCacheConcurrency(t *testing.T) { } // Wait for all reads - for i := 0; i < 10; i++ { + for range 10 { <-done } @@ -1203,3 +1220,1803 @@ func TestExtractTokenVariations(t *testing.T) { func testContext() context.Context { return context.Background() } + +func TestParseConfigFromQuery(t *testing.T) { + tests := []struct { + name string + queryString string + wantNil bool + wantSalary float64 + wantBenefits float64 + }{ + { + name: "both salary and benefits", + queryString: "salary=300000&benefits=1.5", + wantNil: false, + wantSalary: 300000, + wantBenefits: 1.5, + }, + { + name: "only salary", + queryString: "salary=250000", + wantNil: false, + wantSalary: 250000, + wantBenefits: 0, + }, + { + name: "only benefits", + queryString: "benefits=1.3", + wantNil: false, + wantSalary: 0, + wantBenefits: 1.3, + }, + { + name: "no config params", + queryString: "other=value", + wantNil: true, + }, + { + name: "empty query", + queryString: "", + wantNil: true, + }, + { + name: "invalid salary value", + queryString: "salary=invalid&benefits=1.2", + wantNil: false, + wantSalary: 0, + wantBenefits: 1.2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test?"+tt.queryString, http.NoBody) + query := req.URL.Query() + + cfg := parseConfigFromQuery(query) + + if tt.wantNil { + if cfg != nil { + t.Errorf("parseConfigFromQuery() = %v, want nil", cfg) + } + return + } + + if cfg == nil { + t.Fatal("parseConfigFromQuery() = nil, want non-nil") + } + + if cfg.AnnualSalary != tt.wantSalary { + t.Errorf("AnnualSalary = %v, want %v", cfg.AnnualSalary, tt.wantSalary) + } + if cfg.BenefitsMultiplier != tt.wantBenefits { + t.Errorf("BenefitsMultiplier = %v, want %v", cfg.BenefitsMultiplier, tt.wantBenefits) + } + }) + } +} + +func TestSetR2RCallout(t *testing.T) { + s := New() + + // Test enabling + s.SetR2RCallout(true) + if !s.r2rCallout { + t.Error("SetR2RCallout(true) did not enable r2rCallout") + } + + // Test disabling + s.SetR2RCallout(false) + if s.r2rCallout { + t.Error("SetR2RCallout(false) did not disable r2rCallout") + } +} + +func TestShutdown(t *testing.T) { + s := New() + + // Shutdown should not panic + s.Shutdown() +} + +func TestErrorTypes(t *testing.T) { + t.Run("AccessError", func(t *testing.T) { + err := NewAccessError(http.StatusForbidden, "test error") + if err == nil { + t.Fatal("NewAccessError() returned nil") + } + + expectedMsg := "access error (403): test error" + if err.Error() != expectedMsg { + t.Errorf("Error() = %q, want %q", err.Error(), expectedMsg) + } + + if !IsAccessError(err) { + t.Error("IsAccessError() = false, want true") + } + + // Test with non-access error + regularErr := errors.New("regular error") + if IsAccessError(regularErr) { + t.Error("IsAccessError(regularErr) = true, want false") + } + }) + + t.Run("IsAccessError with different status codes", func(t *testing.T) { + testCases := []struct { + name string + statusCode int + want bool + }{ + {"Forbidden", http.StatusForbidden, true}, + {"Unauthorized", http.StatusUnauthorized, true}, + {"NotFound", http.StatusNotFound, true}, + {"BadRequest", http.StatusBadRequest, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := NewAccessError(tc.statusCode, "test") + got := IsAccessError(err) + if got != tc.want { + t.Errorf("IsAccessError() = %v, want %v", got, tc.want) + } + }) + } + }) + + t.Run("IsAccessError with error strings", func(t *testing.T) { + testCases := []struct { + name string + errMsg string + want bool + }{ + {"Resource not accessible", "Resource not accessible by integration", true}, + {"Not Found", "Not Found", true}, + {"Rate limit", "API rate limit exceeded", true}, + {"Other error", "some other error", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := errors.New(tc.errMsg) + got := IsAccessError(err) + if got != tc.want { + t.Errorf("IsAccessError(%q) = %v, want %v", tc.errMsg, got, tc.want) + } + }) + } + }) +} + +func TestHandleWebUI(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + w := httptest.NewRecorder() + + s.handleWebUI(w, req) + + resp := w.Result() + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("Failed to close response body: %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("handleWebUI() status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != "text/html; charset=utf-8" { + t.Errorf("Content-Type = %q, want %q", contentType, "text/html; charset=utf-8") + } +} + +func TestHandleStatic(t *testing.T) { + s := New() + + tests := []struct { + name string + path string + wantStatus int + }{ + { + name: "root path", + path: "/", + wantStatus: http.StatusNotFound, // Static handler doesn't serve root + }, + { + name: "js file", + path: "/static/app.js", + wantStatus: http.StatusNotFound, // File doesn't exist in test + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tt.path, http.NoBody) + w := httptest.NewRecorder() + + s.handleStatic(w, req) + + resp := w.Result() + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("Failed to close response body: %v", err) + } + }() + + if resp.StatusCode != tt.wantStatus { + t.Errorf("handleStatic(%q) status = %d, want %d", tt.path, resp.StatusCode, tt.wantStatus) + } + }) + } +} + +func TestParseRepoSampleRequest(t *testing.T) { + s := New() + + tests := []struct { + name string + body string + wantErr bool + wantOwner string + wantRepo string + wantDays int + wantSampleSize int + }{ + { + name: "valid request with all fields", + body: `{"owner":"testowner","repo":"testrepo","days":30,"sample_size":10}`, + wantErr: false, + wantOwner: "testowner", + wantRepo: "testrepo", + wantDays: 30, + wantSampleSize: 10, + }, + { + name: "valid request with defaults", + body: `{"owner":"testowner","repo":"testrepo"}`, + wantErr: false, + wantOwner: "testowner", + wantRepo: "testrepo", + wantDays: 90, + wantSampleSize: 30, + }, + { + name: "missing owner", + body: `{"repo":"testrepo"}`, + wantErr: true, + }, + { + name: "missing repo", + body: `{"owner":"testowner"}`, + wantErr: true, + }, + { + name: "invalid json", + body: `{invalid}`, + wantErr: true, + }, + { + name: "custom days and samples", + body: `{"owner":"owner","repo":"repo","days":60,"sample_size":20}`, + wantErr: false, + wantOwner: "owner", + wantRepo: "repo", + wantDays: 60, + wantSampleSize: 20, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/repo-sample", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + + result, err := s.parseRepoSampleRequest(req.Context(), req) + + if tt.wantErr { + if err == nil { + t.Error("parseRepoSampleRequest() expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("parseRepoSampleRequest() unexpected error: %v", err) + } + + if result.Owner != tt.wantOwner { + t.Errorf("Owner = %q, want %q", result.Owner, tt.wantOwner) + } + if result.Repo != tt.wantRepo { + t.Errorf("Repo = %q, want %q", result.Repo, tt.wantRepo) + } + if result.Days != tt.wantDays { + t.Errorf("Days = %d, want %d", result.Days, tt.wantDays) + } + if result.SampleSize != tt.wantSampleSize { + t.Errorf("SampleSize = %d, want %d", result.SampleSize, tt.wantSampleSize) + } + }) + } +} + +func TestParseOrgSampleRequest(t *testing.T) { + s := New() + + tests := []struct { + name string + body string + wantErr bool + wantOrg string + wantDays int + wantSampleSize int + }{ + { + name: "valid request with all fields", + body: `{"org":"testorg","days":30,"sample_size":10}`, + wantErr: false, + wantOrg: "testorg", + wantDays: 30, + wantSampleSize: 10, + }, + { + name: "valid request with defaults", + body: `{"org":"testorg"}`, + wantErr: false, + wantOrg: "testorg", + wantDays: 90, + wantSampleSize: 30, + }, + { + name: "missing org", + body: `{"days":30}`, + wantErr: true, + }, + { + name: "invalid json", + body: `{invalid}`, + wantErr: true, + }, + { + name: "custom days and samples", + body: `{"org":"myorg","days":60,"sample_size":20}`, + wantErr: false, + wantOrg: "myorg", + wantDays: 60, + wantSampleSize: 20, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/org-sample", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + + result, err := s.parseOrgSampleRequest(req.Context(), req) + + if tt.wantErr { + if err == nil { + t.Error("parseOrgSampleRequest() expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("parseOrgSampleRequest() unexpected error: %v", err) + } + + if result.Org != tt.wantOrg { + t.Errorf("Org = %q, want %q", result.Org, tt.wantOrg) + } + if result.Days != tt.wantDays { + t.Errorf("Days = %d, want %d", result.Days, tt.wantDays) + } + if result.SampleSize != tt.wantSampleSize { + t.Errorf("SampleSize = %d, want %d", result.SampleSize, tt.wantSampleSize) + } + }) + } +} + +func TestHandleStaticWithValidFile(t *testing.T) { + s := New() + + // Test with a path that might exist in the embedded FS + req := httptest.NewRequest(http.MethodGet, "/static/index.html", http.NoBody) + w := httptest.NewRecorder() + + s.handleStatic(w, req) + + resp := w.Result() + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("Failed to close response body: %v", err) + } + }() + + // We expect either 200 (file exists) or 404 (file doesn't exist in test) + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound { + t.Errorf("handleStatic() status = %d, want 200 or 404", resp.StatusCode) + } +} + +func TestMergeConfigEdgeCases(t *testing.T) { + s := New() + + tests := []struct { + name string + base cost.Config + override *cost.Config + want cost.Config + }{ + { + name: "override with nil", + base: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + }, + override: nil, + want: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + }, + }, + { + name: "override with zero values", + base: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + }, + override: &cost.Config{}, + want: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + }, + }, + { + name: "override salary only", + base: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + }, + override: &cost.Config{ + AnnualSalary: 300000, + }, + want: cost.Config{ + AnnualSalary: 300000, + BenefitsMultiplier: 1.3, + }, + }, + { + name: "override benefits only", + base: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + }, + override: &cost.Config{ + BenefitsMultiplier: 1.5, + }, + want: cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.5, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := s.mergeConfig(tt.base, tt.override) + if got.AnnualSalary != tt.want.AnnualSalary { + t.Errorf("AnnualSalary = %v, want %v", got.AnnualSalary, tt.want.AnnualSalary) + } + if got.BenefitsMultiplier != tt.want.BenefitsMultiplier { + t.Errorf("BenefitsMultiplier = %v, want %v", got.BenefitsMultiplier, tt.want.BenefitsMultiplier) + } + }) + } +} + +func TestParseRequestPOST(t *testing.T) { + s := New() + + tests := []struct { + name string + body string + wantURL string + wantErr bool + }{ + { + name: "valid JSON", + body: `{"url":"https://github.com/owner/repo/pull/123"}`, + wantURL: "https://github.com/owner/repo/pull/123", + wantErr: false, + }, + { + name: "invalid JSON", + body: `{invalid json}`, + wantErr: true, + }, + { + name: "missing url field", + body: `{"config":{"salary":300000}}`, + wantErr: true, + }, + { + name: "empty url", + body: `{"url":""}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + + result, err := s.parseRequest(req.Context(), req) + + if (err != nil) != tt.wantErr { + t.Errorf("parseRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && result.URL != tt.wantURL { + t.Errorf("parseRequest() URL = %v, want %v", result.URL, tt.wantURL) + } + }) + } +} + +func TestHandleHealthErrorPath(t *testing.T) { + s := New() + + // Create a response writer that fails on write + req := httptest.NewRequest(http.MethodGet, "/health", http.NoBody) + + // Use a normal recorder - we're testing the encode path + w := httptest.NewRecorder() + + s.handleHealth(w, req) + + // Should succeed normally + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } +} + +func TestHandleWebUIErrorPaths(t *testing.T) { + s := New() + + tests := []struct { + name string + path string + wantStatus int + }{ + { + name: "root path", + path: "/", + wantStatus: http.StatusOK, + }, + { + name: "web ui path", + path: "/web", + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tt.path, http.NoBody) + w := httptest.NewRecorder() + + s.handleWebUI(w, req) + + if w.Code != tt.wantStatus { + t.Errorf("Expected status %d, got %d", tt.wantStatus, w.Code) + } + }) + } +} + +func TestTokenFunction(t *testing.T) { + s := New() + ctx := context.Background() + + // Test when fallbackToken is empty + token := s.token(ctx) + + // Token might be from gh CLI or empty + // Just verify the function doesn't crash + _ = token +} + +func TestLimiterCleanup(t *testing.T) { + s := New() + ctx := context.Background() + + // Create many limiters to trigger cleanup + // The cleanup happens at 10001 limiters + for i := range 10005 { + ip := fmt.Sprintf("192.168.1.%d", i) + _ = s.limiter(ctx, ip) + } + + // Should have cleaned up to half + s.ipLimitersMu.RLock() + count := len(s.ipLimiters) + s.ipLimitersMu.RUnlock() + + if count > 10001 { + t.Errorf("Expected limiter cleanup, got %d limiters", count) + } +} + +func TestNewWithDatastoreEnv(t *testing.T) { + // Test with DATASTORE_DB set + t.Setenv("DATASTORE_DB", "test-db-id") + s := New() + if s == nil { + t.Fatal("Expected server to be created") + } + // Note: dsClient might be nil if the client creation fails, but the server should still be created +} + +func TestLogSSEError(t *testing.T) { + s := New() + ctx := context.Background() + + // Test with non-nil error + logSSEError(ctx, s.logger, fmt.Errorf("test error")) + + // Test with nil error + logSSEError(ctx, s.logger, nil) +} + +func TestStartKeepAliveCompletesCoverage(t *testing.T) { + w := httptest.NewRecorder() + + // Start keep alive + stop, errChan := startKeepAlive(w) + + // Let it run briefly + time.Sleep(100 * time.Millisecond) + + // Stop it + close(stop) + + // Check for errors + select { + case err := <-errChan: + if err != nil { + t.Errorf("Unexpected error from startKeepAlive: %v", err) + } + case <-time.After(1 * time.Second): + // No error - success + } +} + +func TestSendSSECoverage(t *testing.T) { + tests := []struct { + name string + update ProgressUpdate + }{ + { + name: "error message", + update: ProgressUpdate{ + Type: "error", + Error: "test error", + }, + }, + { + name: "complete message", + update: ProgressUpdate{ + Type: "complete", + PR: 123, + Owner: "owner", + Repo: "repo", + Result: &cost.ExtrapolatedBreakdown{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + err := sendSSE(w, tt.update) + if err != nil { + t.Errorf("sendSSE() error = %v", err) + } + + if w.Body.Len() == 0 { + t.Error("Expected SSE message to be written") + } + }) + } +} + +func TestMergeConfigAllFields(t *testing.T) { + s := New() + + base := cost.Config{ + AnnualSalary: 250000, + BenefitsMultiplier: 1.3, + HoursPerYear: 2080, + EventDuration: 20 * time.Minute, + ContextSwitchInDuration: 20 * time.Minute, + ContextSwitchOutDuration: 20 * time.Minute, + SessionGapThreshold: 60 * time.Minute, + DeliveryDelayFactor: 0.25, + MaxDelayAfterLastEvent: 30 * 24 * time.Hour, + MaxProjectDelay: 90 * 24 * time.Hour, + MaxCodeDrift: 180 * 24 * time.Hour, + ReviewInspectionRate: 200, + ModificationCostFactor: 1.0, + } + + override := &cost.Config{ + AnnualSalary: 300000, + BenefitsMultiplier: 1.5, + HoursPerYear: 2000, + EventDuration: 30 * time.Minute, + ContextSwitchInDuration: 15 * time.Minute, + ContextSwitchOutDuration: 15 * time.Minute, + SessionGapThreshold: 45 * time.Minute, + DeliveryDelayFactor: 0.3, + MaxDelayAfterLastEvent: 20 * 24 * time.Hour, + MaxProjectDelay: 60 * 24 * time.Hour, + MaxCodeDrift: 120 * 24 * time.Hour, + ReviewInspectionRate: 250, + ModificationCostFactor: 1.2, + } + + result := s.mergeConfig(base, override) + + // Verify all fields were overridden + if result.AnnualSalary != 300000 { + t.Errorf("Expected AnnualSalary 300000, got %v", result.AnnualSalary) + } + if result.BenefitsMultiplier != 1.5 { + t.Errorf("Expected BenefitsMultiplier 1.5, got %v", result.BenefitsMultiplier) + } + if result.HoursPerYear != 2000 { + t.Errorf("Expected HoursPerYear 2000, got %v", result.HoursPerYear) + } + if result.EventDuration != 30*time.Minute { + t.Errorf("Expected EventDuration 30m, got %v", result.EventDuration) + } + if result.ReviewInspectionRate != 250 { + t.Errorf("Expected ReviewInspectionRate 250, got %v", result.ReviewInspectionRate) + } + if result.ModificationCostFactor != 1.2 { + t.Errorf("Expected ModificationCostFactor 1.2, got %v", result.ModificationCostFactor) + } +} + +func TestProcessRequestWithMock(t *testing.T) { + s := New() + ctx := context.Background() + + // Create a mock PR data + mockData := newMockPRData("test-author", 150, 3) + + // Store in cache to simulate successful fetch + s.prDataCacheMu.Lock() + s.prDataCache["https://github.com/test/repo/pull/123"] = &cacheEntry{ + data: mockData, + } + s.prDataCacheMu.Unlock() + + req := &CalculateRequest{ + URL: "https://github.com/test/repo/pull/123", + } + + // This will fail because we can't fully mock the GitHub client + // but it will exercise more code paths + _, err := s.processRequest(ctx, req, "fake-token") + + // We expect an error because we don't have a real GitHub client + // but this still exercises the code + _ = err +} + +func TestCachedPRQueryHit(t *testing.T) { + s := New() + ctx := context.Background() + + // Pre-populate cache + testPRs := newMockPRSummaries(5) + key := "repo:owner/repo:30" + + s.prQueryCacheMu.Lock() + s.prQueryCache[key] = &cacheEntry{data: testPRs} + s.prQueryCacheMu.Unlock() + + // Test cache hit + prs, found := s.cachedPRQuery(ctx, key) + if !found { + t.Error("Expected cache hit") + } + if len(prs) != 5 { + t.Errorf("Expected 5 PRs, got %d", len(prs)) + } +} + +func TestCachedPRQueryMiss(t *testing.T) { + s := New() + ctx := context.Background() + + // Test cache miss + _, found := s.cachedPRQuery(ctx, "nonexistent-key") + if found { + t.Error("Expected cache miss") + } +} + +func TestCachedPRDataHit(t *testing.T) { + s := New() + ctx := context.Background() + + // Pre-populate cache + testData := newMockPRData("test-author", 100, 5) + key := "https://github.com/owner/repo/pull/123" + + s.prDataCacheMu.Lock() + s.prDataCache[key] = &cacheEntry{data: *testData} + s.prDataCacheMu.Unlock() + + // Test cache hit + data, found := s.cachedPRData(ctx, key) + if !found { + t.Error("Expected cache hit") + } + if data.Author != "test-author" { + t.Errorf("Expected author test-author, got %s", data.Author) + } +} + +func TestCachePRQuery(t *testing.T) { + s := New() + ctx := context.Background() + + testPRs := newMockPRSummaries(3) + key := "repo:owner/repo:30" + + // Cache the data + s.cachePRQuery(ctx, key, testPRs) + + // Verify it was cached + s.prQueryCacheMu.RLock() + entry, exists := s.prQueryCache[key] + s.prQueryCacheMu.RUnlock() + + if !exists { + t.Error("Expected data to be cached") + } + + if cached, ok := entry.data.([]github.PRSummary); ok { + if len(cached) != 3 { + t.Errorf("Expected 3 cached PRs, got %d", len(cached)) + } + } else { + t.Error("Cached data is not []github.PRSummary") + } +} + +func TestCachePRData(t *testing.T) { + s := New() + ctx := context.Background() + + testData := newMockPRData("author", 200, 4) + key := "https://github.com/owner/repo/pull/456" + + // Cache the data + s.cachePRData(ctx, key, *testData) + + // Verify it was cached + s.prDataCacheMu.RLock() + entry, exists := s.prDataCache[key] + s.prDataCacheMu.RUnlock() + + if !exists { + t.Error("Expected data to be cached") + } + + if cached, ok := entry.data.(cost.PRData); ok { + if cached.Author != "author" { + t.Errorf("Expected author 'author', got %s", cached.Author) + } + if cached.LinesAdded != 200 { + t.Errorf("Expected 200 lines, got %d", cached.LinesAdded) + } + } else { + t.Error("Cached data is not cost.PRData") + } +} + +func TestHandleRepoSampleRateLimitExceeded(t *testing.T) { + s := New() + s.SetRateLimit(1, 1) // Very low rate limit + + // Consume the rate limit + ctx := context.Background() + limiter := s.limiter(ctx, "test-ip") + limiter.Allow() + + // Make second request that should be rate limited + req2 := httptest.NewRequest(http.MethodPost, "/api/sample/repo", strings.NewReader(`{"owner":"test","repo":"repo","days":30}`)) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("X-Forwarded-For", "test-ip") + w2 := httptest.NewRecorder() + + s.handleRepoSample(w2, req2) + + if w2.Code != http.StatusTooManyRequests { + t.Errorf("Expected status 429, got %d", w2.Code) + } +} + +func TestHandleOrgSampleBadRequest(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodPost, "/api/sample/org", strings.NewReader(`{invalid json}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + s.handleOrgSample(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", w.Code) + } +} + +func TestParseRepoSampleRequestMissingFields(t *testing.T) { + s := New() + + tests := []struct { + name string + body string + }{ + { + name: "missing owner", + body: `{"repo":"test","days":30}`, + }, + { + name: "missing repo", + body: `{"owner":"test","days":30}`, + }, + { + name: "negative days", + body: `{"owner":"test","repo":"repo","days":-1}`, + }, + { + name: "days too large", + body: `{"owner":"test","repo":"repo","days":400}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + + _, err := s.parseRepoSampleRequest(req.Context(), req) + if err == nil { + t.Error("Expected error for invalid request") + } + }) + } +} + +func TestParseOrgSampleRequestValidation(t *testing.T) { + s := New() + + tests := []struct { + name string + body string + wantErr bool + }{ + { + name: "valid request", + body: `{"org":"test-org","days":30,"sample_size":10}`, + wantErr: false, + }, + { + name: "missing org", + body: `{"days":30}`, + wantErr: true, + }, + { + name: "sample size zero", + body: `{"org":"test","days":30,"sample_size":0}`, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + + _, err := s.parseOrgSampleRequest(req.Context(), req) + if (err != nil) != tt.wantErr { + t.Errorf("parseOrgSampleRequest() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestServeHTTPCORSPreflightSameOrigin(t *testing.T) { + s := New() + + // Test preflight without Sec-Fetch-Site (same-origin request) + req := httptest.NewRequest(http.MethodOptions, "/api/calculate", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + // Should allow same-origin + if w.Code == http.StatusForbidden { + t.Error("Should allow same-origin preflight") + } +} + +func TestServeHTTPCSRFProtection(t *testing.T) { + s := New() + + // Test POST without proper CSRF headers + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(`{"url":"https://github.com/owner/repo/pull/123"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Sec-Fetch-Site", "cross-site") + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + // Should be blocked by CSRF protection + if w.Code != http.StatusForbidden { + t.Errorf("Expected CSRF protection to block request, got status %d", w.Code) + } +} + +func TestHandleCalculateWithXForwardedFor(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/api/calculate?url=https://github.com/owner/repo/pull/123", http.NoBody) + req.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8") + w := httptest.NewRecorder() + + s.handleCalculate(w, req) + + // Just verify it doesn't crash with X-Forwarded-For header parsing + _ = w.Code +} + +func TestValidateGitHubTokenSuccess(t *testing.T) { + s := New() + ctx := context.Background() + + // This will test the token validation logic + // It will likely fail without a valid token, but exercises the code + err := s.validateGitHubToken(ctx, "fake-token") + // We don't assert on the result since we can't mock the GitHub API easily + _ = err +} + +func TestHandleRepoSampleWithAuth(t *testing.T) { + s := New() + + reqBody := `{"owner":"test","repo":"test","days":30,"sample_size":10}` + req := httptest.NewRequest(http.MethodPost, "/api/sample/repo", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + s.handleRepoSample(w, req) + + // Will fail without real GitHub access, but exercises auth extraction + _ = w.Code +} + +func TestHandleOrgSampleWithAuth(t *testing.T) { + s := New() + + reqBody := `{"org":"test-org","days":30,"sample_size":10}` + req := httptest.NewRequest(http.MethodPost, "/api/sample/org", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + + s.handleOrgSample(w, req) + + // Will fail without real GitHub access, but exercises auth extraction + _ = w.Code +} + +func TestServeHTTPRoutingCalculate(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/api/calculate?url=https://github.com/owner/repo/pull/1", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + // Just verify routing works (will fail on actual processing) + _ = w.Code +} + +func TestServeHTTPRoutingHealth(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/health", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 for /health, got %d", w.Code) + } +} + +func TestServeHTTPRoutingWebUI(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + // Should return web UI + _ = w.Code +} + +func TestServeHTTPRoutingStatic(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/static/test.css", http.NoBody) + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + // Will 404 if file doesn't exist, but exercises routing + _ = w.Code +} + +func TestHandleRepoSampleStreamWithInvalidRequest(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodPost, "/api/sample/repo/stream", strings.NewReader(`{invalid}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + s.handleRepoSampleStream(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for invalid request, got %d", w.Code) + } +} + +func TestHandleOrgSampleStreamWithInvalidRequest(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodPost, "/api/sample/org/stream", strings.NewReader(`{invalid}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + s.handleOrgSampleStream(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for invalid request, got %d", w.Code) + } +} + +func TestNewWithAllDefaults(t *testing.T) { + s := New() + + // Verify defaults are set correctly + if s.rateLimit != DefaultRateLimit { + t.Errorf("Expected rate limit %d, got %d", DefaultRateLimit, s.rateLimit) + } + if s.rateBurst != DefaultRateBurst { + t.Errorf("Expected rate burst %d, got %d", DefaultRateBurst, s.rateBurst) + } + if s.dataSource != "turnserver" { + t.Errorf("Expected data source 'turnserver', got %s", s.dataSource) + } +} + +func TestIsOriginAllowedWithWildcard(t *testing.T) { + s := New() + + // Set allowed origins with wildcard + s.SetCORSConfig("https://*.example.com", false) + + tests := []struct { + name string + origin string + want bool + }{ + { + name: "wildcard subdomain match", + origin: "https://app.example.com", + want: true, + }, + { + name: "wildcard deep subdomain match", + origin: "https://api.app.example.com", + want: true, + }, + { + name: "no match different domain", + origin: "https://evil.com", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := s.isOriginAllowed(tt.origin) + if got != tt.want { + t.Errorf("isOriginAllowed(%q) = %v, want %v", tt.origin, got, tt.want) + } + }) + } +} + +func TestParseRequestWithConfigOverride(t *testing.T) { + s := New() + + reqBody := `{ + "url": "https://github.com/owner/repo/pull/123", + "config": { + "AnnualSalary": 300000, + "BenefitsMultiplier": 1.4 + } + }` + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + + result, err := s.parseRequest(req.Context(), req) + if err != nil { + t.Fatalf("parseRequest() error = %v", err) + } + + if result.Config == nil { + t.Fatal("Expected config to be set") + } + if result.Config.AnnualSalary != 300000 { + t.Errorf("Expected salary 300000, got %v", result.Config.AnnualSalary) + } + if result.Config.BenefitsMultiplier != 1.4 { + t.Errorf("Expected benefits 1.4, got %v", result.Config.BenefitsMultiplier) + } +} + +func TestHandleStaticNotFound(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/static/nonexistent.js", http.NoBody) + w := httptest.NewRecorder() + + s.handleStatic(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected 404 for nonexistent file, got %d", w.Code) + } +} + +func TestHandleStaticEmptyPath(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/static/", http.NoBody) + w := httptest.NewRecorder() + + s.handleStatic(w, req) + + // Should handle empty path gracefully + _ = w.Code +} + +func TestSanitizeErrorWithTokens(t *testing.T) { + tests := []struct { + name string + input error + want string + }{ + { + name: "error with ghp token", + input: fmt.Errorf("failed with ghp_123456789012345678901234567890123456"), + want: "failed with [REDACTED_TOKEN]", + }, + { + name: "error with gho token", + input: fmt.Errorf("auth failed: gho_abcdef123456abcdef123456abcdef123456"), + want: "auth failed: [REDACTED_TOKEN]", + }, + { + name: "error without token", + input: fmt.Errorf("regular error message"), + want: "regular error message", + }, + { + name: "nil error", + input: nil, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeError(tt.input) + if got != tt.want { + t.Errorf("sanitizeError() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestCachePRQueryMemoryWrite(t *testing.T) { + s := New() + ctx := context.Background() + testPRs := newMockPRSummaries(3) + key := "test-cache-key" + + s.cachePRQuery(ctx, key, testPRs) + + // Verify it was cached in memory + prs, found := s.cachedPRQuery(ctx, key) + if !found { + t.Fatal("Expected cache entry to be found") + } + if len(prs) != 3 { + t.Errorf("Expected 3 PRs, got %d", len(prs)) + } +} + +func TestCachePRDataMemoryWrite(t *testing.T) { + s := New() + ctx := context.Background() + testData := newMockPRData("test-author", 100, 5) + key := "pr:owner/repo:123" + + s.cachePRData(ctx, key, *testData) + + // Verify it was cached + data, found := s.cachedPRData(ctx, key) + if !found { + t.Fatal("Expected cache entry to be found") + } + if data.Author != "test-author" { + t.Errorf("Expected author 'test-author', got %s", data.Author) + } + if data.LinesAdded != 100 { + t.Errorf("Expected 100 lines added, got %d", data.LinesAdded) + } +} + +func TestCachedPRDataMissCache(t *testing.T) { + s := New() + ctx := context.Background() + + _, found := s.cachedPRData(ctx, "nonexistent-key") + if found { + t.Error("Expected cache miss for nonexistent key") + } +} + +func TestCachedPRQueryBadType(t *testing.T) { + s := New() + ctx := context.Background() + key := "bad-type-key" + + // Store wrong type in cache + s.prQueryCacheMu.Lock() + s.prQueryCache[key] = &cacheEntry{data: "not a PR summary slice"} + s.prQueryCacheMu.Unlock() + + _, found := s.cachedPRQuery(ctx, key) + if found { + t.Error("Expected cache miss for wrong type") + } +} + +func TestCachedPRDataBadType(t *testing.T) { + s := New() + ctx := context.Background() + key := "bad-type-key" + + // Store wrong type in cache + s.prDataCacheMu.Lock() + s.prDataCache[key] = &cacheEntry{data: "not a PRData"} + s.prDataCacheMu.Unlock() + + _, found := s.cachedPRData(ctx, key) + if found { + t.Error("Expected cache miss for wrong type") + } +} + +func TestLimiterCleanupLarge(t *testing.T) { + s := New() + ctx := context.Background() + + // Add more than 10000 limiters to trigger cleanup + for i := range 10500 { + ip := fmt.Sprintf("192.168.1.%d", i%256) + fmt.Sprintf(".%d", i/256) + _ = s.limiter(ctx, ip) + } + + // Should have triggered cleanup + s.ipLimitersMu.RLock() + count := len(s.ipLimiters) + s.ipLimitersMu.RUnlock() + + if count > 10000 { + t.Errorf("Expected limiter cleanup, but have %d limiters", count) + } +} + +func TestTokenFallbackReturns(t *testing.T) { + s := New() + ctx := context.Background() + + // token() returns the fallback token if set (may be set by gh auth token at startup) + token := s.token(ctx) + // Just verify it returns without error - gh auth token may have set a fallback + _ = token +} + +func TestTokenFallbackExplicitSet(t *testing.T) { + s := New() + ctx := context.Background() + + // Manually set a fallback token + s.fallbackTokenMu.Lock() + s.fallbackToken = "ghp_abcdefghijklmnopqrstuvwxyz1234567890" + s.fallbackTokenMu.Unlock() + + token := s.token(ctx) + if len(token) != 40 { + t.Errorf("Expected 40-char token, got %d chars: %s", len(token), token) + } +} + +func TestSetTokenValidationWithInvalidKeyFile(t *testing.T) { + s := New() + + err := s.SetTokenValidation("test-app-id", "/nonexistent/key/file.pem") + if err == nil { + t.Error("Expected error for nonexistent key file") + } +} + +func TestParseRepoSampleRequestValidDays(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/repo/sample?owner=test&repo=test&days=30", http.NoBody) + + result, err := s.parseRepoSampleRequest(ctx, req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if result.Days != 30 { + t.Errorf("Expected days=30, got %d", result.Days) + } +} + +func TestParseRepoSampleRequestMissingOwner(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/repo/sample?repo=test", http.NoBody) + + _, err := s.parseRepoSampleRequest(ctx, req) + if err == nil { + t.Error("Expected error for missing owner parameter") + } +} + +func TestParseRepoSampleRequestMissingRepo(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/repo/sample?owner=test", http.NoBody) + + _, err := s.parseRepoSampleRequest(ctx, req) + if err == nil { + t.Error("Expected error for missing repo parameter") + } +} + +func TestParseOrgSampleRequestValidDays(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/org/sample?org=test&days=30", http.NoBody) + + result, err := s.parseOrgSampleRequest(ctx, req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if result.Days != 30 { + t.Errorf("Expected days=30, got %d", result.Days) + } +} + +func TestParseOrgSampleRequestMissingOrg(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/org/sample", http.NoBody) + + _, err := s.parseOrgSampleRequest(ctx, req) + if err == nil { + t.Error("Expected error for missing org parameter") + } +} + +func TestHandleCalculateWithMalformedJSON(t *testing.T) { + s := New() + + malformedJSON := `{"pr_url": "invalid json` + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(malformedJSON)) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("user", "ghp_123456789012345678901234567890123456") + w := httptest.NewRecorder() + + s.handleCalculate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for malformed JSON, got %d", w.Code) + } +} + +func TestHandleCalculateWithInvalidPRURL(t *testing.T) { + s := New() + + jsonBody := `{"pr_url": "not-a-github-url"}` + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("user", "ghp_123456789012345678901234567890123456") + w := httptest.NewRecorder() + + s.handleCalculate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for invalid PR URL, got %d", w.Code) + } +} + +func TestHandleWebUIWithQueryParams(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/web?pr_url=https://github.com/owner/repo/pull/123", http.NoBody) + w := httptest.NewRecorder() + + s.handleWebUI(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "owner/repo") { + t.Error("Expected response to contain owner/repo reference") + } +} + +func TestServeHTTPWithCSRFProtection(t *testing.T) { + s := New() + s.SetCORSConfig("https://example.com", false) + + // POST request from cross-origin should be blocked + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(`{}`)) + req.Header.Set("Origin", "https://malicious.com") + req.Header.Set("Sec-Fetch-Site", "cross-site") + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + // Should be blocked by CSRF protection + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 for cross-origin POST, got %d", w.Code) + } +} + +func TestValidateGitHubPRURLMoreCases(t *testing.T) { + s := New() + + tests := []struct { + name string + url string + wantErr bool + }{ + { + name: "valid URL", + url: "https://github.com/owner/repo/pull/123", + wantErr: false, + }, + { + name: "URL too long", + url: "https://github.com/" + strings.Repeat("a", 250) + "/repo/pull/123", + wantErr: true, + }, + { + name: "non-github domain", + url: "https://gitlab.com/owner/repo/pull/123", + wantErr: true, + }, + { + name: "http not https", + url: "http://github.com/owner/repo/pull/123", + wantErr: true, + }, + { + name: "URL with credentials", + url: "https://user:pass@github.com/owner/repo/pull/123", + wantErr: true, + }, + { + name: "URL with query params", + url: "https://github.com/owner/repo/pull/123?tab=files", + wantErr: true, + }, + { + name: "URL with fragment", + url: "https://github.com/owner/repo/pull/123#discussion", + wantErr: true, + }, + { + name: "invalid path format", + url: "https://github.com/owner/repo/issues/123", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := s.validateGitHubPRURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("validateGitHubPRURL() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestParseConfigFromQueryWithValues(t *testing.T) { + query := url.Values{} + query.Set("salary", "150000") + query.Set("benefits", "1.3") + + cfg := parseConfigFromQuery(query) + if cfg == nil { + t.Fatal("Expected config, got nil") + } + if cfg.AnnualSalary != 150000 { + t.Errorf("Expected salary 150000, got %f", cfg.AnnualSalary) + } + if cfg.BenefitsMultiplier != 1.3 { + t.Errorf("Expected benefits 1.3, got %f", cfg.BenefitsMultiplier) + } +} + +func TestParseConfigFromQueryEmpty(t *testing.T) { + query := url.Values{} + + cfg := parseConfigFromQuery(query) + if cfg != nil { + t.Errorf("Expected nil for empty query, got %+v", cfg) + } +} + +func TestParseConfigFromQueryInvalidValues(t *testing.T) { + query := url.Values{} + query.Set("salary", "not-a-number") + query.Set("benefits", "invalid") + + cfg := parseConfigFromQuery(query) + if cfg == nil { + t.Fatal("Expected config struct with zero values") + } + if cfg.AnnualSalary != 0 { + t.Errorf("Expected salary 0 for invalid input, got %f", cfg.AnnualSalary) + } +} + +func TestHandleCalculateWithMaxBytesReader(t *testing.T) { + s := New() + + // Create a very large request body (> 1MB) + largeBody := strings.Repeat("x", 2<<20) // 2MB + req := httptest.NewRequest(http.MethodPost, "/api/calculate", strings.NewReader(largeBody)) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("user", "ghp_123456789012345678901234567890123456") + w := httptest.NewRecorder() + + s.handleCalculate(w, req) + + // Should reject large request + if w.Code == http.StatusOK { + t.Error("Expected error for oversized request body") + } +} + +func TestHandleWebUIWithoutPRParam(t *testing.T) { + s := New() + + req := httptest.NewRequest(http.MethodGet, "/web", http.NoBody) + w := httptest.NewRecorder() + + s.handleWebUI(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 for /web without params, got %d", w.Code) + } +} + +func TestServeHTTPWithAllowAllCORS(t *testing.T) { + s := New() + s.SetCORSConfig("", true) + + req := httptest.NewRequest(http.MethodGet, "/health", http.NoBody) + req.Header.Set("Origin", "https://example.com") + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + + // Check CORS header was set + if w.Header().Get("Access-Control-Allow-Origin") != "https://example.com" { + t.Error("Expected CORS header to be set") + } +} + +func TestServeHTTPWithOPTIONS(t *testing.T) { + s := New() + s.SetCORSConfig("https://example.com", false) + + req := httptest.NewRequest(http.MethodOptions, "/api/calculate", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + w := httptest.NewRecorder() + + s.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Errorf("Expected 204 for OPTIONS, got %d", w.Code) + } +} + +func TestParseRepoSampleRequestWithConfigQuery(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/repo/sample?owner=test&repo=test&salary=200000&benefits=1.5", http.NoBody) + + result, err := s.parseRepoSampleRequest(ctx, req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if result.Config == nil { + t.Fatal("Expected config to be set") + } + if result.Config.AnnualSalary != 200000 { + t.Errorf("Expected salary 200000, got %f", result.Config.AnnualSalary) + } +} + +func TestParseOrgSampleRequestWithSample(t *testing.T) { + s := New() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/api/org/sample?org=test&sample=50", http.NoBody) + + result, err := s.parseOrgSampleRequest(ctx, req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if result.SampleSize != 50 { + t.Errorf("Expected sample size 50, got %d", result.SampleSize) + } +} diff --git a/internal/server/server_test_mocks.go b/internal/server/server_test_mocks.go new file mode 100644 index 0000000..b7b033f --- /dev/null +++ b/internal/server/server_test_mocks.go @@ -0,0 +1,41 @@ +package server + +import ( + "fmt" + "time" + + "github.com/codeGROOVE-dev/prcost/pkg/cost" + "github.com/codeGROOVE-dev/prcost/pkg/github" +) + +// Helper functions to create test data. +func newMockPRData(author string, linesAdded int, eventCount int) *cost.PRData { + events := make([]cost.ParticipantEvent, eventCount) + baseTime := time.Now().Add(-24 * time.Hour) + for i := range eventCount { + events[i] = cost.ParticipantEvent{ + Actor: fmt.Sprintf("actor%d", i), + Timestamp: baseTime.Add(time.Duration(i) * time.Hour), + } + } + return &cost.PRData{ + Author: author, + LinesAdded: linesAdded, + CreatedAt: baseTime, + Events: events, + } +} + +func newMockPRSummaries(count int) []github.PRSummary { + summaries := make([]github.PRSummary, count) + for i := range count { + summaries[i] = github.PRSummary{ + Number: i + 1, + Owner: "test-owner", + Repo: "test-repo", + Author: fmt.Sprintf("author%d", i), + UpdatedAt: time.Now().Add(-time.Duration(i) * time.Hour), + } + } + return summaries +} diff --git a/pkg/cost/analyze.go b/pkg/cost/analyze.go index c109a9a..e0062a4 100644 --- a/pkg/cost/analyze.go +++ b/pkg/cost/analyze.go @@ -22,12 +22,14 @@ type PRFetcher interface { } // AnalysisRequest contains parameters for analyzing a set of PRs. +// +//nolint:govet // fieldalignment: struct field order optimized for API clarity type AnalysisRequest struct { + Fetcher PRFetcher // PR data fetcher + Config Config // Cost calculation configuration Samples []PRSummaryInfo // PRs to analyze Logger *slog.Logger // Optional logger for progress - Fetcher PRFetcher // PR data fetcher Concurrency int // Number of concurrent fetches (0 = sequential) - Config Config // Cost calculation configuration } // PRSummaryInfo contains basic PR information needed for fetching. diff --git a/pkg/cost/cost_test.go b/pkg/cost/cost_test.go index 6a27a62..bcf2fe5 100644 --- a/pkg/cost/cost_test.go +++ b/pkg/cost/cost_test.go @@ -843,7 +843,6 @@ func TestAnalyzePRsSequentialSuccess(t *testing.T) { } result, err := AnalyzePRs(ctx, req) - if err != nil { t.Fatalf("Expected no error, got: %v", err) } @@ -897,7 +896,6 @@ func TestAnalyzePRsSequentialPartialFailure(t *testing.T) { } result, err := AnalyzePRs(ctx, req) - if err != nil { t.Fatalf("Expected no error, got: %v", err) } @@ -1001,7 +999,6 @@ func TestAnalyzePRsParallelSuccess(t *testing.T) { } result, err := AnalyzePRs(ctx, req) - if err != nil { t.Fatalf("Expected no error, got: %v", err) } @@ -1055,7 +1052,6 @@ func TestAnalyzePRsParallelPartialFailure(t *testing.T) { } result, err := AnalyzePRs(ctx, req) - if err != nil { t.Fatalf("Expected no error, got: %v", err) } @@ -1149,7 +1145,6 @@ func TestAnalyzePRsWithLogger(t *testing.T) { } result, err := AnalyzePRs(ctx, req) - if err != nil { t.Fatalf("Expected no error, got: %v", err) } @@ -1196,7 +1191,6 @@ func TestAnalyzePRsConcurrencyDefault(t *testing.T) { } result, err := AnalyzePRs(ctx, req) - if err != nil { t.Fatalf("Expected no error, got: %v", err) } diff --git a/pkg/cost/extrapolate.go b/pkg/cost/extrapolate.go index d119f6e..c836e46 100644 --- a/pkg/cost/extrapolate.go +++ b/pkg/cost/extrapolate.go @@ -114,6 +114,8 @@ type ExtrapolatedBreakdown struct { // // The function computes the average cost per PR from the samples, then multiplies // by the total PR count to estimate population-wide costs. +// +//nolint:revive,maintidx // Complex calculation function benefits from cohesion func ExtrapolateFromSamples(breakdowns []Breakdown, totalPRs, totalAuthors, actualOpenPRs int, daysInPeriod int, cfg Config) ExtrapolatedBreakdown { if len(breakdowns) == 0 { return ExtrapolatedBreakdown{ @@ -380,7 +382,8 @@ func ExtrapolateFromSamples(breakdowns []Breakdown, totalPRs, totalAuthors, actu // Formula: baseline annual waste - (re-modeled waste with 40min PRs) - (R2R subscription cost) // Baseline annual waste: preventable cost extrapolated to 52 weeks // uniqueUserCount already defined above for PR tracking calculation - baselineAnnualWaste := (extCodeChurnCost + extDeliveryDelayCost + extAutomatedUpdatesCost + extPRTrackingCost) * (52.0 / (float64(daysInPeriod) / 7.0)) + preventableCost := extCodeChurnCost + extDeliveryDelayCost + extAutomatedUpdatesCost + extPRTrackingCost + baselineAnnualWaste := preventableCost * (52.0 / (float64(daysInPeriod) / 7.0)) // Re-model with 40-minute PR merge times // We need to recalculate delivery delay and future costs assuming all PRs take 40 minutes (2/3 hour) diff --git a/pkg/github/fetch_test.go b/pkg/github/fetch_test.go index e998757..b28001b 100644 --- a/pkg/github/fetch_test.go +++ b/pkg/github/fetch_test.go @@ -3,6 +3,7 @@ package github import ( "encoding/json" "os" + "strings" "testing" "time" @@ -275,3 +276,235 @@ func TestPRDataFromPRXWithRealData(t *testing.T) { t.Logf("PR 1891: %d human events out of %d total events", len(costData.Events), len(prxData.Events)) } + +func TestGetCacheDir(t *testing.T) { + dir, err := getCacheDir() + if err != nil { + t.Fatalf("getCacheDir() error = %v", err) + } + if dir == "" { + t.Error("getCacheDir() returned empty string") + } + + // Should contain prcost in the path + if !strings.Contains(dir, "prcost") { + t.Errorf("getCacheDir() = %q, expected to contain 'prcost'", dir) + } +} + +func TestIsCommonBot(t *testing.T) { + tests := []struct { + name string + username string + want bool + }{ + {"dependabot", "dependabot[bot]", true}, + {"renovate", "renovate-bot", true}, + {"github-actions", "github-actions", true}, + {"codecov", "codecov-commenter", true}, + {"greenkeeper", "greenkeeper[bot]", true}, + {"snyk", "snyk-bot", true}, + {"allcontributors", "allcontributors[bot]", true}, + {"imgbot", "ImgBot", true}, // Case insensitive + {"stalebot", "stalebot", true}, + {"mergify", "mergify[bot]", true}, + {"netlify", "netlify[bot]", true}, + {"vercel", "vercel[bot]", true}, + {"codefactor", "codefactor-io", true}, + {"deepsource", "deepsource-autofix[bot]", true}, + {"pre-commit", "pre-commit-ci[bot]", true}, + {"regular user", "john-doe", false}, + {"bot in middle", "robot-person", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCommonBot(tt.username) + if got != tt.want { + t.Errorf("isCommonBot(%q) = %v, want %v", tt.username, got, tt.want) + } + }) + } +} + +func TestIsCommonBotCaseSensitivity(t *testing.T) { + tests := []struct { + name string + username string + want bool + }{ + {"uppercase BOT", "DEPENDABOT[bot]", true}, + {"mixed case", "DePeNdAbOt[bot]", true}, + {"lowercase all", "dependabot[bot]", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCommonBot(tt.username) + if got != tt.want { + t.Errorf("isCommonBot(%q) = %v, want %v", tt.username, got, tt.want) + } + }) + } +} + +func TestExtractParticipantEventsEdgeCases(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + events []prx.Event + expectedCount int + expectedFirst string + }{ + { + name: "single event", + events: []prx.Event{ + {Timestamp: now, Actor: "alice", Bot: false}, + }, + expectedCount: 1, + expectedFirst: "alice", + }, + { + name: "github user filtered out", + events: []prx.Event{ + {Timestamp: now, Actor: "github", Bot: false}, + }, + expectedCount: 0, + }, + { + name: "mix of valid and invalid", + events: []prx.Event{ + {Timestamp: now, Actor: "alice", Bot: false}, + {Timestamp: now, Actor: "github", Bot: false}, + {Timestamp: now, Actor: "bot[bot]", Bot: true}, + }, + expectedCount: 1, + expectedFirst: "alice", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractParticipantEvents(tt.events) + + if len(result) != tt.expectedCount { + t.Errorf("Expected %d events, got %d", tt.expectedCount, len(result)) + } + + if tt.expectedCount > 0 && result[0].Actor != tt.expectedFirst { + t.Errorf("Expected first actor %s, got %s", tt.expectedFirst, result[0].Actor) + } + }) + } +} + +func TestPRDataFromPRXWithRealSprinklerData(t *testing.T) { + // Load real prx output from sprinkler PR #37 + data, err := os.ReadFile("../../testdata/sprinkler_pr_37_clean.json") + if err != nil { + t.Skipf("Skipping real data test: %v", err) + } + + var prxData prx.PullRequestData + if err := json.Unmarshal(data, &prxData); err != nil { + t.Fatalf("Failed to parse PR data: %v", err) + } + + costData := PRDataFromPRX(&prxData) + + // Validate sprinkler PR #37 specific data + if costData.Author != "tstromberg" { + t.Errorf("Expected author 'tstromberg', got '%s'", costData.Author) + } + + if costData.LinesAdded != 324 { + t.Errorf("Expected 324 lines added, got %d", costData.LinesAdded) + } + + // Should have filtered out bot events (github check_run) + for _, event := range costData.Events { + if event.Actor == "github" { + t.Error("Should have filtered out github automation events") + } + } + + // PR #37 had 6 events total: commit, pr_opened, merged, closed, pr_merged, check_run + // Bot event (check_run) should be filtered, github actor events filtered + // So we should have: commit (Thomas Stromberg), pr_opened (tstromberg), merged, closed, pr_merged + if len(costData.Events) != 5 { + t.Errorf("Expected 5 human events, got %d", len(costData.Events)) + } + + t.Logf("Sprinkler PR 37: %d human events out of %d total events", len(costData.Events), len(prxData.Events)) +} + +func TestGetCacheDirCreatesDirectory(t *testing.T) { + // This test actually calls getCacheDir to improve coverage + dir, err := getCacheDir() + if err != nil { + t.Fatalf("getCacheDir() error = %v", err) + } + if dir == "" { + t.Error("getCacheDir() returned empty string") + } + + // Verify directory was created + info, err := os.Stat(dir) + if err != nil { + t.Errorf("Cache directory was not created: %v", err) + } + if !info.IsDir() { + t.Error("Cache path is not a directory") + } +} + +func TestIsCommonBotVariations(t *testing.T) { + tests := []struct { + username string + want bool + }{ + {"dependabot", true}, + {"dependabot[bot]", true}, + {"renovate", true}, + {"renovate-bot", true}, + {"github-actions", true}, + {"github-actions[bot]", true}, + {"codecov", true}, + {"codecov-commenter", true}, + {"greenkeeper", true}, + {"greenkeeper[bot]", true}, + {"snyk-bot", true}, + {"allcontributors", true}, + {"allcontributors[bot]", true}, + {"imgbot", true}, + {"ImgBot", true}, // case insensitive + {"stalebot", true}, + {"mergify", true}, + {"mergify[bot]", true}, + {"netlify", true}, + {"netlify[bot]", true}, + {"vercel", true}, + {"vercel[bot]", true}, + {"codefactor-io", true}, + {"deepsource-autofix", true}, + {"deepsource-autofix[bot]", true}, + {"pre-commit-ci", true}, + {"pre-commit-ci[bot]", true}, + {"ready-to-review", true}, + {"ready-to-review[bot]", true}, + {"regular-user", false}, + {"robot", false}, + {"botman", false}, + {"john-doe", false}, + } + + for _, tt := range tests { + t.Run(tt.username, func(t *testing.T) { + got := isCommonBot(tt.username) + if got != tt.want { + t.Errorf("isCommonBot(%q) = %v, want %v", tt.username, got, tt.want) + } + }) + } +} diff --git a/pkg/github/query.go b/pkg/github/query.go index bd10993..c516ca3 100644 --- a/pkg/github/query.go +++ b/pkg/github/query.go @@ -13,18 +13,16 @@ import ( ) // PRSummary holds minimal information about a PR for sampling and fetching. -// -//nolint:govet // fieldalignment: struct field order optimized for readability type PRSummary struct { - Owner string // Repository owner - Repo string // Repository name - Number int // PR number - Author string // PR author login - UpdatedAt time.Time // Last update time + UpdatedAt time.Time + Owner string + Repo string + Author string + Number int } // ProgressCallback is called during PR fetching to report progress. -// Parameters: queryName (e.g., "recent", "old", "early"), currentPage, totalPRsSoFar +// Parameters: queryName (e.g., "recent", "old", "early"), currentPage, totalPRsSoFar. type ProgressCallback func(queryName string, page int, prCount int) // FetchPRsFromRepo queries GitHub GraphQL API for all PRs in a repository @@ -48,7 +46,10 @@ type ProgressCallback func(queryName string, page int, prCount int) // - Slice of PRSummary for all matching PRs (deduplicated) func FetchPRsFromRepo(ctx context.Context, owner, repo string, since time.Time, token string, progress ProgressCallback) ([]PRSummary, error) { // Query 1: Recent activity (updated DESC) - get up to 1000 PRs - recent, hitLimit, err := fetchPRsFromRepoWithSort(ctx, owner, repo, since, token, "UPDATED_AT", "DESC", 1000, "recent", progress) + recent, hitLimit, err := fetchPRsFromRepoWithSort(ctx, repoSortParams{ + owner: owner, repo: repo, since: since, token: token, + field: "UPDATED_AT", direction: "DESC", maxPRs: 1000, queryName: "recent", progress: progress, + }) if err != nil { return nil, err } @@ -60,7 +61,10 @@ func FetchPRsFromRepo(ctx context.Context, owner, repo string, since time.Time, // Hit limit - need more coverage for earlier periods // Query 2: Old activity (updated ASC) - get ~500 more - old, _, err := fetchPRsFromRepoWithSort(ctx, owner, repo, since, token, "UPDATED_AT", "ASC", 500, "old", progress) + old, _, err := fetchPRsFromRepoWithSort(ctx, repoSortParams{ + owner: owner, repo: repo, since: since, token: token, + field: "UPDATED_AT", direction: "ASC", maxPRs: 500, queryName: "old", progress: progress, + }) if err != nil { slog.Warn("Failed to fetch old PRs, falling back to recent only", "error", err) return recent, nil @@ -86,7 +90,10 @@ func FetchPRsFromRepo(ctx context.Context, owner, repo string, since time.Time, slog.Info("Gap > 1 week detected, fetching early period PRs to fill coverage hole") // Query 3: Early period (created ASC) - get ~250 more - early, _, err := fetchPRsFromRepoWithSort(ctx, owner, repo, since, token, "CREATED_AT", "ASC", 250, "early", progress) + early, _, err := fetchPRsFromRepoWithSort(ctx, repoSortParams{ + owner: owner, repo: repo, since: since, token: token, + field: "CREATED_AT", direction: "ASC", maxPRs: 250, queryName: "early", progress: progress, + }) if err != nil { slog.Warn("Failed to fetch early PRs, proceeding with recent+old", "error", err) return deduplicatePRs(append(recent, old...)), nil @@ -103,12 +110,27 @@ func FetchPRsFromRepo(ctx context.Context, owner, repo string, since time.Time, return deduplicatePRs(append(recent, old...)), nil } +// repoSortParams contains parameters for sorted PR queries. +type repoSortParams struct { + since time.Time + progress ProgressCallback + owner string + repo string + token string + field string + direction string + queryName string + maxPRs int +} + // fetchPRsFromRepoWithSort queries GitHub GraphQL API with configurable sort order. // Returns PRs and a boolean indicating if the API limit (1000) was hit. -func fetchPRsFromRepoWithSort( - ctx context.Context, owner, repo string, since time.Time, - token, field, direction string, maxPRs int, queryName string, progress ProgressCallback, -) ([]PRSummary, bool, error) { +func fetchPRsFromRepoWithSort(ctx context.Context, params repoSortParams) ([]PRSummary, bool, error) { + owner, repo := params.owner, params.repo + since, token := params.since, params.token + field, direction := params.field, params.direction + maxPRs, queryName := params.maxPRs, params.queryName + progress := params.progress query := fmt.Sprintf(` query($owner: String!, $name: String!, $cursor: String) { repository(owner: $owner, name: $name) { @@ -185,18 +207,16 @@ func fetchPRsFromRepoWithSort( Data struct { Repository struct { PullRequests struct { - TotalCount int - PageInfo struct { + PageInfo struct { HasNextPage bool EndCursor string } Nodes []struct { Number int UpdatedAt time.Time - Author struct { - Login string - } + Author struct{ Login string } } + TotalCount int } } } @@ -316,7 +336,10 @@ func FetchPRsFromOrg(ctx context.Context, org string, since time.Time, token str sinceStr := since.Format("2006-01-02") // Query 1: Recent activity (updated desc) - get up to 1000 PRs - recent, hitLimit, err := fetchPRsFromOrgWithSort(ctx, org, sinceStr, token, "updated", "desc", 1000, "recent", progress) + recent, hitLimit, err := fetchPRsFromOrgWithSort(ctx, orgSortParams{ + org: org, sinceStr: sinceStr, token: token, + field: "updated", direction: "desc", maxPRs: 1000, queryName: "recent", progress: progress, + }) if err != nil { return nil, err } @@ -332,7 +355,10 @@ func FetchPRsFromOrg(ctx context.Context, org string, since time.Time, token str // Hit limit - need more coverage for earlier periods // Query 2: Old activity (updated asc) - get ~500 more - old, _, err := fetchPRsFromOrgWithSort(ctx, org, sinceStr, token, "updated", "asc", 500, "old", progress) + old, _, err := fetchPRsFromOrgWithSort(ctx, orgSortParams{ + org: org, sinceStr: sinceStr, token: token, + field: "updated", direction: "asc", maxPRs: 500, queryName: "old", progress: progress, + }) if err != nil { slog.Warn("Failed to fetch old PRs from org, falling back to recent only", "error", err) return recent, nil @@ -358,7 +384,10 @@ func FetchPRsFromOrg(ctx context.Context, org string, since time.Time, token str slog.Info("Gap > 1 week detected, fetching early period PRs to fill coverage hole (org)") // Query 3: Early period (created asc) - get ~250 more - early, _, err := fetchPRsFromOrgWithSort(ctx, org, sinceStr, token, "created", "asc", 250, "early", progress) + early, _, err := fetchPRsFromOrgWithSort(ctx, orgSortParams{ + org: org, sinceStr: sinceStr, token: token, + field: "created", direction: "asc", maxPRs: 250, queryName: "early", progress: progress, + }) if err != nil { slog.Warn("Failed to fetch early PRs from org, proceeding with recent+old", "error", err) return deduplicatePRsByOwnerRepoNumber(append(recent, old...)), nil @@ -375,11 +404,26 @@ func FetchPRsFromOrg(ctx context.Context, org string, since time.Time, token str return deduplicatePRsByOwnerRepoNumber(append(recent, old...)), nil } +// orgSortParams contains parameters for sorted org PR queries. +type orgSortParams struct { + progress ProgressCallback + org string + sinceStr string + token string + field string + direction string + queryName string + maxPRs int +} + // fetchPRsFromOrgWithSort queries GitHub Search API with configurable sort order. // Returns PRs and a boolean indicating if the API limit (1000) was hit. -func fetchPRsFromOrgWithSort( - ctx context.Context, org, sinceStr, token, field, direction string, maxPRs int, queryName string, progress ProgressCallback, -) ([]PRSummary, bool, error) { +func fetchPRsFromOrgWithSort(ctx context.Context, params orgSortParams) ([]PRSummary, bool, error) { + org, sinceStr := params.org, params.sinceStr + token := params.token + field, direction := params.field, params.direction + maxPRs, queryName := params.maxPRs, params.queryName + progress := params.progress // Build search query with sort // Query format: org:myorg is:pr updated:>2025-07-25 sort:updated-desc searchQuery := fmt.Sprintf("org:%s is:pr %s:>%s sort:%s-%s", org, field, sinceStr, field, direction) @@ -464,24 +508,20 @@ func fetchPRsFromOrgWithSort( var result struct { Data struct { Search struct { - IssueCount int - PageInfo struct { + PageInfo struct { HasNextPage bool EndCursor string } Nodes []struct { - Number int - UpdatedAt time.Time - Author struct { - Login string - } + Number int + UpdatedAt time.Time + Author struct{ Login string } Repository struct { - Owner struct { - Login string - } - Name string + Owner struct{ Login string } + Name string } } + IssueCount int } } Errors []struct { @@ -823,14 +863,14 @@ func CountOpenPRsInRepo(ctx context.Context, owner, repo, token string) (int, er } var result struct { + Errors []struct { + Message string + } Data struct { Search struct { IssueCount int `json:"issueCount"` } `json:"search"` } `json:"data"` - Errors []struct { - Message string - } } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { @@ -904,14 +944,14 @@ func CountOpenPRsInOrg(ctx context.Context, org, token string) (int, error) { } var result struct { + Errors []struct { + Message string + } Data struct { Search struct { IssueCount int `json:"issueCount"` } `json:"search"` } `json:"data"` - Errors []struct { - Message string - } } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { diff --git a/pkg/github/query_test.go b/pkg/github/query_test.go new file mode 100644 index 0000000..0fdf041 --- /dev/null +++ b/pkg/github/query_test.go @@ -0,0 +1,224 @@ +package github + +import ( + "testing" + "time" +) + +func TestIsBot(t *testing.T) { + tests := []struct { + name string + prAuthor string + want bool + }{ + { + name: "dependabot", + prAuthor: "dependabot[bot]", + want: true, + }, + { + name: "renovate", + prAuthor: "renovate[bot]", + want: true, + }, + { + name: "github-actions", + prAuthor: "github-actions[bot]", + want: true, + }, + { + name: "human user", + prAuthor: "testuser", + want: false, + }, + { + name: "user with bot in name", + prAuthor: "robot-person", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsBot(tt.prAuthor) + if got != tt.want { + t.Errorf("IsBot(%q) = %v, want %v", tt.prAuthor, got, tt.want) + } + }) + } +} + +func TestCountBotPRs(t *testing.T) { + prs := []PRSummary{ + {Author: "dependabot[bot]"}, + {Author: "renovate[bot]"}, + {Author: "testuser"}, + {Author: "anotheruser"}, + {Author: "github-actions[bot]"}, + } + + botCount := CountBotPRs(prs) + if botCount != 3 { + t.Errorf("CountBotPRs() = %d, want 3", botCount) + } +} + +func TestSamplePRs(t *testing.T) { + // Create sample PRs + prs := make([]PRSummary, 100) + for i := range prs { + prs[i] = PRSummary{ + Number: i + 1, + Owner: "testowner", + Repo: "testrepo", + } + } + + tests := []struct { + name string + totalPRs int + targetSize int + wantSize int + }{ + { + name: "sample 10 from 100", + totalPRs: 100, + targetSize: 10, + wantSize: 10, + }, + { + name: "sample more than available", + totalPRs: 100, + targetSize: 150, + wantSize: 100, + }, + { + name: "sample with small set", + totalPRs: 5, + targetSize: 10, + wantSize: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testPRs := prs[:tt.totalPRs] + result := SamplePRs(testPRs, tt.targetSize) + if len(result) != tt.wantSize { + t.Errorf("SamplePRs() returned %d PRs, want %d", len(result), tt.wantSize) + } + }) + } +} + +func TestCountUniqueAuthors(t *testing.T) { + prs := []PRSummary{ + {Author: "user1"}, + {Author: "user2"}, + {Author: "user1"}, // duplicate + {Author: "user3"}, + {Author: "user2"}, // duplicate + {Author: "dependabot[bot]"}, + } + + count := CountUniqueAuthors(prs) + if count != 3 { + t.Errorf("CountUniqueAuthors() = %d, want 3", count) + } +} + +func TestCalculateActualTimeWindow(t *testing.T) { + now := time.Now() + prs := []PRSummary{ + {UpdatedAt: now.Add(-10 * 24 * time.Hour)}, + {UpdatedAt: now.Add(-5 * 24 * time.Hour)}, + {UpdatedAt: now.Add(-1 * 24 * time.Hour)}, + } + + // When PRs don't cover full requested period, function returns requested days + days, hitLimit := CalculateActualTimeWindow(prs, 30) + if days != 30 { + t.Errorf("CalculateActualTimeWindow() = %d days, want 30 days (requested)", days) + } + if hitLimit { + t.Error("CalculateActualTimeWindow() hitLimit = true, want false") + } + + // Test with empty PRs + days2, hitLimit2 := CalculateActualTimeWindow([]PRSummary{}, 30) + if days2 != 30 { + t.Errorf("CalculateActualTimeWindow(empty) = %d days, want 30", days2) + } + if hitLimit2 { + t.Error("CalculateActualTimeWindow(empty) hitLimit = true, want false") + } +} + +func TestDeduplicatePRs(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + later := now.Add(1 * time.Hour) + + prs := []PRSummary{ + {Number: 1, Owner: "owner", Repo: "repo", UpdatedAt: earlier}, + {Number: 2, Owner: "owner", Repo: "repo", UpdatedAt: now}, + {Number: 1, Owner: "owner", Repo: "repo", UpdatedAt: later}, // duplicate - first occurrence kept + } + + result := deduplicatePRs(prs) + if len(result) != 2 { + t.Errorf("deduplicatePRs() returned %d PRs, want 2", len(result)) + } + + // Verify we have both unique PRs + numbers := make(map[int]bool) + for _, pr := range result { + numbers[pr.Number] = true + } + if !numbers[1] || !numbers[2] { + t.Error("deduplicatePRs() did not include all unique PRs") + } +} + +func TestDeduplicatePRsByOwnerRepoNumber(t *testing.T) { + prs := []PRSummary{ + {Number: 1, Owner: "owner1", Repo: "repo1", UpdatedAt: time.Now()}, + {Number: 1, Owner: "owner2", Repo: "repo1", UpdatedAt: time.Now()}, // different owner + {Number: 1, Owner: "owner1", Repo: "repo1", UpdatedAt: time.Now().Add(1 * time.Hour)}, // duplicate + } + + result := deduplicatePRsByOwnerRepoNumber(prs) + if len(result) != 2 { + t.Errorf("deduplicatePRsByOwnerRepoNumber() returned %d PRs, want 2", len(result)) + } +} + +func TestIsBotEdgeCases(t *testing.T) { + tests := []struct { + name string + author string + want bool + }{ + {"empty string", "", false}, + {"just brackets", "[bot]", true}, + {"app suffix", "myapp[bot]", true}, + {"greenkeeper", "greenkeeper[bot]", true}, + {"snyk", "snyk[bot]", true}, + {"imgbot", "imgbot[bot]", true}, + {"allcontributors", "allcontributors[bot]", true}, + {"stale", "stale[bot]", true}, + {"codecov", "codecov[bot]", true}, + {"whitesource", "whitesource[bot]", true}, + {"normal user no brackets", "username", false}, + {"bot in middle of name", "robot-user", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsBot(tt.author) + if got != tt.want { + t.Errorf("IsBot(%q) = %v, want %v", tt.author, got, tt.want) + } + }) + } +} diff --git a/testdata/sprinkler_pr_37.json b/testdata/sprinkler_pr_37.json new file mode 100644 index 0000000..073957e --- /dev/null +++ b/testdata/sprinkler_pr_37.json @@ -0,0 +1,21 @@ +2025/10/28 22:09:24 INFO cache miss: GraphQL pull request not in cache owner=codeGROOVE-dev repo=sprinkler pr=37 +2025/10/28 22:09:24 INFO fetching pull request via GraphQL owner=codeGROOVE-dev repo=sprinkler pr=37 +2025/10/28 22:09:24 INFO HTTP request starting method=POST url=https://api.github.com/graphql host=api.github.com +2025/10/28 22:09:25 INFO HTTP response received status=200 url=https://api.github.com/graphql elapsed=1.018633041s +2025/10/28 22:09:25 INFO GraphQL query completed cost=1 remaining=4971 limit=5000 +2025/10/28 22:09:25 INFO GitHub API request starting method=GET url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/collaborators?affiliation=all&per_page=100" headers="map[Accept:application/vnd.github.v3+json Authorization:Bearer gho_...UIqb User-Agent:]" +2025/10/28 22:09:25 INFO HTTP request starting method=GET url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/collaborators?affiliation=all&per_page=100" host=api.github.com +2025/10/28 22:09:25 INFO HTTP response received status=200 url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/collaborators?affiliation=all&per_page=100" elapsed=170.363583ms +2025/10/28 22:09:25 INFO GitHub API response received status="200 OK" url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/collaborators?affiliation=all&per_page=100" elapsed=170.53075ms rate_limits="map[Retry-After: X-RateLimit-Limit:5000 X-RateLimit-Remaining:4997 X-RateLimit-Reset:1761706998 X-RateLimit-Resource:collaborators X-RateLimit-Used:3]" +2025/10/28 22:09:25 INFO GitHub API request starting method=GET url=https://api.github.com/repos/codeGROOVE-dev/sprinkler/rulesets headers="map[Accept:application/vnd.github.v3+json Authorization:Bearer gho_...UIqb User-Agent:]" +2025/10/28 22:09:25 INFO HTTP request starting method=GET url=https://api.github.com/repos/codeGROOVE-dev/sprinkler/rulesets host=api.github.com +2025/10/28 22:09:25 INFO HTTP response received status=200 url=https://api.github.com/repos/codeGROOVE-dev/sprinkler/rulesets elapsed=208.483375ms +2025/10/28 22:09:25 INFO GitHub API response received status="200 OK" url=https://api.github.com/repos/codeGROOVE-dev/sprinkler/rulesets elapsed=208.550458ms rate_limits="map[Retry-After: X-RateLimit-Limit:5000 X-RateLimit-Remaining:4975 X-RateLimit-Reset:1761706718 X-RateLimit-Resource:core X-RateLimit-Used:25]" +2025/10/28 22:09:25 INFO fetched required checks from rulesets count=0 checks=[] +2025/10/28 22:09:25 INFO GitHub API request starting method=GET url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/commits/03baab46ffa62f2d360eaaae7402bebe103639d8/check-runs?per_page=100" headers="map[Accept:application/vnd.github.v3+json Authorization:Bearer gho_...UIqb User-Agent:]" +2025/10/28 22:09:25 INFO HTTP request starting method=GET url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/commits/03baab46ffa62f2d360eaaae7402bebe103639d8/check-runs?per_page=100" host=api.github.com +2025/10/28 22:09:26 INFO HTTP response received status=200 url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/commits/03baab46ffa62f2d360eaaae7402bebe103639d8/check-runs?per_page=100" elapsed=197.456917ms +2025/10/28 22:09:26 INFO GitHub API response received status="200 OK" url="https://api.github.com/repos/codeGROOVE-dev/sprinkler/commits/03baab46ffa62f2d360eaaae7402bebe103639d8/check-runs?per_page=100" elapsed=197.724458ms rate_limits="map[Retry-After: X-RateLimit-Limit:5000 X-RateLimit-Remaining:4974 X-RateLimit-Reset:1761706718 X-RateLimit-Resource:core X-RateLimit-Used:26]" +2025/10/28 22:09:26 INFO fetched check runs via REST count=1 +2025/10/28 22:09:26 INFO successfully fetched pull request via hybrid GraphQL+REST owner=codeGROOVE-dev repo=sprinkler pr=37 event_count=6 api_calls_made="3 (vs 13+ with REST)" +{"events":[{"timestamp":"2025-10-29T02:04:26Z","kind":"commit","actor":"Thomas Stromberg","body":"Add TestCheckEventRaceCondition"},{"timestamp":"2025-10-29T02:04:47Z","kind":"pr_opened","actor":"tstromberg","write_access":2},{"timestamp":"2025-10-29T02:05:04Z","kind":"merged","actor":"tstromberg"},{"timestamp":"2025-10-29T02:05:04Z","kind":"closed","actor":"tstromberg"},{"timestamp":"2025-10-29T02:05:04Z","kind":"pr_merged","actor":"tstromberg"},{"timestamp":"2025-10-29T02:05:10Z","kind":"check_run","actor":"github","outcome":"success","body":"Kusari Inspector","description":"Security Analysis Passed: No security issues found","bot":true}],"pull_request":{"created_at":"2025-10-29T02:04:47Z","updated_at":"2025-10-29T02:05:04Z","closed_at":"2025-10-29T02:05:04Z","merged_at":"2025-10-29T02:05:04Z","approval_summary":{"approvals_with_write_access":0,"approvals_with_unknown_access":0,"approvals_without_write_access":0,"changes_requested":0},"check_summary":{"success":{"Kusari Inspector":"Security Analysis Passed: No security issues found"},"failing":{},"pending":{},"cancelled":{},"skipped":{},"stale":{},"neutral":{}},"mergeable":null,"assignees":[],"participant_access":{"Thomas Stromberg":0,"tstromberg":2},"mergeable_state":"unknown","mergeable_state_description":"Merge status is being calculated","author":"tstromberg","body":"","title":"Add TestCheckEventRaceCondition","merged_by":"tstromberg","state":"merged","test_state":"passing","head_sha":"03baab46ffa62f2d360eaaae7402bebe103639d8","number":37,"changed_files":1,"deletions":0,"additions":324,"author_write_access":2,"author_bot":false,"merged":true,"draft":false}} diff --git a/testdata/sprinkler_pr_37_clean.json b/testdata/sprinkler_pr_37_clean.json new file mode 100644 index 0000000..49f6b62 --- /dev/null +++ b/testdata/sprinkler_pr_37_clean.json @@ -0,0 +1,86 @@ +{ + "events": [ + { + "timestamp": "2025-10-29T02:04:26Z", + "kind": "commit", + "actor": "Thomas Stromberg", + "body": "Add TestCheckEventRaceCondition" + }, + { + "timestamp": "2025-10-29T02:04:47Z", + "kind": "pr_opened", + "actor": "tstromberg", + "write_access": 2 + }, + { + "timestamp": "2025-10-29T02:05:04Z", + "kind": "merged", + "actor": "tstromberg" + }, + { + "timestamp": "2025-10-29T02:05:04Z", + "kind": "closed", + "actor": "tstromberg" + }, + { + "timestamp": "2025-10-29T02:05:04Z", + "kind": "pr_merged", + "actor": "tstromberg" + }, + { + "timestamp": "2025-10-29T02:05:10Z", + "kind": "check_run", + "actor": "github", + "outcome": "success", + "body": "Kusari Inspector", + "description": "Security Analysis Passed: No security issues found", + "bot": true + } + ], + "pull_request": { + "created_at": "2025-10-29T02:04:47Z", + "updated_at": "2025-10-29T02:05:04Z", + "closed_at": "2025-10-29T02:05:04Z", + "merged_at": "2025-10-29T02:05:04Z", + "approval_summary": { + "approvals_with_write_access": 0, + "approvals_with_unknown_access": 0, + "approvals_without_write_access": 0, + "changes_requested": 0 + }, + "check_summary": { + "success": { + "Kusari Inspector": "Security Analysis Passed: No security issues found" + }, + "failing": {}, + "pending": {}, + "cancelled": {}, + "skipped": {}, + "stale": {}, + "neutral": {} + }, + "mergeable": null, + "assignees": [], + "participant_access": { + "Thomas Stromberg": 0, + "tstromberg": 2 + }, + "mergeable_state": "unknown", + "mergeable_state_description": "Merge status is being calculated", + "author": "tstromberg", + "body": "", + "title": "Add TestCheckEventRaceCondition", + "merged_by": "tstromberg", + "state": "merged", + "test_state": "passing", + "head_sha": "03baab46ffa62f2d360eaaae7402bebe103639d8", + "number": 37, + "changed_files": 1, + "deletions": 0, + "additions": 324, + "author_write_access": 2, + "author_bot": false, + "merged": true, + "draft": false + } +}