-
Notifications
You must be signed in to change notification settings - Fork 4
/
tree_traversal.go
106 lines (90 loc) · 2.58 KB
/
tree_traversal.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package mbtree
// RSearch traverse back from the nid node to its ancestors (until root)
// @param: it supports pass one filter func to check the node
func (t *SafeMultiBranchTree) RSearch(id int64, args ...FilterFunc) chan int64 {
// IMPORTANT: as non-buffered channel will be blocked if receiver quit
// here use buffered-channel, the channel buffer number is SafeMultiBranchTree.MaxTreeLevel
// default is 100, which should be enough to accommodate as much nodes as possible
c := make(chan int64, t.MaxTreeLevel)
go func() {
// if iterate to root node, then need quit
if id == t.RootId {
close(c)
return
}
node := t.GetNode(id)
if node == nil {
close(c)
return
}
// 初始化current
currentId := id
currentNode := node
filter := getFilterFromArgs(args...)
for {
// 如果遍历到根节点,需要停止
if currentId == t.RootId {
c <- t.RootId
break
}
// 如果没有过滤函数,直接返回当前值
if filter(currentNode) {
c <- currentId
}
// 如果没有到根节点,遍历到根节点
currentId = currentNode.Pid
currentNode = t.GetNode(currentId)
// 如果没有找到节点也需要停止
if currentNode == nil {
break
}
} // END for
// 关闭channel
close(c)
}()
return c
}
// DepthFirstTraversal 深度优先遍历
func (t *SafeMultiBranchTree) DepthFirstTraversal(nid int64, args ...FilterFunc) chan int64 {
return t.traversal(nid, DEPTH, args...)
}
// WidthFirstTraversal 广度优先遍历
func (t *SafeMultiBranchTree) WidthFirstTraversal(nid int64, args ...FilterFunc) chan int64 {
return t.traversal(nid, WIDTH, args...)
}
// traversal traverse the tree (or a subtree) with optional node filtering and sorting.
func (t *SafeMultiBranchTree) traversal(nid int64, mode TraversalMode, args ...FilterFunc) chan int64 {
c := make(chan int64)
go func() {
node := t.GetNode(nid)
if node == nil {
close(c)
return
}
filter := getFilterFromArgs(args...)
// 满足条件的当前节点的id返回
if filter(node) {
c <- node.Id
}
// 过滤子节点
queue := t.FilterNodesWithin(node.ChildIds, filter)
for {
// 如果无要遍历的节点了,break
if len(queue) == 0 {
break
}
// 将需遍历的子节点的第一个取出来
c <- queue[0].Id
// 过滤该节点的子节点
expansion := t.FilterNodesWithin(queue[0].ChildIds, filter)
// 如果深度优先,重新组装queue
if mode == DEPTH {
queue = append(expansion, queue[1:]...)
} else if mode == WIDTH {
queue = append(queue[1:], expansion...)
}
}
close(c)
}()
return c
}