-
Notifications
You must be signed in to change notification settings - Fork 0
/
retain_trie.go
153 lines (137 loc) · 3.08 KB
/
retain_trie.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package trie
import (
"strings"
"github.com/morgeq/iotfast/server/mqtt/retained"
gmqtt "github.com/morgeq/iotfast/server/mqtt"
)
// topicTrie
type topicTrie = topicNode
// children
type children = map[string]*topicNode
// topicNode
type topicNode struct {
children children
msg *gmqtt.Message
parent *topicNode // pointer of parent node
topicName string
}
// newTopicTrie create a new trie tree
func newTopicTrie() *topicTrie {
return newNode()
}
// newNode create a new trie node
func newNode() *topicNode {
return &topicNode{
children: children{},
}
}
// newChild create a child node of t
func (t *topicNode) newChild() *topicNode {
return &topicNode{
children: children{},
parent: t,
}
}
// find walk through the tire and return the node that represent the topicName
// return nil if not found
func (t *topicTrie) find(topicName string) *topicNode {
topicSlice := strings.Split(topicName, "/")
var pNode = t
for _, lv := range topicSlice {
if _, ok := pNode.children[lv]; ok {
pNode = pNode.children[lv]
} else {
return nil
}
}
if pNode.msg != nil {
return pNode
}
return nil
}
// matchTopic walk through the tire and call the fn callback for each message witch match the topic filter.
func (t *topicTrie) matchTopic(topicSlice []string, fn retained.IterateFn) {
endFlag := len(topicSlice) == 1
switch topicSlice[0] {
case "#":
t.preOrderTraverse(fn)
case "+":
// 当前层的所有
for _, v := range t.children {
if endFlag {
if v.msg != nil {
fn(v.msg)
}
} else {
v.matchTopic(topicSlice[1:], fn)
}
}
default:
if n := t.children[topicSlice[0]]; n != nil {
if endFlag {
if n.msg != nil {
fn(n.msg)
}
} else {
n.matchTopic(topicSlice[1:], fn)
}
}
}
}
func (t *topicTrie) getMatchedMessages(topicFilter string) []*gmqtt.Message {
topicLv := strings.Split(topicFilter, "/")
var rs []*gmqtt.Message
t.matchTopic(topicLv, func(message *gmqtt.Message) bool {
rs = append(rs, message.Copy())
return true
})
return rs
}
func isSystemTopic(topicName string) bool {
return len(topicName) >= 1 && topicName[0] == '$'
}
// addRetainMsg add a retain message
func (t *topicTrie) addRetainMsg(topicName string, message *gmqtt.Message) {
topicSlice := strings.Split(topicName, "/")
var pNode = t
for _, lv := range topicSlice {
if _, ok := pNode.children[lv]; !ok {
pNode.children[lv] = pNode.newChild()
}
pNode = pNode.children[lv]
}
pNode.msg = message
pNode.topicName = topicName
}
func (t *topicTrie) remove(topicName string) {
topicSlice := strings.Split(topicName, "/")
l := len(topicSlice)
var pNode = t
for _, lv := range topicSlice {
if _, ok := pNode.children[lv]; ok {
pNode = pNode.children[lv]
} else {
return
}
}
pNode.msg = nil
if len(pNode.children) == 0 {
delete(pNode.parent.children, topicSlice[l-1])
}
}
func (t *topicTrie) preOrderTraverse(fn retained.IterateFn) bool {
if t == nil {
return false
}
if t.msg != nil {
if !fn(t.msg) {
return false
}
}
for _, c := range t.children {
if !c.preOrderTraverse(fn) {
return false
}
}
return true
}