forked from aakash-rajur/sqlxgen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_where_contexts.go
133 lines (95 loc) · 2.18 KB
/
split_where_contexts.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package prepare
import (
"regexp"
"github.com/joomcode/errorx"
"github.com/mvoorberg/sqlxgen/internal/utils/linked_list"
)
func splitWhereContexts(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
partials := make([]string, 0)
matches, err := findAllWhereMatches(query)
if err != nil {
return partials, errorx.IllegalFormat.Wrap(err, "failed to find where matches")
}
if len(matches) == 0 {
return []string{query}, nil
}
parens := findAllParentheses(query)
for _, match := range matches {
whereStart, whereEnd := match[4], match[5]-1
closestParen := parentheses{
Start: -1,
End: len(query),
}
for _, paren := range parens {
isWithin := paren.Start < whereStart && whereEnd < paren.End
if !isWithin {
continue
}
isCloser := closestParen.Start < paren.Start && paren.End < closestParen.End
if !isCloser {
continue
}
closestParen = paren
}
if closestParen.Start == -1 {
continue
}
wherePartial := query[whereStart:closestParen.End]
if wherePartial == "" {
continue
}
partials = append(partials, wherePartial)
}
if len(partials) == 0 {
return []string{query}, nil
}
return partials, nil
}
func findAllWhereMatches(query string) ([][]int, error) {
re, err := regexp.Compile(`(\)|\s)(where)(\(|\s)`)
if err != nil {
return [][]int{}, errorx.IllegalFormat.Wrap(err, "failed to compile with regex")
}
matches := re.FindAllStringSubmatchIndex(query, -1)
if len(matches) == 0 {
return [][]int{}, nil
}
return matches, nil
}
func findAllParentheses(content string) []parentheses {
ps := make([]parentheses, 0)
stack := linked_list.NewStack[parenthesesNode]()
for index, character := range content {
if character == '(' {
node := parenthesesNode{
Character: character,
Index: index,
}
stack.Push(&node)
}
if character == ')' {
prevNode, _ := stack.Peek()
if prevNode.Character != '(' {
continue
}
stack.Pop()
p := parentheses{
Start: prevNode.Index,
End: index,
}
ps = append(ps, p)
}
}
return ps
}
type parentheses struct {
Start int
End int
}
type parenthesesNode struct {
Character rune
Index int
}