-
Notifications
You must be signed in to change notification settings - Fork 2
/
compiler.go
153 lines (140 loc) · 3.97 KB
/
compiler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package goen
import (
sqr "gopkg.in/Masterminds/squirrel.v1"
"reflect"
)
var (
DefaultCompiler PatchCompiler = PatchCompilerFunc(compilePatch)
BulkCompiler PatchCompiler = &bulkCompiler{}
)
type CompilerOptions struct {
StmtBuilder sqr.StatementBuilderType
Patches *PatchList
}
type PatchCompiler interface {
Compile(*CompilerOptions) *SqlizerList
}
type PatchCompilerFunc func(*CompilerOptions) *SqlizerList
func (fn PatchCompilerFunc) Compile(opts *CompilerOptions) (sqlizers *SqlizerList) {
return fn(opts)
}
func compilePatch(opts *CompilerOptions) (sqlizers *SqlizerList) {
sqlizers = NewSqlizerList()
for curr := opts.Patches.Front(); curr != nil; curr = curr.Next() {
patch := curr.GetValue()
if len(patch.Columns) != len(patch.Values) {
panic("goen: number of columns and values are mismatched")
}
var sqlizer sqr.Sqlizer
switch patch.Kind {
case PatchInsert:
sqlizer = opts.StmtBuilder.Insert(patch.TableName).
Columns(patch.Columns...).
Values(patch.Values...)
case PatchUpdate:
stmt := opts.StmtBuilder.Update(patch.TableName)
for i := range patch.Columns {
stmt = stmt.Set(patch.Columns[i], patch.Values[i])
}
if patch.RowKey != nil {
stmt = stmt.Where(patch.RowKey)
}
sqlizer = stmt
case PatchDelete:
stmt := opts.StmtBuilder.Delete(patch.TableName)
if patch.RowKey != nil {
stmt = stmt.Where(patch.RowKey)
}
sqlizer = stmt
default:
panic("goen: unable to make sql statement for unknown kind (" + string(patch.Kind) + ")")
}
sqlizers.PushBack(sqlizer)
}
return sqlizers
}
type bulkCompiler struct{}
func (compiler *bulkCompiler) Compile(opts *CompilerOptions) (sqlizers *SqlizerList) {
sqlizers = NewSqlizerList()
for curr := opts.Patches.Front(); curr != nil; curr = curr.Next() {
patch := curr.GetValue()
if len(patch.Columns) != len(patch.Values) {
panic("goen: number of columns and values are mismatched")
}
switch patch.Kind {
case PatchInsert:
stmt := opts.StmtBuilder.Insert(patch.TableName).Columns(patch.Columns...).Values(patch.Values...)
for curr.Next() != nil && compiler.isCompat(patch, curr.Next().GetValue()) {
curr = curr.Next()
stmt = stmt.Values(curr.GetValue().Values...)
}
sqlizers.PushBack(stmt)
case PatchDelete:
stmt := opts.StmtBuilder.Delete(patch.TableName)
cond := sqr.Or{}
if patch.RowKey != nil {
cond = append(cond, patch.RowKey)
}
for curr.Next() != nil && compiler.isCompat(patch, curr.Next().GetValue()) {
curr = curr.Next()
if np := curr.GetValue(); np.RowKey != nil {
cond = append(cond, np.RowKey)
}
}
stmt = stmt.Where(cond)
sqlizers.PushBack(stmt)
case PatchUpdate:
stmt := opts.StmtBuilder.Update(patch.TableName)
for i := range patch.Columns {
stmt = stmt.Set(patch.Columns[i], patch.Values[i])
}
cond := sqr.Or{}
if patch.RowKey != nil {
cond = append(cond, patch.RowKey)
}
for curr.Next() != nil && compiler.isCompat(patch, curr.Next().GetValue()) {
curr = curr.Next()
if np := curr.GetValue(); np.RowKey != nil {
cond = append(cond, np.RowKey)
}
}
stmt = stmt.Where(cond)
sqlizers.PushBack(stmt)
default:
fallbackOpts := &CompilerOptions{
StmtBuilder: opts.StmtBuilder,
Patches: NewPatchList(),
}
fallbackOpts.Patches.PushBack(patch)
sqlizers.PushBackList(DefaultCompiler.Compile(fallbackOpts))
}
}
return sqlizers
}
func (compiler *bulkCompiler) isCompat(p1, p2 *Patch) bool {
if p1.Kind != p2.Kind {
return false
}
if p1.TableName != p2.TableName {
return false
}
if len(p1.Columns) != len(p2.Columns) {
return false
} else {
for i := range p1.Columns {
if p1.Columns[i] != p2.Columns[i] {
return false
}
}
}
switch p1.Kind {
case PatchUpdate:
// do not use "database/sql/driver".Valuer.
// it's for converting go type to sql type; type converting.
// if converts the actual value, maybe illegal implementation.
if !reflect.DeepEqual(p1.Values, p2.Values) {
return false
}
}
return true
}