-
Notifications
You must be signed in to change notification settings - Fork 70
/
list.go
306 lines (266 loc) · 5.78 KB
/
list.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
package sortedlist
import "math"
// List 实现了排序链表的数据结构
type List interface {
// Remove 移除节点
Remove(key int64) bool
// Add 新增节点
Add(key int64, data interface{})
// Range 过滤范围内的 key 并返回 Iter 对象
Range(lower, upper int64) Iter
// All 迭代所有对象
All() Iter
}
// Iter 迭代器对象
type Iter interface {
Next() bool
Value() interface{}
}
type iter struct {
cursor int
data []interface{}
}
// Next 推进迭代器
func (it *iter) Next() bool {
it.cursor++
if len(it.data) > it.cursor {
return true
}
return false
}
// Value 返回迭代器当前 value
func (it *iter) Value() interface{} {
return it.data[it.cursor]
}
type avlNode struct {
h int
key int64
value interface{}
left *avlNode
right *avlNode
}
type aVLTree struct {
tree *avlNode
}
// NewTree 生成 AVL 树
func NewTree() List {
return &aVLTree{&avlNode{h: -2}}
}
func (a *aVLTree) Add(k int64, v interface{}) {
a.tree = insert(k, v, a.tree)
}
func (a *aVLTree) Remove(k int64) bool {
if a.tree.search(k) {
a.tree.delete(k)
return true
}
return false
}
func (a *aVLTree) All() Iter {
return a.tree.values(0, math.MaxInt64)
}
func (a *aVLTree) Range(lower, upper int64) Iter {
return a.tree.values(lower, upper)
}
func (a *aVLTree) getMaxValue() int64 {
return a.tree.maxNode().key
}
func (a *aVLTree) getMinValue() int64 {
return a.tree.minNode().key
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func insert(k int64, v interface{}, t *avlNode) *avlNode {
if t == nil {
return &avlNode{key: k, value: v}
}
if t.h == -2 {
t.key = k
t.value = v
t.h = 0
return t
}
cmp := k - t.key
if cmp > 0 {
// 将节点插入到右子树中
t.right = insert(k, v, t.right)
} else if cmp < 0 {
// 将节点插入到左子树中
t.left = insert(k, v, t.left)
} else if cmp == 0 {
t.value = v
}
// 维持树平衡
t = t.keepBalance(k)
t.h = max(t.left.height(), t.right.height()) + 1
return t
}
func (t *avlNode) search(k int64) bool {
if t == nil {
return false
}
cmp := k - t.key
if cmp > 0 {
// 如果 v 大于当前节点值,继续从右子树中寻找
return t.right.search(k)
} else if cmp < 0 {
// 如果 v 小于当前节点值,继续从左子树中寻找
return t.left.search(k)
} else {
// 相等则表示找到
return true
}
}
func (t *avlNode) delete(k int64) *avlNode {
if t == nil {
return t
}
cmp := k - t.key
if cmp > 0 {
// 如果 v 大于当前节点值,继续从右子树中删除
t.right = t.right.delete(k)
} else if cmp < 0 {
// 如果 v 小于当前节点值,继续从左子树中删除
t.left = t.left.delete(k)
} else {
// 找到 v
if t.left != nil && t.right != nil {
// 如果该节点既有左子树又有右子树
// 使用右子树中的最小节点取代删除节点,然后删除右子树中的最小节点
minnode := t.right.minNode()
t.key = minnode.key
t.value = minnode.value
t.right = t.right.delete(t.key)
} else if t.left != nil {
// 如果只有左子树,则直接删除节点
t = t.left
} else {
// 只有右子树或空树
t = t.right
}
}
if t != nil {
t.h = max(t.left.height(), t.right.height()) + 1
t = t.keepBalance(k)
}
return t
}
func (t *avlNode) minNode() *avlNode {
if t == nil {
return nil
}
// 整棵树的最左边节点就是值最小的节点
if t.left == nil {
return t
} else {
return t.left.minNode()
}
}
func (t *avlNode) maxNode() *avlNode {
if t == nil {
return nil
}
// 整棵树的最右边节点就是值最大的节点
if t.right == nil {
return t
} else {
return t.right.maxNode()
}
}
/*
左左情况:右旋
*
*
*
*/
func (t *avlNode) llRotate() *avlNode {
node := t.left
t.left = node.right
node.right = t
node.h = max(node.left.height(), node.right.height()) + 1
t.h = max(t.left.height(), t.right.height()) + 1
return node
}
/*
右右情况:左旋
*
*
*
*/
func (t *avlNode) rrRotate() *avlNode {
node := t.right
t.right = node.left
node.left = t
node.h = max(node.left.height(), node.right.height()) + 1
t.h = max(t.left.height(), t.right.height()) + 1
return node
}
/*
左右情况:先左旋 后右旋
*
*
*
*/
func (t *avlNode) lrRotate() *avlNode {
t.left = t.left.rrRotate()
return t.llRotate()
}
/*
右左情况:先右旋 后左旋
*
*
*
*/
func (t *avlNode) rlRotate() *avlNode {
t.right = t.right.llRotate()
return t.rrRotate()
}
func (t *avlNode) keepBalance(k int64) *avlNode {
// 左子树失衡
if t.left.height()-t.right.height() == 2 {
if k-t.left.key < 0 {
// 当插入的节点在失衡节点的左子树的左子树中,直接右旋
t = t.llRotate()
} else {
// 当插入的节点在失衡节点的左子树的右子树中,先左旋后右旋
t = t.lrRotate()
}
} else if t.right.height()-t.left.height() == 2 {
if t.right.right.height() > t.right.left.height() {
// 当插入的节点在失衡节点的右子树的右子树中,直接左旋
t = t.rrRotate()
} else {
// 当插入的节点在失衡节点的右子树的左子树中,先右旋后左旋
t = t.rlRotate()
}
}
// 调整树高度
t.h = max(t.left.height(), t.right.height()) + 1
return t
}
func (t *avlNode) height() int {
if t != nil {
return t.h
}
return -1
}
// appendValue 中序遍历按顺序获取所有值
func appendValue(values []interface{}, lower, upper int64, t *avlNode) []interface{} {
if t != nil {
values = appendValue(values, lower, upper, t.left)
if t.key >= lower && t.key <= upper {
values = append(values, t.value)
}
values = appendValue(values, lower, upper, t.right)
}
return values
}
func (t *avlNode) values(lower, upper int64) Iter {
it := &iter{data: []interface{}{nil}}
it.data = appendValue(it.data, lower, upper, t)
return it
}