Skip to content
This repository has been archived by the owner on Jun 28, 2018. It is now read-only.

Avoid version collisions #58

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ sudo: required
go:
- 1.4
- 1.5
- 1.6

services:
- docker
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ go migrate.Up(pipe, "driver://url", "./path")
The format of migration files looks like this:

```
001_initial_plan_to_do_sth.up.sql # up migration instructions
001_initial_plan_to_do_sth.down.sql # down migration instructions
002_xxx.up.sql
002_xxx.down.sql
20060102150405_initial_plan_to_do_sth.up.sql # up migration instructions
20060102150405_initial_plan_to_do_sth.down.sql # down migration instructions
20060102150506_xxx.up.sql
20060102150506_xxx.down.sql
...
```

Why two files? This way you could still do sth like
``psql -f ./db/migrations/001_initial_plan_to_do_sth.up.sql`` and there is no
``psql -f ./db/migrations/20060102150405_initial_plan_to_do_sth.up.sql`` and there is no
need for any custom markup language to divide up and down migrations. Please note
that the filename extension depends on the driver.

Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go: &go
- $GOPATH:/go
go-test:
<<: *go
command: sh -c 'go get -t -v ./... && go test -v ./...'
command: sh -c 'go get -t -v ./... && go test -p=1 -v ./...'
links:
- postgres
- mysql
Expand Down
8 changes: 6 additions & 2 deletions driver/bash/bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
return
}

func (driver *Driver) Version() (uint64, error) {
return uint64(0), nil
func (driver *Driver) Version() (file.Version, error) {
return file.Version(0), nil
}

func (driver *Driver) Versions() (file.Versions, error) {
return file.Versions{0}, nil
}

func init() {
Expand Down
87 changes: 34 additions & 53 deletions driver/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package cassandra
import (
"fmt"
"net/url"
"sort"
"strconv"
"strings"
"time"
Expand All @@ -19,25 +20,7 @@ type Driver struct {
}

const (
tableName = "schema_migrations"
versionRow = 1
)

type counterStmt bool

func (c counterStmt) String() string {
sign := ""
if bool(c) {
sign = "+"
} else {
sign = "-"
}
return "UPDATE " + tableName + " SET version = version " + sign + " 1 where versionRow = ?"
}

const (
up counterStmt = true
down counterStmt = false
tableName = "schema_migrations"
)

// Cassandra Driver URL format:
Expand Down Expand Up @@ -78,14 +61,14 @@ func (driver *Driver) Initialize(rawurl string) error {
}

driver.session, err = cluster.CreateSession()

if err != nil {
return err
}

if err := driver.ensureVersionTableExists(); err != nil {
return err
}

return nil
}

Expand All @@ -95,59 +78,43 @@ func (driver *Driver) Close() error {
}

func (driver *Driver) ensureVersionTableExists() error {
err := driver.session.Query("CREATE TABLE IF NOT EXISTS " + tableName + " (version counter, versionRow bigint primary key);").Exec()
if err != nil {
return err
}

_, err = driver.Version()
if err != nil {
driver.session.Query(up.String(), versionRow).Exec()
}

return nil
err := driver.session.Query("CREATE TABLE IF NOT EXISTS " + tableName + " (version bigint primary key);").Exec()
return err
}

func (driver *Driver) FilenameExtension() string {
return "cql"
}

func (driver *Driver) version(d direction.Direction, invert bool) error {
var stmt counterStmt
switch d {
case direction.Up:
stmt = up
case direction.Down:
stmt = down
}
if invert {
stmt = !stmt
}
return driver.session.Query(stmt.String(), versionRow).Exec()
}

func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
var err error
defer func() {
if err != nil {
// Invert version direction if we couldn't apply the changes for some reason.
if err := driver.version(f.Direction, true); err != nil {
pipe <- err
if errRollback := driver.session.Query("DELETE FROM "+tableName+" WHERE version = ?", f.Version).Exec(); errRollback != nil {
pipe <- errRollback
}
pipe <- err
}
close(pipe)
}()

pipe <- f
if err = driver.version(f.Direction, false); err != nil {
return
}

if err = f.ReadContent(); err != nil {
return
}

if f.Direction == direction.Up {
if err = driver.session.Query("INSERT INTO "+tableName+" (version) VALUES (?)", f.Version).Exec(); err != nil {
return
}
} else if f.Direction == direction.Down {
if err = driver.session.Query("DELETE FROM "+tableName+" WHERE version = ?", f.Version).Exec(); err != nil {
return
}
}

for _, query := range strings.Split(string(f.Content), ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
Expand All @@ -160,10 +127,24 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
}
}

func (driver *Driver) Version() (uint64, error) {
func (driver *Driver) Version() (file.Version, error) {
versions, err := driver.Versions()
if len(versions) == 0 {
return 0, err
}
return versions[0], err
}

func (driver *Driver) Versions() (file.Versions, error) {
versions := file.Versions{}
iter := driver.session.Query("SELECT version FROM " + tableName).Iter()
var version int64
err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version)
return uint64(version) - 1, err
for iter.Scan(&version) {
versions = append(versions, file.Version(version))
}
err := iter.Close()
sort.Sort(sort.Reverse(versions))
return versions, err
}

func init() {
Expand Down
62 changes: 52 additions & 10 deletions driver/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cassandra
import (
"net/url"
"os"
"reflect"
"testing"
"time"

Expand All @@ -17,7 +18,7 @@ func TestMigrate(t *testing.T) {

host := os.Getenv("CASSANDRA_PORT_9042_TCP_ADDR")
port := os.Getenv("CASSANDRA_PORT_9042_TCP_PORT")
driverUrl := "cassandra://" + host + ":" + port + "/system"
driverUrl := "cassandra://" + host + ":" + port + "/system?protocol=4"

// prepare a clean test database
u, err := url.Parse(driverUrl)
Expand All @@ -29,19 +30,20 @@ func TestMigrate(t *testing.T) {
cluster.Keyspace = u.Path[1:len(u.Path)]
cluster.Consistency = gocql.All
cluster.Timeout = 1 * time.Minute
cluster.ProtoVersion = 4

session, err = cluster.CreateSession()

if err != nil {
t.Fatal(err)
}

if err := session.Query(`CREATE KEYSPACE IF NOT EXISTS migrate WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};`).Exec(); err != nil {
if err := resetKeySpace(session); err != nil {
t.Fatal(err)
}

cluster.Keyspace = "migrate"
session, err = cluster.CreateSession()
driverUrl = "cassandra://" + host + ":" + port + "/migrate"
driverUrl = "cassandra://" + host + ":" + port + "/migrate?protocol=4"

d := &Driver{}
if err := d.Initialize(driverUrl); err != nil {
Expand All @@ -51,8 +53,8 @@ func TestMigrate(t *testing.T) {
files := []file.File{
{
Path: "/foobar",
FileName: "001_foobar.up.sql",
Version: 1,
FileName: "20060102150405_foobar.up.sql",
Version: 20060102150405,
Name: "foobar",
Direction: direction.Up,
Content: []byte(`
Expand All @@ -66,8 +68,8 @@ func TestMigrate(t *testing.T) {
},
{
Path: "/foobar",
FileName: "002_foobar.down.sql",
Version: 1,
FileName: "20060102150405_foobar.down.sql",
Version: 20060102150405,
Name: "foobar",
Direction: direction.Down,
Content: []byte(`
Expand All @@ -76,8 +78,8 @@ func TestMigrate(t *testing.T) {
},
{
Path: "/foobar",
FileName: "002_foobar.up.sql",
Version: 2,
FileName: "20060102150406_foobar.up.sql",
Version: 20060102150406,
Name: "foobar",
Direction: direction.Up,
Content: []byte(`
Expand All @@ -95,6 +97,26 @@ func TestMigrate(t *testing.T) {
t.Fatal(errs)
}

version, err := d.Version()
if err != nil {
t.Fatal(err)
}

if version != 20060102150405 {
t.Errorf("Expected version to be: %d, got: %d", 20060102150405, version)
}

// Check versions applied in DB
expectedVersions := file.Versions{20060102150405}
versions, err := d.Versions()
if err != nil {
t.Errorf("Could not fetch versions: %s", err)
}

if !reflect.DeepEqual(versions, expectedVersions) {
t.Errorf("Expected versions to be: %v, got: %v", expectedVersions, versions)
}

pipe = pipep.New()
go d.Migrate(files[1], pipe)
errs = pipep.ReadErrors(pipe)
Expand All @@ -109,8 +131,28 @@ func TestMigrate(t *testing.T) {
t.Error("Expected test case to fail")
}

// Check versions applied in DB
expectedVersions = file.Versions{}
versions, err = d.Versions()
if err != nil {
t.Errorf("Could not fetch versions: %s", err)
}

if !reflect.DeepEqual(versions, expectedVersions) {
t.Errorf("Expected versions to be: %v, got: %v", expectedVersions, versions)
}

if err := resetKeySpace(session); err != nil {
t.Fatal(err)
}

if err := d.Close(); err != nil {
t.Fatal(err)
}

}

func resetKeySpace(session *gocql.Session) error {
session.Query(`DROP KEYSPACE migrate;`).Exec()
return session.Query(`CREATE KEYSPACE IF NOT EXISTS migrate WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};`).Exec()
}
5 changes: 4 additions & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ type Driver interface {
Migrate(file file.File, pipe chan interface{})

// Version returns the current migration version.
Version() (uint64, error)
Version() (file.Version, error)

// Versions returns the list of applied migrations
Versions() (file.Versions, error)
}

// New returns Driver and calls Initialize on it
Expand Down
26 changes: 23 additions & 3 deletions driver/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (driver *Driver) Close() error {
}

func (driver *Driver) ensureVersionTableExists() error {
_, err := driver.db.Exec("CREATE TABLE IF NOT EXISTS " + tableName + " (version int not null primary key);")
_, err := driver.db.Exec("CREATE TABLE IF NOT EXISTS " + tableName + " (version bigint not null primary key);")

if _, isWarn := err.(mysql.MySQLWarnings); err != nil && !isWarn {
return err
Expand Down Expand Up @@ -167,8 +167,8 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
}
}

func (driver *Driver) Version() (uint64, error) {
var version uint64
func (driver *Driver) Version() (file.Version, error) {
var version file.Version
err := driver.db.QueryRow("SELECT version FROM " + tableName + " ORDER BY version DESC").Scan(&version)
switch {
case err == sql.ErrNoRows:
Expand All @@ -180,6 +180,26 @@ func (driver *Driver) Version() (uint64, error) {
}
}

func (driver *Driver) Versions() (file.Versions, error) {
versions := file.Versions{}

rows, err := driver.db.Query("SELECT version FROM " + tableName + " ORDER BY version DESC")
if err != nil {
return versions, err
}
defer rows.Close()
for rows.Next() {
var version file.Version
err := rows.Scan(&version)
if err != nil {
return versions, err
}
versions = append(versions, version)
}
err = rows.Err()
return versions, err
}

func init() {
driver.RegisterDriver("mysql", &Driver{})
}