-
Notifications
You must be signed in to change notification settings - Fork 532
/
schemas.go
203 lines (165 loc) · 4.9 KB
/
schemas.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package postgres
import (
"bytes"
"fmt"
"io/fs"
"path/filepath"
"sort"
"strconv"
"strings"
"github.com/ignite/cli/v29/ignite/pkg/errors"
)
// SchemasDir defines the name for the embedded schema directory.
const SchemasDir = "schemas"
const (
defaultSchemasTableName = "schema"
sqlBeginTX = "BEGIN"
sqlCommitTX = "COMMIT"
sqlCommandSuffix = ";"
tplSchemaInsertSQL = `
INSERT INTO %s(version)
VALUES(%d)
`
tplSchemaTableDDL = `
CREATE TABLE IF NOT EXISTS %[1]v (
version SMALLINT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT %[1]v_pk PRIMARY KEY (version)
)
`
tplSchemaVersionSQL = `
SELECT COALESCE(MAX(version), 0)
FROM %s
`
)
// SchemasWalkFunc is the type of the function called by WalkFrom.
type SchemasWalkFunc func(version uint64, script []byte) error
// NewSchemas creates a new embedded SQL schema manager.
// The embedded FS is used to iterate the schema files.
// By default, the applied schema versions are stored in the "schema"
// table but the name can have a prefix namespace when different
// packages are storing the schemas in the same database.
func NewSchemas(fs fs.FS, namespace string) Schemas {
tableName := defaultSchemasTableName
if namespace != "" {
tableName = fmt.Sprintf("%s_%s", namespace, tableName)
}
return Schemas{tableName, fs}
}
// Schemas defines a type to manage versioning of embedded SQL schemas.
// Each schema file must live inside the embedded schemas directory and the name
// of each schema file must be numeric, where the number represents the version.
type Schemas struct {
tableName string
fs fs.FS
}
// GetTableDDL returns the DDL to create the schemas table.
func (s Schemas) GetTableDDL() string {
return fmt.Sprintf(tplSchemaTableDDL, s.tableName)
}
// GetSchemaVersionSQL returns the SQL query to get the current schema version.
func (s Schemas) GetSchemaVersionSQL() string {
return fmt.Sprintf(tplSchemaVersionSQL, s.tableName)
}
// WalkFrom calls a function for SQL schemas starting from a specific version.
// This is useful to apply newer schemas that are not yet applied.
func (s Schemas) WalkFrom(fromVersion uint64, fn SchemasWalkFunc) error {
// Stores schema file paths by version
paths := map[uint64]string{}
// Index the paths to the schemas with the matching versions
err := fs.WalkDir(s.fs, SchemasDir, func(path string, _ fs.DirEntry, err error) error {
if err != nil {
return errors.Errorf("failed to read schema %s: %w", path, err)
}
if path == SchemasDir {
return nil
}
// Extract the schema file version from the file name
version := extractSchemaVersion(path)
if version == 0 {
return errors.Errorf("invalid schema file name '%s'", path)
}
if fromVersion <= version {
paths[version] = path
}
return nil
})
if err != nil {
return err
}
if len(paths) == 0 {
return nil
}
for _, ver := range sortedSchemaVersions(paths) {
p := paths[ver]
// Read the SQL script from the schema file
script, err := fs.ReadFile(s.fs, p)
if err != nil {
return errors.Errorf("failed to read schema '%s': %w", p, err)
}
// Create the SQL script to change the schema to the
// current version within a single transaction
b := ScriptBuilder{}
b.BeginTX()
b.AppendCommand(s.getSchemaVersionInsertSQL(ver))
b.AppendScript(script)
b.CommitTX()
if err := fn(ver, b.Bytes()); err != nil {
return err
}
}
return nil
}
func (s Schemas) getSchemaVersionInsertSQL(version uint64) string {
return fmt.Sprintf(tplSchemaInsertSQL, s.tableName, version)
}
// ScriptBuilder builds database DDL/SQL scripts that execute multiple commands.
type ScriptBuilder struct {
buf bytes.Buffer
}
// BeginTX appends a command to start a database transaction.
func (b *ScriptBuilder) BeginTX() {
b.AppendCommand(sqlBeginTX)
}
// CommitTX appends a command to commit a database transaction.
func (b *ScriptBuilder) CommitTX() {
b.AppendCommand(sqlCommitTX)
}
// AppendCommand appends a command to the script.
func (b *ScriptBuilder) AppendCommand(cmd string) {
if strings.HasSuffix(cmd, sqlCommandSuffix) {
b.buf.WriteString(cmd)
} else {
b.buf.WriteString(cmd + sqlCommandSuffix)
}
}
// AppendScript appends a database DDL/SQL script.
func (b *ScriptBuilder) AppendScript(s []byte) {
b.buf.Write(s)
}
// Bytes returns the whole script as bytes.
func (b *ScriptBuilder) Bytes() []byte {
return b.buf.Bytes()
}
func extractSchemaVersion(fileName string) uint64 {
name := strings.TrimSuffix(
filepath.Base(fileName),
filepath.Ext(fileName),
)
// The names of the schema files MUST be numeric
version, err := strconv.ParseUint(name, 10, 0)
if err != nil {
return 0
}
return version
}
func sortedSchemaVersions(paths map[uint64]string) []uint64 {
versions := make([]uint64, 0, len(paths))
for ver := range paths {
versions = append(versions, ver)
}
sort.Slice(versions, func(i, j int) bool {
return versions[i] < versions[j]
})
return versions
}