/
migrate.go
119 lines (108 loc) · 2.89 KB
/
migrate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package db
import (
"bytes"
"context"
"embed"
"fmt"
"sort"
"time"
"github.com/jackc/pgx/v5"
)
//go:embed migrations/*.sql
var migrations embed.FS
// Migrate runs database migrations.
func Migrate() error {
// Get the migrations.
entries, readErr := migrations.ReadDir("migrations")
if readErr != nil {
return readErr
}
filenames := make([]string, len(entries))
for i, v := range entries {
filenames[i] = v.Name()
}
sort.Strings(filenames)
// Handle the function to make all contexts.
cancellers := []func(){}
defer func() {
for _, v := range cancellers {
v()
}
}()
contextMaker := func() context.Context {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
cancellers = append(cancellers, cancel)
return ctx
}
// Ensure the migrations table exists.
c := dbConn()
_, tableErr := c.Exec(contextMaker(), "CREATE TABLE IF NOT EXISTS migrations (filename TEXT PRIMARY KEY)")
if tableErr != nil {
return tableErr
}
// Get the lock.
_, err := UseGlobalLock(contextMaker(), "migrations", func() (struct{}, error) {
// Get all the migrations from the table.
rows, migrationsGetErr := c.Query(contextMaker(), "SELECT filename FROM migrations")
if migrationsGetErr != nil {
return struct{}{}, migrationsGetErr
}
defer rows.Close()
migrationsRan := make([]string, 0)
for rows.Next() {
var filename string
if migrationsGetErr = rows.Scan(&filename); migrationsGetErr != nil {
return struct{}{}, migrationsGetErr
}
migrationsRan = append(migrationsRan, filename)
}
rows.Close()
// Run all pending migrations.
for _, filename := range filenames {
// Check if it has already been ran.
ran := false
for _, dbFilename := range migrationsRan {
if dbFilename == filename {
fmt.Println("[db] Migration", filename, "already ran - skipping!")
ran = true
break
}
}
// Run the migration if not.
if !ran {
// Get the migration SQL.
migrationSql, err := migrations.ReadFile("migrations/" + filename)
if err != nil {
return struct{}{}, err
}
// Check if it starts with "-- nosplit".
parts := [][]byte{}
if bytes.HasPrefix(migrationSql, []byte("-- nosplit")) {
parts = [][]byte{migrationSql}
} else {
parts = bytes.Split(migrationSql, []byte(";"))
}
batch := &pgx.Batch{}
for _, v := range parts {
batch.Queue(string(v))
}
batch.Queue("INSERT INTO migrations (filename) VALUES ($1)", filename)
fmt.Print("[db] Running migration ", filename, "...")
results := c.SendBatch(contextMaker(), batch)
for i := 0; i < len(parts)+1; i++ {
_, err = results.Exec()
if err != nil && err.Error() != "no result" {
_ = results.Close()
return struct{}{}, err
}
}
_ = results.Close()
fmt.Println(" success!")
}
}
// Return no errors.
fmt.Println("[db] All migrations ran!")
return struct{}{}, nil
})
return err
}