diff --git a/go.mod b/go.mod index d8af59440..75a35c4a9 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,7 @@ require ( golang.org/x/mod v0.2.0 golang.org/x/net v0.0.0-20200625001655-4c5254603344 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 + golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae // indirect golang.org/x/text v0.3.3 // indirect google.golang.org/genproto v0.0.0-20200701001935-0939c5918c31 // indirect diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index e49fa9c93..daab7522d 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -1,3 +1,5 @@ +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -o sqlitefakes/fake_rowscanner.go . RowScanner +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -o sqlitefakes/fake_querier.go . Querier package sqlite import ( @@ -13,10 +15,28 @@ import ( "github.com/operator-framework/operator-registry/pkg/registry" ) -type SQLQuerier struct { +type RowScanner interface { + Next() bool + Close() error + Scan(dest ...interface{}) error +} + +type Querier interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (RowScanner, error) +} + +type dbQuerierAdapter struct { db *sql.DB } +func (a dbQuerierAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (RowScanner, error) { + return a.db.QueryContext(ctx, query, args...) +} + +type SQLQuerier struct { + db Querier +} + var _ registry.Query = &SQLQuerier{} func NewSQLLiteQuerier(dbFilename string) (*SQLQuerier, error) { @@ -25,11 +45,15 @@ func NewSQLLiteQuerier(dbFilename string) (*SQLQuerier, error) { return nil, err } - return &SQLQuerier{db}, nil + return &SQLQuerier{dbQuerierAdapter{db}}, nil } func NewSQLLiteQuerierFromDb(db *sql.DB) *SQLQuerier { - return &SQLQuerier{db} + return &SQLQuerier{dbQuerierAdapter{db}} +} + +func NewSQLLiteQuerierFromDBQuerier(q Querier) *SQLQuerier { + return &SQLQuerier{q} } func (s *SQLQuerier) ListTables(ctx context.Context) ([]string, error) { @@ -900,7 +924,7 @@ func (s *SQLQuerier) GetCurrentCSVNameForChannel(ctx context.Context, pkgName, c return "", nil } -func (s *SQLQuerier) ListBundles(ctx context.Context) (bundles []*api.Bundle, err error) { +func (s *SQLQuerier) ListBundles(ctx context.Context) ([]*api.Bundle, error) { query := `SELECT DISTINCT channel_entry.entry_id, operatorbundle.bundle, operatorbundle.bundlepath, channel_entry.operatorbundle_name, channel_entry.package_name, channel_entry.channel_name, operatorbundle.replaces, operatorbundle.skips, operatorbundle.version, operatorbundle.skiprange, @@ -918,23 +942,25 @@ func (s *SQLQuerier) ListBundles(ctx context.Context) (bundles []*api.Bundle, er } defer rows.Close() - bundles = []*api.Bundle{} + var bundles []*api.Bundle bundlesMap := map[string]*api.Bundle{} for rows.Next() { - var entryID sql.NullInt64 - var bundle sql.NullString - var bundlePath sql.NullString - var bundleName sql.NullString - var pkgName sql.NullString - var channelName sql.NullString - var replaces sql.NullString - var skips sql.NullString - var version sql.NullString - var skipRange sql.NullString - var depType sql.NullString - var depValue sql.NullString - var propType sql.NullString - var propValue sql.NullString + var ( + entryID sql.NullInt64 + bundle sql.NullString + bundlePath sql.NullString + bundleName sql.NullString + pkgName sql.NullString + channelName sql.NullString + replaces sql.NullString + skips sql.NullString + version sql.NullString + skipRange sql.NullString + depType sql.NullString + depValue sql.NullString + propType sql.NullString + propValue sql.NullString + ) if err := rows.Scan(&entryID, &bundle, &bundlePath, &bundleName, &pkgName, &channelName, &replaces, &skips, &version, &skipRange, &depType, &depValue, &propType, &propValue); err != nil { return nil, err } @@ -946,29 +972,18 @@ func (s *SQLQuerier) ListBundles(ctx context.Context) (bundles []*api.Bundle, er bundleKey := fmt.Sprintf("%s/%s/%s/%s", bundleName.String, version.String, bundlePath.String, channelName.String) bundleItem, ok := bundlesMap[bundleKey] if ok { - // Create new dependency object if depType.Valid && depValue.Valid { - dep := &api.Dependency{} - dep.Type = depType.String - dep.Value = depValue.String - - // Add new dependency to the existing list - existingDeps := bundleItem.Dependencies - existingDeps = append(existingDeps, dep) - bundleItem.Dependencies = existingDeps + bundleItem.Dependencies = append(bundleItem.Dependencies, &api.Dependency{ + Type: depType.String, + Value: depValue.String, + }) } - - // Create new property object if propType.Valid && propValue.Valid { - prop := &api.Property{} - prop.Type = propType.String - prop.Value = propValue.String - - // Add new property to the existing list - existingProps := bundleItem.Properties - existingProps = append(existingProps, prop) - bundleItem.Properties = existingProps + bundleItem.Properties = append(bundleItem.Properties, &api.Property{ + Type: propType.String, + Value: propValue.String, + }) } } else { // Create new bundle @@ -987,30 +1002,34 @@ func (s *SQLQuerier) ListBundles(ctx context.Context) (bundles []*api.Bundle, er out.Version = version.String out.SkipRange = skipRange.String out.Replaces = replaces.String - out.Skips = strings.Split(skips.String, ",") + if skips.Valid { + out.Skips = strings.Split(skips.String, ",") + } provided, required, err := s.GetApisForEntry(ctx, entryID.Int64) if err != nil { return nil, err } - out.ProvidedApis = provided - out.RequiredApis = required - - // Create new dependency and dependency list - dep := &api.Dependency{} - dependencies := []*api.Dependency{} - dep.Type = depType.String - dep.Value = depValue.String - dependencies = append(dependencies, dep) - out.Dependencies = dependencies - - // Create new property and property list - prop := &api.Property{} - properties := []*api.Property{} - prop.Type = propType.String - prop.Value = propValue.String - properties = append(properties, prop) - out.Properties = properties + if len(provided) > 0 { + out.ProvidedApis = provided + } + if len(required) > 0 { + out.RequiredApis = required + } + + if depType.Valid && depValue.Valid { + out.Dependencies = []*api.Dependency{{ + Type: depType.String, + Value: depValue.String, + }} + } + + if propType.Valid && propValue.Valid { + out.Properties = []*api.Property{{ + Type: propType.String, + Value: propValue.String, + }} + } bundlesMap[bundleKey] = out } @@ -1028,16 +1047,16 @@ func (s *SQLQuerier) ListBundles(ctx context.Context) (bundles []*api.Bundle, er bundles = append(bundles, v) } - return + return bundles, nil } func unique(deps []*api.Dependency) []*api.Dependency { - keys := make(map[string]bool) - list := []*api.Dependency{} + keys := make(map[string]struct{}) + var list []*api.Dependency for _, entry := range deps { depKey := fmt.Sprintf("%s/%s", entry.Type, entry.Value) if _, value := keys[depKey]; !value { - keys[depKey] = true + keys[depKey] = struct{}{} list = append(list, entry) } } @@ -1045,12 +1064,12 @@ func unique(deps []*api.Dependency) []*api.Dependency { } func uniqueProps(props []*api.Property) []*api.Property { - keys := make(map[string]bool) - list := []*api.Property{} + keys := make(map[string]struct{}) + var list []*api.Property for _, entry := range props { propKey := fmt.Sprintf("%s/%s", entry.Type, entry.Value) if _, value := keys[propKey]; !value { - keys[propKey] = true + keys[propKey] = struct{}{} list = append(list, entry) } } diff --git a/pkg/sqlite/query_test.go b/pkg/sqlite/query_test.go new file mode 100644 index 000000000..c657771b7 --- /dev/null +++ b/pkg/sqlite/query_test.go @@ -0,0 +1,278 @@ +package sqlite_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/operator-framework/operator-registry/pkg/api" + "github.com/operator-framework/operator-registry/pkg/sqlite" + "github.com/operator-framework/operator-registry/pkg/sqlite/sqlitefakes" + "github.com/stretchr/testify/assert" +) + +func TestListBundles(t *testing.T) { + type Columns struct { + EntryID sql.NullInt64 + Bundle sql.NullString + BundlePath sql.NullString + BundleName sql.NullString + PackageName sql.NullString + ChannelName sql.NullString + Replaces sql.NullString + Skips sql.NullString + Version sql.NullString + SkipRange sql.NullString + DependencyType sql.NullString + DependencyValue sql.NullString + PropertyType sql.NullString + PropertyValue sql.NullString + } + + var NoRows sqlitefakes.FakeRowScanner + NoRows.NextReturns(false) + + ScanFromColumns := func(t *testing.T, dsts []interface{}, cols Columns) { + ct := reflect.TypeOf(cols) + if len(dsts) != ct.NumField() { + t.Fatalf("expected %d columns, got %d", ct.NumField(), len(dsts)) + } + for i, dst := range dsts { + f := ct.Field(i) + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr { + t.Fatalf("scan argument at index %d is not a pointer", i) + } + if !f.Type.AssignableTo(dv.Elem().Type()) { + t.Fatalf("%s is not assignable to argument %s at index %d", f.Type, dv.Elem().Type(), i) + } + dv.Elem().Set(reflect.ValueOf(cols).Field(i)) + } + } + + for _, tc := range []struct { + Name string + Querier func(t *testing.T) sqlite.Querier + Bundles []*api.Bundle + ErrorMessage string + }{ + { + Name: "returns error when query returns error", + Querier: func(t *testing.T) sqlite.Querier { + var q sqlitefakes.FakeQuerier + q.QueryContextReturns(nil, fmt.Errorf("test")) + return &q + }, + ErrorMessage: "test", + }, + { + Name: "returns error when scan returns error", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&r, nil) + r.NextReturnsOnCall(0, true) + r.ScanReturns(fmt.Errorf("test")) + return &q + }, + ErrorMessage: "test", + }, + { + Name: "skips row without valid bundle name", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&r, nil) + r.NextReturnsOnCall(0, true) + r.ScanCalls(func(args ...interface{}) error { + ScanFromColumns(t, args, Columns{ + Version: sql.NullString{Valid: true}, + BundlePath: sql.NullString{Valid: true}, + ChannelName: sql.NullString{Valid: true}, + }) + return nil + }) + return &q + }, + }, + { + Name: "skips row without valid version", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&r, nil) + r.NextReturnsOnCall(0, true) + r.ScanCalls(func(args ...interface{}) error { + ScanFromColumns(t, args, Columns{ + BundleName: sql.NullString{Valid: true}, + BundlePath: sql.NullString{Valid: true}, + ChannelName: sql.NullString{Valid: true}, + }) + return nil + }) + return &q + }, + }, + { + Name: "skips row without valid bundle path", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&r, nil) + r.NextReturnsOnCall(0, true) + r.ScanCalls(func(args ...interface{}) error { + ScanFromColumns(t, args, Columns{ + BundleName: sql.NullString{Valid: true}, + Version: sql.NullString{Valid: true}, + ChannelName: sql.NullString{Valid: true}, + }) + return nil + }) + return &q + }, + }, + { + Name: "skips row without valid channel name", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&r, nil) + r.NextReturnsOnCall(0, true) + r.ScanCalls(func(args ...interface{}) error { + ScanFromColumns(t, args, Columns{ + BundleName: sql.NullString{Valid: true}, + Version: sql.NullString{Valid: true}, + BundlePath: sql.NullString{Valid: true}, + }) + return nil + }) + return &q + }, + }, + { + Name: "bundle dependencies are null when dependency type or value is null", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&r, nil) + r.NextReturnsOnCall(0, true) + r.ScanCalls(func(args ...interface{}) error { + ScanFromColumns(t, args, Columns{ + BundleName: sql.NullString{Valid: true}, + Version: sql.NullString{Valid: true}, + ChannelName: sql.NullString{Valid: true}, + BundlePath: sql.NullString{Valid: true}, + }) + return nil + }) + return &q + }, + Bundles: []*api.Bundle{ + {}, + }, + }, + { + Name: "all dependencies and properties are returned", + Querier: func(t *testing.T) sqlite.Querier { + var ( + q sqlitefakes.FakeQuerier + r sqlitefakes.FakeRowScanner + ) + q.QueryContextReturns(&NoRows, nil) + q.QueryContextReturnsOnCall(0, &r, nil) + r.NextReturnsOnCall(0, true) + r.NextReturnsOnCall(1, true) + cols := []Columns{ + { + BundleName: sql.NullString{Valid: true, String: "BundleName"}, + Version: sql.NullString{Valid: true, String: "Version"}, + ChannelName: sql.NullString{Valid: true, String: "ChannelName"}, + BundlePath: sql.NullString{Valid: true, String: "BundlePath"}, + DependencyType: sql.NullString{Valid: true, String: "Dependency1Type"}, + DependencyValue: sql.NullString{Valid: true, String: "Dependency1Value"}, + PropertyType: sql.NullString{Valid: true, String: "Property1Type"}, + PropertyValue: sql.NullString{Valid: true, String: "Property1Value"}, + }, + { + BundleName: sql.NullString{Valid: true, String: "BundleName"}, + Version: sql.NullString{Valid: true, String: "Version"}, + ChannelName: sql.NullString{Valid: true, String: "ChannelName"}, + BundlePath: sql.NullString{Valid: true, String: "BundlePath"}, + DependencyType: sql.NullString{Valid: true, String: "Dependency2Type"}, + DependencyValue: sql.NullString{Valid: true, String: "Dependency2Value"}, + PropertyType: sql.NullString{Valid: true, String: "Property2Type"}, + PropertyValue: sql.NullString{Valid: true, String: "Property2Value"}, + }, + } + var i int + r.ScanCalls(func(args ...interface{}) error { + if i < len(cols) { + ScanFromColumns(t, args, cols[i]) + i++ + } + return nil + }) + return &q + }, + Bundles: []*api.Bundle{ + { + CsvName: "BundleName", + ChannelName: "ChannelName", + BundlePath: "BundlePath", + Version: "Version", + Dependencies: []*api.Dependency{ + { + Type: "Dependency1Type", + Value: "Dependency1Value", + }, + { + Type: "Dependency2Type", + Value: "Dependency2Value", + }, + }, + Properties: []*api.Property{ + { + Type: "Property1Type", + Value: "Property1Value", + }, + { + Type: "Property2Type", + Value: "Property2Value", + }, + }, + }, + }, + }, + } { + t.Run(tc.Name, func(t *testing.T) { + var q sqlite.Querier + if tc.Querier != nil { + q = tc.Querier(t) + } + sq := sqlite.NewSQLLiteQuerierFromDBQuerier(q) + bundles, err := sq.ListBundles(context.Background()) + + assert := assert.New(t) + assert.Equal(tc.Bundles, bundles) + if tc.ErrorMessage == "" { + assert.NoError(err) + } else { + assert.EqualError(err, tc.ErrorMessage) + } + }) + } +} diff --git a/pkg/sqlite/sqlitefakes/fake_querier.go b/pkg/sqlite/sqlitefakes/fake_querier.go new file mode 100644 index 000000000..8d5ebca06 --- /dev/null +++ b/pkg/sqlite/sqlitefakes/fake_querier.go @@ -0,0 +1,120 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package sqlitefakes + +import ( + "context" + "sync" + + "github.com/operator-framework/operator-registry/pkg/sqlite" +) + +type FakeQuerier struct { + QueryContextStub func(context.Context, string, ...interface{}) (sqlite.RowScanner, error) + queryContextMutex sync.RWMutex + queryContextArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 []interface{} + } + queryContextReturns struct { + result1 sqlite.RowScanner + result2 error + } + queryContextReturnsOnCall map[int]struct { + result1 sqlite.RowScanner + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeQuerier) QueryContext(arg1 context.Context, arg2 string, arg3 ...interface{}) (sqlite.RowScanner, error) { + fake.queryContextMutex.Lock() + ret, specificReturn := fake.queryContextReturnsOnCall[len(fake.queryContextArgsForCall)] + fake.queryContextArgsForCall = append(fake.queryContextArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 []interface{} + }{arg1, arg2, arg3}) + fake.recordInvocation("QueryContext", []interface{}{arg1, arg2, arg3}) + fake.queryContextMutex.Unlock() + if fake.QueryContextStub != nil { + return fake.QueryContextStub(arg1, arg2, arg3...) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.queryContextReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeQuerier) QueryContextCallCount() int { + fake.queryContextMutex.RLock() + defer fake.queryContextMutex.RUnlock() + return len(fake.queryContextArgsForCall) +} + +func (fake *FakeQuerier) QueryContextCalls(stub func(context.Context, string, ...interface{}) (sqlite.RowScanner, error)) { + fake.queryContextMutex.Lock() + defer fake.queryContextMutex.Unlock() + fake.QueryContextStub = stub +} + +func (fake *FakeQuerier) QueryContextArgsForCall(i int) (context.Context, string, []interface{}) { + fake.queryContextMutex.RLock() + defer fake.queryContextMutex.RUnlock() + argsForCall := fake.queryContextArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeQuerier) QueryContextReturns(result1 sqlite.RowScanner, result2 error) { + fake.queryContextMutex.Lock() + defer fake.queryContextMutex.Unlock() + fake.QueryContextStub = nil + fake.queryContextReturns = struct { + result1 sqlite.RowScanner + result2 error + }{result1, result2} +} + +func (fake *FakeQuerier) QueryContextReturnsOnCall(i int, result1 sqlite.RowScanner, result2 error) { + fake.queryContextMutex.Lock() + defer fake.queryContextMutex.Unlock() + fake.QueryContextStub = nil + if fake.queryContextReturnsOnCall == nil { + fake.queryContextReturnsOnCall = make(map[int]struct { + result1 sqlite.RowScanner + result2 error + }) + } + fake.queryContextReturnsOnCall[i] = struct { + result1 sqlite.RowScanner + result2 error + }{result1, result2} +} + +func (fake *FakeQuerier) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.queryContextMutex.RLock() + defer fake.queryContextMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeQuerier) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ sqlite.Querier = new(FakeQuerier) diff --git a/pkg/sqlite/sqlitefakes/fake_rowscanner.go b/pkg/sqlite/sqlitefakes/fake_rowscanner.go new file mode 100644 index 000000000..3d0cb5b11 --- /dev/null +++ b/pkg/sqlite/sqlitefakes/fake_rowscanner.go @@ -0,0 +1,238 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package sqlitefakes + +import ( + "sync" + + "github.com/operator-framework/operator-registry/pkg/sqlite" +) + +type FakeRowScanner struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + NextStub func() bool + nextMutex sync.RWMutex + nextArgsForCall []struct { + } + nextReturns struct { + result1 bool + } + nextReturnsOnCall map[int]struct { + result1 bool + } + ScanStub func(...interface{}) error + scanMutex sync.RWMutex + scanArgsForCall []struct { + arg1 []interface{} + } + scanReturns struct { + result1 error + } + scanReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRowScanner) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.closeReturns + return fakeReturns.result1 +} + +func (fake *FakeRowScanner) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeRowScanner) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeRowScanner) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRowScanner) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRowScanner) Next() bool { + fake.nextMutex.Lock() + ret, specificReturn := fake.nextReturnsOnCall[len(fake.nextArgsForCall)] + fake.nextArgsForCall = append(fake.nextArgsForCall, struct { + }{}) + fake.recordInvocation("Next", []interface{}{}) + fake.nextMutex.Unlock() + if fake.NextStub != nil { + return fake.NextStub() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.nextReturns + return fakeReturns.result1 +} + +func (fake *FakeRowScanner) NextCallCount() int { + fake.nextMutex.RLock() + defer fake.nextMutex.RUnlock() + return len(fake.nextArgsForCall) +} + +func (fake *FakeRowScanner) NextCalls(stub func() bool) { + fake.nextMutex.Lock() + defer fake.nextMutex.Unlock() + fake.NextStub = stub +} + +func (fake *FakeRowScanner) NextReturns(result1 bool) { + fake.nextMutex.Lock() + defer fake.nextMutex.Unlock() + fake.NextStub = nil + fake.nextReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeRowScanner) NextReturnsOnCall(i int, result1 bool) { + fake.nextMutex.Lock() + defer fake.nextMutex.Unlock() + fake.NextStub = nil + if fake.nextReturnsOnCall == nil { + fake.nextReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.nextReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeRowScanner) Scan(arg1 ...interface{}) error { + fake.scanMutex.Lock() + ret, specificReturn := fake.scanReturnsOnCall[len(fake.scanArgsForCall)] + fake.scanArgsForCall = append(fake.scanArgsForCall, struct { + arg1 []interface{} + }{arg1}) + fake.recordInvocation("Scan", []interface{}{arg1}) + fake.scanMutex.Unlock() + if fake.ScanStub != nil { + return fake.ScanStub(arg1...) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.scanReturns + return fakeReturns.result1 +} + +func (fake *FakeRowScanner) ScanCallCount() int { + fake.scanMutex.RLock() + defer fake.scanMutex.RUnlock() + return len(fake.scanArgsForCall) +} + +func (fake *FakeRowScanner) ScanCalls(stub func(...interface{}) error) { + fake.scanMutex.Lock() + defer fake.scanMutex.Unlock() + fake.ScanStub = stub +} + +func (fake *FakeRowScanner) ScanArgsForCall(i int) []interface{} { + fake.scanMutex.RLock() + defer fake.scanMutex.RUnlock() + argsForCall := fake.scanArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRowScanner) ScanReturns(result1 error) { + fake.scanMutex.Lock() + defer fake.scanMutex.Unlock() + fake.ScanStub = nil + fake.scanReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRowScanner) ScanReturnsOnCall(i int, result1 error) { + fake.scanMutex.Lock() + defer fake.scanMutex.Unlock() + fake.ScanStub = nil + if fake.scanReturnsOnCall == nil { + fake.scanReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.scanReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRowScanner) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + fake.nextMutex.RLock() + defer fake.nextMutex.RUnlock() + fake.scanMutex.RLock() + defer fake.scanMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRowScanner) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ sqlite.RowScanner = new(FakeRowScanner)