From 0741616d2ea92fb372f8d5469d9256f3c856ffc1 Mon Sep 17 00:00:00 2001 From: Dave Jeffrey Date: Thu, 11 Jun 2015 11:11:28 +0100 Subject: [PATCH] Don't load in all drivers by default #40 Requires activating drivers with a _ style import, e.g. import "_ github.com/mattes/migrate/driver/postgres" --- README.md | 3 +++ driver/bash/bash.go | 5 ++++ driver/cassandra/cassandra.go | 5 ++++ driver/driver.go | 50 ++++++++--------------------------- driver/mysql/mysql.go | 5 ++++ driver/postgres/postgres.go | 5 ++++ driver/registry/registry.go | 20 ++++++++++++++ driver/sqlite3/sqlite3.go | 5 ++++ migrate/migrate_test.go | 2 ++ 9 files changed, 61 insertions(+), 39 deletions(-) create mode 100644 driver/registry/registry.go diff --git a/README.md b/README.md index ae09afc7..32873cba 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,9 @@ See GoDoc here: http://godoc.org/github.com/mattes/migrate/migrate ```go import "github.com/mattes/migrate/migrate" +// Import any required drivers so that they are registered and available +import _ "github.com/mattes/migrate/drivers/mysql" + // use synchronous versions of migration functions ... allErrors, ok := migrate.UpSync("driver://url", "./path") if !ok { diff --git a/driver/bash/bash.go b/driver/bash/bash.go index 156371e3..ed5ffa00 100644 --- a/driver/bash/bash.go +++ b/driver/bash/bash.go @@ -2,6 +2,7 @@ package bash import ( + "github.com/mattes/migrate/driver/registry" "github.com/mattes/migrate/file" _ "github.com/mattes/migrate/migrate/direction" ) @@ -30,3 +31,7 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { func (driver *Driver) Version() (uint64, error) { return uint64(0), nil } + +func init() { + registry.RegisterDriver("bash", Driver{}) +} diff --git a/driver/cassandra/cassandra.go b/driver/cassandra/cassandra.go index 8d72f43e..bb4c75c1 100644 --- a/driver/cassandra/cassandra.go +++ b/driver/cassandra/cassandra.go @@ -4,6 +4,7 @@ package cassandra import ( "fmt" "github.com/gocql/gocql" + "github.com/mattes/migrate/driver/registry" "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" "net/url" @@ -153,3 +154,7 @@ func (driver *Driver) Version() (uint64, error) { err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version) return uint64(version) - 1, err } + +func init() { + registry.RegisterDriver("cassandra", Driver{}) +} diff --git a/driver/driver.go b/driver/driver.go index 9c00074b..69f6fe6e 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -5,12 +5,9 @@ import ( "errors" "fmt" neturl "net/url" // alias to allow `url string` func signature in New + "reflect" - "github.com/mattes/migrate/driver/bash" - "github.com/mattes/migrate/driver/cassandra" - "github.com/mattes/migrate/driver/mysql" - "github.com/mattes/migrate/driver/postgres" - "github.com/mattes/migrate/driver/sqlite3" + "github.com/mattes/migrate/driver/registry" "github.com/mattes/migrate/file" ) @@ -47,51 +44,26 @@ func New(url string) (Driver, error) { return nil, err } - switch u.Scheme { - case "postgres": - d := &postgres.Driver{} - verifyFilenameExtension("postgres", d) - if err := d.Initialize(url); err != nil { - return nil, err - } - return d, nil - - case "mysql": - d := &mysql.Driver{} - verifyFilenameExtension("mysql", d) - if err := d.Initialize(url); err != nil { + driver := registry.GetDriver(u.Scheme) + if driver != nil { + blankDriver := reflect.New(reflect.TypeOf(driver)).Interface() + d, ok := blankDriver.(Driver) + if !ok { + err := errors.New(fmt.Sprintf("Driver '%s' does not implement the Driver interface")) return nil, err } - return d, nil - - case "bash": - d := &bash.Driver{} - verifyFilenameExtension("bash", d) + verifyFilenameExtension(u.Scheme, d) if err := d.Initialize(url); err != nil { return nil, err } - return d, nil - case "cassandra": - d := &cassandra.Driver{} - verifyFilenameExtension("cassanda", d) - if err := d.Initialize(url); err != nil { - return nil, err - } - return d, nil - case "sqlite3": - d := &sqlite3.Driver{} - verifyFilenameExtension("sqlite3", d) - if err := d.Initialize(url); err != nil { - return nil, err - } return d, nil - default: + } else { return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme)) } } -// verifyFilenameExtension panics if the drivers filename extension +// verifyFilenameExtension panics if the driver's filename extension // is not correct or empty. func verifyFilenameExtension(driverName string, d Driver) { f := d.FilenameExtension() diff --git a/driver/mysql/mysql.go b/driver/mysql/mysql.go index ef026ee8..2a3fec64 100644 --- a/driver/mysql/mysql.go +++ b/driver/mysql/mysql.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "github.com/go-sql-driver/mysql" + "github.com/mattes/migrate/driver/registry" "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" "regexp" @@ -177,3 +178,7 @@ func (driver *Driver) Version() (uint64, error) { return version, nil } } + +func init() { + registry.RegisterDriver("mysql", Driver{}) +} diff --git a/driver/postgres/postgres.go b/driver/postgres/postgres.go index bacd2ee6..af386950 100644 --- a/driver/postgres/postgres.go +++ b/driver/postgres/postgres.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/lib/pq" + "github.com/mattes/migrate/driver/registry" "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" "strconv" @@ -119,3 +120,7 @@ func (driver *Driver) Version() (uint64, error) { return version, nil } } + +func init() { + registry.RegisterDriver("postgres", Driver{}) +} diff --git a/driver/registry/registry.go b/driver/registry/registry.go new file mode 100644 index 00000000..cbb3bb95 --- /dev/null +++ b/driver/registry/registry.go @@ -0,0 +1,20 @@ +// Package registry maintains a map of imported and available drivers +package registry + +var driverRegistry map[string]interface{} + +// Registers a driver so it can be created from its name. Drivers should +// call this from an init() function so that they registers themselvse on +// import +func RegisterDriver(name string, driver interface{}) { + driverRegistry[name] = driver +} + +// Retrieves a registered driver by name +func GetDriver(name string) interface{} { + return driverRegistry[name] +} + +func init() { + driverRegistry = make(map[string]interface{}) +} diff --git a/driver/sqlite3/sqlite3.go b/driver/sqlite3/sqlite3.go index a33b7e5c..c009b7b1 100644 --- a/driver/sqlite3/sqlite3.go +++ b/driver/sqlite3/sqlite3.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/mattes/migrate/driver/registry" "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" "github.com/mattn/go-sqlite3" @@ -123,3 +124,7 @@ func (driver *Driver) Version() (uint64, error) { return version, nil } } + +func init() { + registry.RegisterDriver("sqlite3", Driver{}) +} diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index e2e0793c..4fe1140d 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -3,6 +3,8 @@ package migrate import ( "io/ioutil" "testing" + // Ensure imports for each driver we wish to test + _ "github.com/mattes/migrate/driver/postgres" ) // Add Driver URLs here to test basic Up, Down, .. functions.