Skip to content

Commit

Permalink
allow keys to be empty interfaces rather than strings
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Randall committed Dec 5, 2017
1 parent beedcd5 commit 766d4e4
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 102 deletions.
7 changes: 4 additions & 3 deletions .travis.yml
@@ -1,10 +1,11 @@
language: go

go:
- 1.7
- 1.x

before_install:
- go get -t -v ./...
install:
- go get -u github.com/golang/dep/...
- dep ensure

script:
- go test -v -race -coverprofile=coverage.txt -covermode=atomic
Expand Down
12 changes: 6 additions & 6 deletions cache.go
Expand Up @@ -4,9 +4,9 @@ import "context"

// The Cache interface. If a custom cache is provided, it must implement this interface.
type Cache interface {
Get(context.Context, string) (Thunk, bool)
Set(context.Context, string, Thunk)
Delete(context.Context, string) bool
Get(context.Context, interface{}) (Thunk, bool)
Set(context.Context, interface{}, Thunk)
Delete(context.Context, interface{}) bool
Clear()
}

Expand All @@ -16,13 +16,13 @@ type Cache interface {
type NoCache struct{}

// Get is a NOOP
func (c *NoCache) Get(context.Context, string) (Thunk, bool) { return nil, false }
func (c *NoCache) Get(context.Context, interface{}) (Thunk, bool) { return nil, false }

// Set is a NOOP
func (c *NoCache) Set(context.Context, string, Thunk) { return }
func (c *NoCache) Set(context.Context, interface{}, Thunk) { return }

// Delete is a NOOP
func (c *NoCache) Delete(context.Context, string) bool { return false }
func (c *NoCache) Delete(context.Context, interface{}) bool { return false }

// Clear is a NOOP
func (c *NoCache) Clear() { return }
14 changes: 7 additions & 7 deletions dataloader.go
Expand Up @@ -20,8 +20,8 @@ import (
// different access permissions and consider creating a new instance per
// web request.
type Interface interface {
Load(context.Context, string) Thunk
LoadMany(context.Context, []string) ThunkMany
Load(context.Context, interface{}) Thunk
LoadMany(context.Context, []interface{}) ThunkMany
Clear(context.Context, string) Interface
ClearAll() Interface
Prime(ctx context.Context, key string, value interface{}) Interface
Expand All @@ -31,7 +31,7 @@ type Interface interface {
// It's important that the length of the input keys matches the length of the output results.
//
// The keys passed to this function are guaranteed to be unique
type BatchFunc func(context.Context, []string) []*Result
type BatchFunc func(context.Context, []interface{}) []*Result

// Result is the data structure that a BatchFunc returns.
// It contains the resolved data, and any errors that may have occurred while fetching the data.
Expand Down Expand Up @@ -100,7 +100,7 @@ type ThunkMany func() ([]interface{}, []error)

// type used to on input channel
type batchRequest struct {
key string
key interface{}
channel chan *Result
}

Expand Down Expand Up @@ -191,7 +191,7 @@ func NewBatchedLoader(batchFn BatchFunc, opts ...Option) *Loader {
}

// Load load/resolves the given key, returning a channel that will contain the value and error
func (l *Loader) Load(originalContext context.Context, key string) Thunk {
func (l *Loader) Load(originalContext context.Context, key interface{}) Thunk {
ctx, finish := l.tracer.TraceLoad(originalContext, key)

c := make(chan *Result, 1)
Expand Down Expand Up @@ -267,7 +267,7 @@ func (l *Loader) Load(originalContext context.Context, key string) Thunk {
}

// LoadMany loads mulitiple keys, returning a thunk (type: ThunkMany) that will resolve the keys passed in.
func (l *Loader) LoadMany(originalContext context.Context, keys []string) ThunkMany {
func (l *Loader) LoadMany(originalContext context.Context, keys []interface{}) ThunkMany {
ctx, finish := l.tracer.TraceLoadMany(originalContext, keys)

length := len(keys)
Expand Down Expand Up @@ -386,7 +386,7 @@ func (b *batcher) end() {

// execute the batch of all items in queue
func (b *batcher) batch(originalContext context.Context) {
var keys []string
var keys []interface{}
var reqs []*batchRequest
var items []*Result
var panicErr interface{}
Expand Down
110 changes: 55 additions & 55 deletions dataloader_test.go
Expand Up @@ -81,7 +81,7 @@ func TestLoader(t *testing.T) {
t.Parallel()
errorLoader, _ := ErrorLoader(0)
ctx := context.Background()
future := errorLoader.LoadMany(ctx, []string{"1", "2", "3"})
future := errorLoader.LoadMany(ctx, []interface{}{"1", "2", "3"})
_, err := future()
if len(err) != 3 {
t.Error("LoadMany didn't return right number of errors")
Expand All @@ -90,13 +90,13 @@ func TestLoader(t *testing.T) {

t.Run("test LoadMany returns len(errors) == len(keys)", func(t *testing.T) {
t.Parallel()
loader, _ := OneErrorLoader(0)
loader, _ := OneErrorLoader(3)
ctx := context.Background()
future := loader.LoadMany(ctx, []string{"1", "2", "3"})
future := loader.LoadMany(ctx, []interface{}{"1", "2", "3"})
_, err := future()
log.Printf("errs: %#v", err)
if len(err) != 3 {
t.Error("LoadMany didn't return right number of errors (should match size of input)")
return
t.Errorf("LoadMany didn't return right number of errors (should match size of input)")
}

if err[0] == nil {
Expand All @@ -112,7 +112,7 @@ func TestLoader(t *testing.T) {
t.Parallel()
identityLoader, _ := IDLoader(0)
ctx := context.Background()
future := identityLoader.LoadMany(ctx, []string{"1", "2", "3"})
future := identityLoader.LoadMany(ctx, []interface{}{"1", "2", "3"})
go future()
go future()
})
Expand All @@ -127,7 +127,7 @@ func TestLoader(t *testing.T) {
}()
panicLoader, _ := PanicLoader(0)
ctx := context.Background()
future := panicLoader.LoadMany(ctx, []string{"1"})
future := panicLoader.LoadMany(ctx, []interface{}{"1"})
_, errs := future()
if len(errs) < 1 || errs[0].Error() != "Panic received in batch function: Programming error" {
t.Error("Panic was not propagated as an error.")
Expand All @@ -138,7 +138,7 @@ func TestLoader(t *testing.T) {
t.Parallel()
identityLoader, _ := IDLoader(0)
ctx := context.Background()
future := identityLoader.LoadMany(ctx, []string{"1", "2", "3"})
future := identityLoader.LoadMany(ctx, []interface{}{"1", "2", "3"})
results, _ := future()
if results[0].(string) != "1" || results[1].(string) != "2" || results[2].(string) != "3" {
t.Error("loadmany didn't return the right value")
Expand All @@ -162,8 +162,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1", "2"}
expected := [][]string{inner}
inner := []interface{}{"1", "2"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not call batchFn in right order. Expected %#v, got %#v", expected, calls)
}
Expand All @@ -176,7 +176,7 @@ func TestLoader(t *testing.T) {

n := 10
reqs := []Thunk{}
keys := []string{}
keys := []interface{}{}
for i := 0; i < n; i++ {
key := strconv.Itoa(i)
reqs = append(reqs, faultyLoader.Load(ctx, key))
Expand Down Expand Up @@ -215,9 +215,9 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner1 := []string{"1", "2"}
inner2 := []string{"3"}
expected := [][]string{inner1, inner2}
inner1 := []interface{}{"1", "2"}
inner2 := []interface{}{"3"}
expected := [][]interface{}{inner1, inner2}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand All @@ -240,8 +240,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1"}
expected := [][]string{inner}
inner := []interface{}{"1"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand All @@ -265,8 +265,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1"}
expected := [][]string{inner}
inner := []interface{}{"1"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand Down Expand Up @@ -300,8 +300,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1", "A"}
expected := [][]string{inner}
inner := []interface{}{"1", "A"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand All @@ -328,8 +328,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1"}
expected := [][]string{inner}
inner := []interface{}{"1"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not batch queries. Expected %#v, got %#v", expected, calls)
}
Expand Down Expand Up @@ -366,8 +366,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1", "A", "B"}
expected := [][]string{inner}
inner := []interface{}{"1", "A", "B"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand Down Expand Up @@ -400,8 +400,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1", "A", "B"}
expected := [][]string{inner}
inner := []interface{}{"1", "A", "B"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand Down Expand Up @@ -432,8 +432,8 @@ func TestLoader(t *testing.T) {
}

calls := *loadCalls
inner := []string{"1", "A", "B"}
expected := [][]string{inner}
inner := []interface{}{"1", "A", "B"}
expected := [][]interface{}{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
Expand All @@ -442,10 +442,10 @@ func TestLoader(t *testing.T) {
}

// test helpers
func IDLoader(max int) (*Loader, *[][]string) {
func IDLoader(max int) (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
identityLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
var loadCalls [][]interface{}
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
var results []*Result
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand All @@ -457,10 +457,10 @@ func IDLoader(max int) (*Loader, *[][]string) {
}, WithBatchCapacity(max))
return identityLoader, &loadCalls
}
func BatchOnlyLoader(max int) (*Loader, *[][]string) {
func BatchOnlyLoader(max int) (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
identityLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
var loadCalls [][]interface{}
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
var results []*Result
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand All @@ -472,10 +472,10 @@ func BatchOnlyLoader(max int) (*Loader, *[][]string) {
}, WithBatchCapacity(max), WithClearCacheOnBatch())
return identityLoader, &loadCalls
}
func ErrorLoader(max int) (*Loader, *[][]string) {
func ErrorLoader(max int) (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
identityLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
var loadCalls [][]interface{}
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
var results []*Result
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand All @@ -487,11 +487,11 @@ func ErrorLoader(max int) (*Loader, *[][]string) {
}, WithBatchCapacity(max))
return identityLoader, &loadCalls
}
func OneErrorLoader(max int) (*Loader, *[][]string) {
func OneErrorLoader(max int) (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
identityLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
var results []*Result
var loadCalls [][]interface{}
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
results := make([]*Result, max, max)
mu.Lock()
loadCalls = append(loadCalls, keys)
mu.Unlock()
Expand All @@ -500,23 +500,23 @@ func OneErrorLoader(max int) (*Loader, *[][]string) {
if i == 0 {
err = errors.New("always error on the first key")
}
results = append(results, &Result{key, err})
results[i] = &Result{key, err}
}
return results
}, WithBatchCapacity(max))
return identityLoader, &loadCalls
}
func PanicLoader(max int) (*Loader, *[][]string) {
var loadCalls [][]string
panicLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
func PanicLoader(max int) (*Loader, *[][]interface{}) {
var loadCalls [][]interface{}
panicLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
panic("Programming error")
}, WithBatchCapacity(max), withSilentLogger())
return panicLoader, &loadCalls
}
func BadLoader(max int) (*Loader, *[][]string) {
func BadLoader(max int) (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
identityLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
var loadCalls [][]interface{}
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
var results []*Result
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand All @@ -526,11 +526,11 @@ func BadLoader(max int) (*Loader, *[][]string) {
}, WithBatchCapacity(max))
return identityLoader, &loadCalls
}
func NoCacheLoader(max int) (*Loader, *[][]string) {
func NoCacheLoader(max int) (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
var loadCalls [][]interface{}
cache := &NoCache{}
identityLoader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
var results []*Result
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand All @@ -544,11 +544,11 @@ func NoCacheLoader(max int) (*Loader, *[][]string) {
}

// FaultyLoader gives len(keys)-1 results.
func FaultyLoader() (*Loader, *[][]string) {
func FaultyLoader() (*Loader, *[][]interface{}) {
var mu sync.Mutex
var loadCalls [][]string
var loadCalls [][]interface{}

loader := NewBatchedLoader(func(_ context.Context, keys []string) []*Result {
loader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
var results []*Result
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand All @@ -573,7 +573,7 @@ func FaultyLoader() (*Loader, *[][]string) {
///////////////////////////////////////////////////
var a = &Avg{}

func batchIdentity(_ context.Context, keys []string) (results []*Result) {
func batchIdentity(_ context.Context, keys []interface{}) (results []*Result) {
a.Add(len(keys))
for _, key := range keys {
results = append(results, &Result{key, nil})
Expand Down

0 comments on commit 766d4e4

Please sign in to comment.