/
version.go
142 lines (131 loc) · 3.58 KB
/
version.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package main
import (
"strings"
"time"
"github.com/jmoiron/sqlx"
mytokenlib "github.com/oidc-mytoken/lib"
log "github.com/sirupsen/logrus"
"github.com/oidc-mytoken/server/internal/db"
"github.com/oidc-mytoken/server/internal/db/dbmigrate"
"github.com/oidc-mytoken/server/internal/db/dbrepo/versionrepo"
"github.com/oidc-mytoken/server/internal/model/version"
"github.com/oidc-mytoken/server/internal/utils/dbcl"
)
func did(state versionrepo.DBVersionState, version string) (beforeDone, afterDone bool) {
for _, entry := range state {
if entry.Version == version {
if entry.Before.Valid {
beforeDone = true
}
if entry.After.Valid {
afterDone = true
}
return
}
}
return
}
func getDoneMap(state versionrepo.DBVersionState) (map[string]bool, map[string]bool) {
before := make(map[string]bool, len(dbmigrate.Versions))
after := make(map[string]bool, len(dbmigrate.Versions))
for _, v := range dbmigrate.Versions {
before[v], after[v] = did(state, v)
}
return before, after
}
func migrateDB(mytokenNodes []string) error {
v := "v" + version.VERSION
dbState, err := versionrepo.GetVersionState(log.StandardLogger(), nil)
if err != nil {
return err
}
return runUpdates(dbState, mytokenNodes, v)
}
func runUpdates(dbState versionrepo.DBVersionState, mytokenNodes []string, version string) error {
beforeDone, afterDone := getDoneMap(dbState)
if err := runBeforeUpdates(beforeDone); err != nil {
return err
}
if !anyAfterUpdates(afterDone) { // If there are no after cmds to run, we are done
return nil
}
waitUntilAllNodesOnVersion(mytokenNodes, version)
return runAfterUpdates(afterDone)
}
func runBeforeUpdates(beforeDone map[string]bool) error {
for _, v := range dbmigrate.Versions {
if err := updateCallback(
dbmigrate.MigrationCommands[v].Before, v, beforeDone, versionrepo.SetVersionBefore,
); err != nil {
return err
}
}
return nil
}
func anyAfterUpdates(afterDone map[string]bool) bool {
for v, cs := range dbmigrate.MigrationCommands {
if len(cs.After) > 0 && !afterDone[v] {
return true
}
}
return false
}
func runAfterUpdates(afterDone map[string]bool) error {
for _, v := range dbmigrate.Versions {
if err := updateCallback(
dbmigrate.MigrationCommands[v].After, v, afterDone, versionrepo.SetVersionAfter,
); err != nil {
return err
}
}
return nil
}
func updateCallback(
cmds, version string, done map[string]bool,
dbUpdateCallback func(log.Ext1FieldLogger, *sqlx.Tx, string) error,
) error {
log.WithField("version", version).Info("Updating DB to version")
if cmds == "" {
return nil
}
if done[version] {
log.WithField("version", version).Info("Skipping Update; DB already has this version.")
return nil
}
if err := dbcl.RunDBCommands(cmds, dbConfig.DBConf, true); err != nil {
return err
}
return db.Transact(
log.StandardLogger(), func(tx *sqlx.Tx) error {
return dbUpdateCallback(log.StandardLogger(), tx, version)
},
)
}
func waitUntilAllNodesOnVersion(mytokenNodes []string, version string) {
allNodesOnVersion := len(mytokenNodes) == 0
for !allNodesOnVersion {
tmp := []string{}
for _, n := range mytokenNodes {
v, err := getVersionForNode(n)
if err != nil {
log.WithError(err).Error()
}
if v != version {
tmp = append(tmp, n)
}
}
mytokenNodes = tmp
allNodesOnVersion = len(mytokenNodes) == 0
time.Sleep(60 * time.Second)
}
}
func getVersionForNode(node string) (string, error) {
if !strings.HasPrefix(node, "http") {
node = "https://" + node
}
my, err := mytokenlib.NewMytokenServer(node)
if err != nil {
return "", err
}
return my.ServerMetadata.Version, nil
}