forked from src-d/go-mysql-server
/
convert_dates.go
186 lines (162 loc) · 5.54 KB
/
convert_dates.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
package analyzer
import (
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/expression/function"
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
"github.com/src-d/go-mysql-server/sql/plan"
)
// convertDates wraps all expressions of date and datetime type with converts
// to ensure the date range is validated.
func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}
// Replacements contains a mapping from columns to the alias they will be
// replaced by.
var replacements = make(map[tableCol]string)
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}
// nodeReplacements are all the replacements found in the current node.
// These replacements are not applied to the current node, only to
// parent nodes.
var nodeReplacements = make(map[tableCol]string)
var expressions = make(map[string]bool)
switch exp := exp.(type) {
case *plan.Project:
for _, e := range exp.Projections {
expressions[e.String()] = true
}
case *plan.GroupBy:
for _, e := range exp.Aggregate {
expressions[e.String()] = true
}
}
var result sql.Node
var err error
switch exp := exp.(type) {
case *plan.GroupBy:
var aggregate = make([]sql.Expression, len(exp.Aggregate))
for i, a := range exp.Aggregate {
agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
return nil, err
}
aggregate[i] = agg
if _, ok := agg.(*expression.Alias); !ok && agg.String() != a.String() {
nodeReplacements[tableCol{"", a.String()}] = agg.String()
}
}
var grouping = make([]sql.Expression, len(exp.Grouping))
for i, g := range exp.Grouping {
gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false)
})
if err != nil {
return nil, err
}
grouping[i] = gr
}
result = plan.NewGroupBy(aggregate, grouping, exp.Child)
case *plan.Project:
var projections = make([]sql.Expression, len(exp.Projections))
for i, e := range exp.Projections {
expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
return nil, err
}
projections[i] = expr
if _, ok := expr.(*expression.Alias); !ok && expr.String() != e.String() {
nodeReplacements[tableCol{"", e.String()}] = expr.String()
}
}
result = plan.NewProject(projections, exp.Child)
default:
result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, n, replacements, nodeReplacements, expressions, false)
})
}
if err != nil {
return nil, err
}
// We're done with this node, so copy all the replacements found in
// this node to the global replacements in order to make the necesssary
// changes in parent nodes.
for tc, n := range nodeReplacements {
replacements[tc] = n
}
return result, err
})
}
func addDateConvert(
e sql.Expression,
node sql.Node,
replacements, nodeReplacements map[tableCol]string,
expressions map[string]bool,
aliasRootProjections bool,
) (sql.Expression, error) {
var result sql.Expression
// No need to wrap expressions that already validate times, such as
// convert, date_add, etc and those expressions whose Type method
// cannot be called because they are placeholders.
switch e := e.(type) {
case *aggregation.Max:
child, err := addDateConvert(e.Child, node, replacements, nodeReplacements, expressions, false)
if err != nil {
return nil, err
}
return aggregation.NewMax(child), nil
case *aggregation.Min:
child, err := addDateConvert(e.Child, node, replacements, nodeReplacements, expressions, false)
if err != nil {
return nil, err
}
return aggregation.NewMin(child), nil
case *expression.Convert,
*expression.Arithmetic,
*function.DateAdd,
*function.DateSub,
*expression.Star,
*expression.DefaultColumn,
*expression.Alias:
return e, nil
default:
// If it's a replacement, just replace it with the correct GetField
// because we know that it's already converted to a correct date
// and there is no point to do so again.
if gf, ok := e.(*expression.GetField); ok {
if name, ok := replacements[tableCol{gf.Table(), gf.Name()}]; ok {
return expression.NewGetField(gf.Index(), gf.Type(), name, gf.IsNullable()), nil
}
}
switch e.Type() {
case sql.Date:
result = expression.NewConvert(e, expression.ConvertToDate)
case sql.Timestamp:
result = expression.NewConvert(e, expression.ConvertToDatetime)
default:
result = e
}
}
// Only do this if it's a root expression in a project or group by.
switch node.(type) {
case *plan.Project, *plan.GroupBy:
// If it was originally a GetField, and it's not anymore it's
// because we wrapped it in a convert. We need to make it an alias
// and propagate the changes up the chain.
if gf, ok := e.(*expression.GetField); ok && expressions[e.String()] && aliasRootProjections {
if _, ok := result.(*expression.GetField); !ok {
result = expression.NewAlias(result, gf.Name())
nodeReplacements[tableCol{gf.Table(), gf.Name()}] = gf.Name()
}
}
}
return result, nil
}