/
provider.go
100 lines (81 loc) · 2.43 KB
/
provider.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// Package provider provides an iterator for iterating over all of the
// migration statements that need to be applied. It will provide the statements
// in the correct order based on the Edition priority and migration version.
package provider
import (
"sort"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
)
const nilVersion = -1
type migration struct {
version int
edition string
statements []byte
}
// Provider provides the migrations to the schema.Manager in the correct order.
type Provider struct {
pos int
migrations []migration
}
// DatabaseState is a map of edition names to versions.
type DatabaseState map[string]int
// New creates a Provider. The given DatabaseState is compared against the editions
// to determine which migrations need to be applied and the migrations are then ordered
// based on the Edition priority and by the migration version.
func New(dbState DatabaseState, editions edition.Editions) *Provider {
m := &Provider{
pos: -1,
}
// ensure editions in priority order
editions.Sort()
allMigrations := make([]migration, 0)
for _, e := range editions {
dbVer, ok := dbState[e.Name]
if !ok {
dbVer = nilVersion
}
migrations := make([]migration, 0, len(e.Migrations))
for ver, statements := range e.Migrations {
if ver > dbVer {
migrations = append(migrations, migration{
version: ver,
edition: e.Name,
statements: statements,
})
}
}
sort.SliceStable(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
})
allMigrations = append(allMigrations, migrations...)
}
m.migrations = allMigrations
return m
}
// Next proceeds to the next migration. It returns true on success or false
// if there are no more migrations.
func (p *Provider) Next() bool {
p.pos++
return len(p.migrations) > p.pos
}
// Version returns the version for the current migration.
func (p *Provider) Version() int {
if p.pos < 0 || p.pos >= len(p.migrations) {
return -1
}
return p.migrations[p.pos].version
}
// Edition returns the edition name for the current migration.
func (p *Provider) Edition() string {
if p.pos < 0 || p.pos >= len(p.migrations) {
return ""
}
return p.migrations[p.pos].edition
}
// Statements returns the sql statements name for the current migration.
func (p *Provider) Statements() []byte {
if p.pos < 0 || p.pos >= len(p.migrations) {
return nil
}
return p.migrations[p.pos].statements
}