diff --git a/sql/catalog.go b/sql/catalog.go index b393e7e56f..c9f9dd2d7e 100644 --- a/sql/catalog.go +++ b/sql/catalog.go @@ -37,15 +37,15 @@ var ErrAsOfNotSupported = errors.NewKind("AS OF not supported for database %s") // expression with a view when the view definition has its own AS OF expressions. var ErrIncompatibleAsOf = errors.NewKind("incompatible use of AS OF: %s") -// Catalog holds databases, tables and functions. +// Catalog holds sliceProvider, tables and functions. type Catalog struct { FunctionRegistry *ProcessList *MemoryManager - mu sync.RWMutex - dbs Databases - locks sessionLocks + mu sync.RWMutex + provider DatabaseProvider + locks sessionLocks } type tableLocks map[string]struct{} @@ -56,118 +56,121 @@ type sessionLocks map[uint32]dbLocks // NewCatalog returns a new empty Catalog. func NewCatalog() *Catalog { + return NewCatalogWithDbProvider(&sliceProvider{}) +} + +// NewCatalogWithDbProvider returns a new empty Catalog. +func NewCatalogWithDbProvider(provider DatabaseProvider) *Catalog { return &Catalog{ FunctionRegistry: NewFunctionRegistry(), MemoryManager: NewMemoryManager(ProcessMemory), ProcessList: NewProcessList(), + provider: provider, locks: make(sessionLocks), } } -// AllDatabases returns all databases in the catalog. -func (c *Catalog) AllDatabases() Databases { +// AllDatabases returns all sliceProvider in the catalog. +func (c *Catalog) AllDatabases() []Database { c.mu.RLock() defer c.mu.RUnlock() - var result = make(Databases, len(c.dbs)) - copy(result, c.dbs) - return result + return c.provider.AllDatabases() } // AddDatabase adds a new database to the catalog. func (c *Catalog) AddDatabase(db Database) { c.mu.Lock() - c.dbs.Add(db) - c.mu.Unlock() + defer c.mu.Unlock() + c.provider.AddDatabase(db) } // RemoveDatabase removes a database from the catalog. func (c *Catalog) RemoveDatabase(dbName string) { c.mu.Lock() - c.dbs.Delete(dbName) - c.mu.Unlock() + defer c.mu.Unlock() + c.provider.DropDatabase(dbName) } func (c *Catalog) HasDB(db string) bool { c.mu.RLock() defer c.mu.RUnlock() - _, err := c.dbs.Database(db) - - return err == nil + return c.provider.HasDatabase(db) } // Database returns the database with the given name. func (c *Catalog) Database(db string) (Database, error) { c.mu.RLock() defer c.mu.RUnlock() - return c.dbs.Database(db) -} - -// Table returns the table in the given database with the given name. -func (c *Catalog) Table(ctx *Context, db, table string) (Table, Database, error) { - c.mu.RLock() - defer c.mu.RUnlock() - return c.dbs.Table(ctx, db, table) + return c.provider.Database(db) } -// TableAsOf returns the table in the given database with the given name, as it existed at the time given. The database -// named must support timed queries. -func (c *Catalog) TableAsOf(ctx *Context, db, table string, time interface{}) (Table, Database, error) { - c.mu.RLock() - defer c.mu.RUnlock() - return c.dbs.TableAsOf(ctx, db, table, time) -} +// LockTable adds a lock for the given table and session client. It is assumed +// the database is the current database in use. +func (c *Catalog) LockTable(ctx *Context, table string) { + id := ctx.ID() + db := ctx.GetCurrentDatabase() -// Databases is a collection of Database. -type Databases []Database + c.mu.Lock() + defer c.mu.Unlock() -// Database returns the Database with the given name if it exists. -func (d Databases) Database(name string) (Database, error) { - if len(d) == 0 { - return nil, ErrDatabaseNotFound.New(name) + if _, ok := c.locks[id]; !ok { + c.locks[id] = make(dbLocks) } - name = strings.ToLower(name) - var dbNames []string - for _, db := range d { - if strings.ToLower(db.Name()) == name { - return db, nil - } - dbNames = append(dbNames, db.Name()) + if _, ok := c.locks[id][db]; !ok { + c.locks[id][db] = make(tableLocks) } - similar := similartext.Find(dbNames, name) - return nil, ErrDatabaseNotFound.New(name + similar) -} -// Add adds a new database. -func (d *Databases) Add(db Database) { - *d = append(*d, db) + c.locks[id][db][table] = struct{}{} } -// Delete removes a database. -func (d *Databases) Delete(dbName string) { - idx := -1 - for i, db := range *d { - if db.Name() == dbName { - idx = i - break +// UnlockTables unlocks all tables for which the given session client has a +// lock. +func (c *Catalog) UnlockTables(ctx *Context, id uint32) error { + c.mu.Lock() + defer c.mu.Unlock() + + var errors []string + for db, tables := range c.locks[id] { + for t := range tables { + database, err := c.provider.Database(db) + if err != nil { + return err + } + + table, _, err := database.GetTableInsensitive(ctx, t) + if err == nil { + if lockable, ok := table.(Lockable); ok { + if e := lockable.Unlock(ctx, id); e != nil { + errors = append(errors, e.Error()) + } + } + } else { + errors = append(errors, err.Error()) + } } } - if idx != -1 { - *d = append((*d)[:idx], (*d)[idx+1:]...) + delete(c.locks, id) + if len(errors) > 0 { + return fmt.Errorf("error unlocking tables for %d: %s", id, strings.Join(errors, ", ")) } + + return nil } -// Table returns the Table with the given name if it exists. -func (d Databases) Table(ctx *Context, dbName string, tableName string) (Table, Database, error) { - db, err := d.Database(dbName) +// Table returns the table in the given database with the given name. +func (c *Catalog) Table(ctx *Context, dbName, tableName string) (Table, Database, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + db, err := c.Database(dbName) if err != nil { return nil, nil, err } tbl, ok, err := db.GetTableInsensitive(ctx, tableName) - if err != nil { return nil, nil, err } else if !ok { @@ -177,20 +180,13 @@ func (d Databases) Table(ctx *Context, dbName string, tableName string) (Table, return tbl, db, nil } -func suggestSimilarTables(db Database, ctx *Context, tableName string) error { - tableNames, err := db.GetTableNames(ctx) - if err != nil { - return err - } - - similar := similartext.Find(tableNames, tableName) - return ErrTableNotFound.New(tableName + similar) -} +// TableAsOf returns the table in the given database with the given name, as it existed at the time given. The database +// named must support timed queries. +func (c *Catalog) TableAsOf(ctx *Context, dbName, tableName string, asOf interface{}) (Table, Database, error) { + c.mu.RLock() + defer c.mu.RUnlock() -// TableAsOf returns the table with the name given at the time given, if it existed. The database named must implement -// sql.VersionedDatabase or an error is returned. -func (d Databases) TableAsOf(ctx *Context, dbName string, tableName string, asOf interface{}) (Table, Database, error) { - db, err := d.Database(dbName) + db, err := c.Database(dbName) if err != nil { return nil, nil, err } @@ -211,6 +207,16 @@ func (d Databases) TableAsOf(ctx *Context, dbName string, tableName string, asOf return tbl, versionedDb, nil } +func suggestSimilarTables(db Database, ctx *Context, tableName string) error { + tableNames, err := db.GetTableNames(ctx) + if err != nil { + return err + } + + similar := similartext.Find(tableNames, tableName) + return ErrTableNotFound.New(tableName + similar) +} + func suggestSimilarTablesAsOf(db VersionedDatabase, ctx *Context, tableName string, time interface{}) error { tableNames, err := db.GetTableNamesAsOf(ctx, time) if err != nil { @@ -221,52 +227,61 @@ func suggestSimilarTablesAsOf(db VersionedDatabase, ctx *Context, tableName stri return ErrTableNotFound.New(tableName + similar) } -// LockTable adds a lock for the given table and session client. It is assumed -// the database is the current database in use. -func (c *Catalog) LockTable(ctx *Context, table string) { - id := ctx.ID() - db := ctx.GetCurrentDatabase() +// sliceProvider is a collection of Database. +type sliceProvider []Database - c.mu.Lock() - defer c.mu.Unlock() +var _ DatabaseProvider = &sliceProvider{} - if _, ok := c.locks[id]; !ok { - c.locks[id] = make(dbLocks) +// Database returns the Database with the given name if it exists. +func (d *sliceProvider) Database(name string) (Database, error) { + if len(*d) == 0 { + return nil, ErrDatabaseNotFound.New(name) } - if _, ok := c.locks[id][db]; !ok { - c.locks[id][db] = make(tableLocks) + name = strings.ToLower(name) + var dbNames []string + for _, db := range *d { + if strings.ToLower(db.Name()) == name { + return db, nil + } + dbNames = append(dbNames, db.Name()) } + similar := similartext.Find(dbNames, name) + return nil, ErrDatabaseNotFound.New(name + similar) +} - c.locks[id][db][table] = struct{}{} +// HasDatabase returns the Database with the given name if it exists. +func (d *sliceProvider) HasDatabase(name string) bool { + name = strings.ToLower(name) + for _, db := range *d { + if strings.ToLower(db.Name()) == name { + return true + } + } + return false } -// UnlockTables unlocks all tables for which the given session client has a -// lock. -func (c *Catalog) UnlockTables(ctx *Context, id uint32) error { - c.mu.Lock() - defer c.mu.Unlock() +// AllDatabases returns the Database with the given name if it exists. +func (d *sliceProvider) AllDatabases() []Database { + return *d +} - var errors []string - for db, tables := range c.locks[id] { - for t := range tables { - table, _, err := c.dbs.Table(ctx, db, t) - if err == nil { - if lockable, ok := table.(Lockable); ok { - if e := lockable.Unlock(ctx, id); e != nil { - errors = append(errors, e.Error()) - } - } - } else { - errors = append(errors, err.Error()) - } +// AddDatabase adds a new database. +func (d *sliceProvider) AddDatabase(db Database) { + *d = append(*d, db) +} + +// DropDatabase removes a database. +func (d *sliceProvider) DropDatabase(dbName string) { + idx := -1 + for i, db := range *d { + if db.Name() == dbName { + idx = i + break } } - delete(c.locks, id) - if len(errors) > 0 { - return fmt.Errorf("error unlocking tables for %d: %s", id, strings.Join(errors, ", ")) + if idx != -1 { + *d = append((*d)[:idx], (*d)[idx+1:]...) } - - return nil } diff --git a/sql/catalog_test.go b/sql/catalog_test.go index 14365b9850..1a67a391b9 100644 --- a/sql/catalog_test.go +++ b/sql/catalog_test.go @@ -27,7 +27,7 @@ import ( func TestAllDatabases(t *testing.T) { require := require.New(t) - var dbs = sql.Databases{ + var dbs = []sql.Database{ memory.NewDatabase("a"), memory.NewDatabase("b"), memory.NewDatabase("c"), diff --git a/sql/core.go b/sql/core.go index 032c79ddd7..1e4dde7c12 100644 --- a/sql/core.go +++ b/sql/core.go @@ -477,6 +477,15 @@ type RowUpdater interface { Closer } +// DatabaseProvider is a collection of Database. +type DatabaseProvider interface { + Database(name string) (Database, error) + HasDatabase(name string) bool + AllDatabases() []Database + AddDatabase(db Database) + DropDatabase(name string) +} + // Database represents the database. type Database interface { Nameable diff --git a/sql/index_registry.go b/sql/index_registry.go index 7d5021afe4..4ec8df1609 100644 --- a/sql/index_registry.go +++ b/sql/index_registry.go @@ -91,7 +91,7 @@ func (r *IndexRegistry) RegisterIndexDriver(driver IndexDriver) { // LoadIndexes creates load functions for all indexes for all dbs, tables and drivers. These functions are called // as needed by the query -func (r *IndexRegistry) LoadIndexes(ctx *Context, dbs Databases) error { +func (r *IndexRegistry) LoadIndexes(ctx *Context, dbs []Database) error { r.driversMut.RLock() defer r.driversMut.RUnlock() r.mut.Lock() diff --git a/sql/index_registry_test.go b/sql/index_registry_test.go index 96a131001b..dfd49c9deb 100644 --- a/sql/index_registry_test.go +++ b/sql/index_registry_test.go @@ -283,7 +283,7 @@ func TestLoadIndexes(t *testing.T) { registry.RegisterIndexDriver(d1) registry.RegisterIndexDriver(d2) - dbs := Databases{ + dbs := []Database{ dummyDB{ name: "db1", tables: map[string]Table{ @@ -331,7 +331,7 @@ func TestLoadOutdatedIndexes(t *testing.T) { registry := NewIndexRegistry() registry.RegisterIndexDriver(d) - dbs := Databases{ + dbs := []Database{ dummyDB{ name: "db1", tables: map[string]Table{