forked from kolide/goose
-
Notifications
You must be signed in to change notification settings - Fork 2
/
migrate.go
234 lines (190 loc) · 5.58 KB
/
migrate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
package goose
import (
"database/sql"
"errors"
"fmt"
"log"
"path/filepath"
"runtime"
"sort"
)
var (
ErrNoCurrentVersion = errors.New("no current version found")
ErrNoNextVersion = errors.New("no next version found")
globalGoose = &Client{
TableName: "goose_db_version",
Dialect: &PostgresDialect{},
}
)
type Migrations []*Migration
// helpers so we can use pkg sort
func (ms Migrations) Len() int { return len(ms) }
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms Migrations) Less(i, j int) bool {
if ms[i].Version == ms[j].Version {
log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source)
}
return ms[i].Version < ms[j].Version
}
func (ms Migrations) Current(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version == current {
return ms[i], nil
}
}
return nil, ErrNoCurrentVersion
}
func (ms Migrations) Next(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version > current {
return ms[i], nil
}
}
return nil, ErrNoNextVersion
}
func (ms Migrations) Last() (*Migration, error) {
if len(ms) == 0 {
return nil, ErrNoNextVersion
}
return ms[len(ms)-1], nil
}
func (ms Migrations) String() string {
str := ""
for _, m := range ms {
str += fmt.Sprintln(m)
}
return str
}
// AddMigration adds a new go migration to the goose struct.
func (c *Client) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
v, _ := NumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
c.Migrations = append(c.Migrations, migration)
}
// AddMigration exists for legacy support of the package global use of Goose.
// Calling the AddMigration method on a Goose struct is the preferred method.
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
// We can't just use globalGoose.AddMigration here because we need to
// correctly record the caller.
_, filename, _, _ := runtime.Caller(1)
v, _ := NumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
globalGoose.Migrations = append(globalGoose.Migrations, migration)
}
// collect all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version
func (c *Client) collectMigrations(dirpath string, current, target int64) (Migrations, error) {
var migrations Migrations
// only load migrations from file if a path is explicitly provided
if dirpath != "" {
// extract the numeric component of each migration,
// filter out any uninteresting files,
// and ensure we only have one file per migration version.
sqlMigrations, err := filepath.Glob(dirpath + "/*.sql")
if err != nil {
return nil, err
}
for _, file := range sqlMigrations {
v, err := NumericComponent(file)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
migrations = append(migrations, migration)
}
}
}
for _, migration := range c.Migrations {
v, err := NumericComponent(migration.Source)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
}
}
migrations = sortAndConnectMigrations(migrations)
return migrations, nil
}
func sortAndConnectMigrations(migrations Migrations) Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
}
// retrieve the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func (c *Client) GetDBVersion(db *sql.DB) (int64, error) {
rows, err := c.Dialect.dbVersionQuery(db, c.TableName)
if err != nil {
return 0, c.createVersionTable(db)
}
defer rows.Close()
// The most recent record for each migration specifies
// whether it has been applied or rolled back.
// The first version we find that has been applied is the current version.
toSkip := make([]int64, 0)
for rows.Next() {
var row MigrationRecord
if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil {
log.Fatal("error scanning rows:", err)
}
// have we already marked this version to be skipped?
skip := false
for _, v := range toSkip {
if v == row.VersionId {
skip = true
break
}
}
if skip {
continue
}
// if version has been applied we're done
if row.IsApplied {
return row.VersionId, nil
}
// latest version of migration has not been applied.
toSkip = append(toSkip, row.VersionId)
}
panic("unreachable")
}
// Create the goose_db_version table
// and insert the initial 0 value into it
func (c *Client) createVersionTable(db *sql.DB) error {
txn, err := db.Begin()
if err != nil {
return err
}
d := c.Dialect
if _, err := txn.Exec(d.createVersionTableSql(c.TableName)); err != nil {
txn.Rollback()
return err
}
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSql(c.TableName), version, applied); err != nil {
txn.Rollback()
return err
}
return txn.Commit()
}