Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 3 additions & 21 deletions pkg/cli/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,16 @@ import (
"fmt"
"os/exec"
"strings"
"sync"

"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/syncutil"
"github.com/github/gh-aw/pkg/workflow"
)

var repoLog = logger.New("cli:repo")

// repoSlugCacheState holds the cached repository slug and protects it with a mutex.
// Using a mutex-guarded struct instead of sync.Once avoids the data race that arises
// when resetting sync.Once via struct assignment (= sync.Once{}) after first use.
type repoSlugCacheState struct {
mu sync.Mutex
result string
err error
done bool
}

// Global cache for current repository info
var currentRepoSlugCache repoSlugCacheState
var currentRepoSlugCache syncutil.OnceLoader[string]

// getCurrentRepoSlugUncached gets the current repository slug (owner/repo) using gh CLI (uncached)
// Falls back to git remote parsing if gh CLI is not available
Expand Down Expand Up @@ -91,15 +81,7 @@ func getCurrentRepoSlugUncached() (string, error) {
// GetCurrentRepoSlug gets the current repository slug with caching.
// This is the recommended function to use for repository access across the codebase.
func GetCurrentRepoSlug() (string, error) {
result, err := func() (string, error) {
currentRepoSlugCache.mu.Lock()
defer currentRepoSlugCache.mu.Unlock()
if !currentRepoSlugCache.done {
currentRepoSlugCache.result, currentRepoSlugCache.err = getCurrentRepoSlugUncached()
currentRepoSlugCache.done = true
}
return currentRepoSlugCache.result, currentRepoSlugCache.err
}()
result, err := currentRepoSlugCache.Get(getCurrentRepoSlugUncached)

if err != nil {
return "", err
Expand Down
6 changes: 1 addition & 5 deletions pkg/cli/repo_test_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,5 @@ package cli
// ClearCurrentRepoSlugCache clears the current repository slug cache.
// This is useful for testing when repository context might have changed.
func ClearCurrentRepoSlugCache() {
currentRepoSlugCache.mu.Lock()
defer currentRepoSlugCache.mu.Unlock()
currentRepoSlugCache.result = ""
currentRepoSlugCache.err = nil
currentRepoSlugCache.done = false
currentRepoSlugCache.Reset()
}
36 changes: 36 additions & 0 deletions pkg/syncutil/onceloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package syncutil

import "sync"

// OnceLoader caches the result of a fallible, expensive one-shot fetch.
// Safe for concurrent use; loader is invoked at most once.
type OnceLoader[T any] struct {
mu sync.Mutex
result T
err error
done bool
}

// Get returns the cached result, invoking loader exactly once.
func (o *OnceLoader[T]) Get(loader func() (T, error)) (T, error) {
o.mu.Lock()
defer o.mu.Unlock()

if !o.done {
o.result, o.err = loader()
o.done = true
}

return o.result, o.err
}

// Reset clears cached state.
func (o *OnceLoader[T]) Reset() {
o.mu.Lock()
defer o.mu.Unlock()

var zero T
o.result = zero
o.err = nil
o.done = false
}
119 changes: 119 additions & 0 deletions pkg/syncutil/onceloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package syncutil

import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
)

func TestOnceLoaderGetCachesSuccess(t *testing.T) {
var loader OnceLoader[string]
var calls atomic.Int32

load := func() (string, error) {
calls.Add(1)
return "ok", nil
}

got1, err1 := loader.Get(load)
got2, err2 := loader.Get(load)

if err1 != nil || err2 != nil {
t.Fatalf("expected nil errors, got err1=%v err2=%v", err1, err2)
}
if got1 != "ok" || got2 != "ok" {
t.Fatalf("expected cached value 'ok', got %q and %q", got1, got2)
}
if calls.Load() != 1 {
t.Fatalf("expected loader to be called once, got %d", calls.Load())
}
}

func TestOnceLoaderGetCachesError(t *testing.T) {
var loader OnceLoader[string]
var calls atomic.Int32
expectedErr := errors.New("boom")

load := func() (string, error) {
calls.Add(1)
return "", expectedErr
}

got1, err1 := loader.Get(load)
got2, err2 := loader.Get(load)

if got1 != "" || got2 != "" {
t.Fatalf("expected empty cached values, got %q and %q", got1, got2)
}
if !errors.Is(err1, expectedErr) || !errors.Is(err2, expectedErr) {
t.Fatalf("expected cached errors to wrap %v, got err1=%v err2=%v", expectedErr, err1, err2)
}
if calls.Load() != 1 {
t.Fatalf("expected loader to be called once, got %d", calls.Load())
}
}

func TestOnceLoaderGetConcurrentSingleInvoke(t *testing.T) {
var loader OnceLoader[string]
var calls atomic.Int32
const workers = 50

load := func() (string, error) {
calls.Add(1)
return "value", nil
}

var wg sync.WaitGroup
wg.Add(workers)
for range workers {
go func() {
defer wg.Done()
got, err := loader.Get(load)
if err != nil {
t.Errorf("expected nil error, got %v", err)
return
}
if got != "value" {
t.Errorf("expected value, got %q", got)
}
}()
}
wg.Wait()

if calls.Load() != 1 {
t.Fatalf("expected loader to be called once under concurrency, got %d", calls.Load())
}
}

func TestOnceLoaderReset(t *testing.T) {
var loader OnceLoader[string]
var calls atomic.Int32

load := func() (string, error) {
n := calls.Add(1)
return fmt.Sprintf("v%d", n), nil
}

got1, err1 := loader.Get(load)
if err1 != nil {
t.Fatalf("unexpected error: %v", err1)
}
if got1 != "v1" {
t.Fatalf("expected first value v1, got %q", got1)
}

loader.Reset()

got2, err2 := loader.Get(load)
if err2 != nil {
t.Fatalf("unexpected error after reset: %v", err2)
}
if got2 != "v2" {
t.Fatalf("expected second value v2 after reset, got %q", got2)
}
if calls.Load() != 2 {
t.Fatalf("expected loader to run twice with reset, got %d", calls.Load())
}
}
24 changes: 3 additions & 21 deletions pkg/workflow/repository_features_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/cli/go-gh/v2/pkg/api"
"github.com/cli/go-gh/v2/pkg/repository"
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/syncutil"
)

var repositoryFeaturesLog = newValidationLogger("repository_features")
Expand All @@ -59,22 +60,11 @@ type RepositoryFeatures struct {
HasIssues bool
}

// currentRepositoryCacheState holds the cached current repository and protects it
// with a mutex. Using a mutex-guarded struct instead of sync.Once avoids the data
// race that arises when resetting sync.Once via struct assignment (= sync.Once{})
// after first use.
type currentRepositoryCacheState struct {
mu sync.Mutex
result string
err error
done bool
}

// Global cache for repository features and current repository info
var (
repositoryFeaturesCache = sync.Map{} // sync.Map is thread-safe and efficient for read-heavy workloads
repositoryFeaturesLoggedCache = sync.Map{} // Tracks which repositories have had their success messages logged
currentRepositoryCache currentRepositoryCacheState
currentRepositoryCache syncutil.OnceLoader[string]
)

// validateRepositoryFeatures validates that required repository features are enabled
Expand Down Expand Up @@ -157,15 +147,7 @@ func (c *Compiler) validateRepositoryFeatures(workflowData *WorkflowData) error

// getCurrentRepository gets the current repository from git context (with caching)
func getCurrentRepository() (string, error) {
result, err := func() (string, error) {
currentRepositoryCache.mu.Lock()
defer currentRepositoryCache.mu.Unlock()
if !currentRepositoryCache.done {
currentRepositoryCache.result, currentRepositoryCache.err = getCurrentRepositoryUncached()
currentRepositoryCache.done = true
}
return currentRepositoryCache.result, currentRepositoryCache.err
}()
result, err := currentRepositoryCache.Get(getCurrentRepositoryUncached)

if err != nil {
return "", err
Expand Down