forked from pressly/goose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
provider.go
172 lines (153 loc) · 4.24 KB
/
provider.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
package goose
import (
"io/fs"
"path/filepath"
"runtime"
"time"
)
const (
defaultProviderPackage = "migrations"
defaultTableName = "goose_db_version"
defaultTimestampFormat = "20060102150405"
)
// defaultProvider is the provider the general functions use.
var defaultProvider = NewProvider()
type providerOptions func(p *Provider)
// TimestampFormat sets the timestamp format for the provider
func TimestampFormat(format string) func(p *Provider) {
return func(p *Provider) {
p.timestampFormat = format
}
}
// TimeFunction sets the time function used to get the time for timestamp numbers
// defaults to time.Now
func TimeFunction(fn func() time.Time) func(p *Provider) {
if fn == nil {
fn = time.Now
}
return func(p *Provider) {
p.timeFn = fn
}
}
// Verbose sets the verbose on the provider
func Verbose(b bool) func(p *Provider) {
return func(p *Provider) {
p.verbose = b
}
}
// SequentialVersion make the provider use sequential versioning
func SequentialVersion(versionTemplate string) func(p *Provider) {
return func(p *Provider) {
p.sequential = true
if versionTemplate != "" {
p.seqVersionTemplate = versionTemplate
}
}
}
// TimestampVersion make the provider use sequential versioning
func TimestampVersion(p *Provider) {
p.sequential = false
}
func Filesystem(baseFS fs.FS) func(p *Provider) {
return func(p *Provider) {
p.baseFS = baseFS
}
}
func Log(log Logger) func(p *Provider) {
return func(p *Provider) {
p.log = log
}
}
func Dialect(dialect string) func(p *Provider) {
return func(p *Provider) {
dialect, err := SelectDialect(p.tableName, dialect)
if err != nil {
p.log.Fatal(err)
}
p.dialect = dialect
}
}
// dirPath finds the directory path of the calling function's caller
func dirPath() string {
_, filename, _, _ := runtime.Caller(2)
return filepath.Dir(filename)
}
// BaseDir will set the base directory, if an empty string is passed
// the directory of the package that called BaseDir is used instead
// this is only useful for Create* and Fix functions
func BaseDir(dir string) func(p *Provider) {
if dir == "" {
dir = dirPath()
}
return func(p *Provider) {
p.baseDir = dir
}
}
func DialectObject(dialect SQLDialect) func(p *Provider) {
return func(p *Provider) {
p.dialect = dialect
p.dialect.SetTableName(p.tableName)
}
}
func Tablename(tablename string) func(p *Provider) {
return func(p *Provider) {
p.tableName = tablename
p.dialect.SetTableName(tablename)
}
}
// ProviderPackage sets the packageName and providerVar used in templates
func ProviderPackage(packageName, providerVar string) func(p *Provider) {
if packageName == "" {
packageName = defaultProviderPackage
}
return func(p *Provider) {
p.packageName = packageName
p.providerVarName = providerVar
}
}
type Provider struct {
timestampFormat string
// defaults to time.Now
timeFn func() time.Time
verbose bool
// whether to use sequential versioning instead of timestamp based versioning
sequential bool
baseFS fs.FS
log Logger
dialect SQLDialect
registeredGoMigrations map[int64]*Migration
tableName string
// seqVersionTemplate sets the template system will use this to format the digit of the sequence number
// by default it %05d, see seqVersionTemplate for actually default value.
seqVersionTemplate string
// packageName is the name of the package to use for Create functions
packageName string
// providerVarName is the name of the provider var for create functions
providerVarName string
// This is used for Create/Fix if the dir is not passed.
baseDir string
}
func NewProvider(options ...providerOptions) *Provider {
p := &Provider{
timestampFormat: defaultTimestampFormat,
timeFn: time.Now,
verbose: false,
sequential: false,
baseFS: osFS{},
log: log,
dialect: &PostgresDialect{},
registeredGoMigrations: map[int64]*Migration{},
tableName: defaultTableName,
packageName: defaultProviderPackage,
}
for _, opt := range options {
opt(p)
}
return p
}
func (p *Provider) BaseDir(dir string) string {
if p.baseDir != "" && (dir == "" || dir == ".") {
return p.baseDir
}
return dir
}