-
Notifications
You must be signed in to change notification settings - Fork 0
/
rule.go
175 lines (143 loc) · 4.19 KB
/
rule.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
//go:build linux
/*
A library for managing nftables rules
*/
package rule
import (
"bytes"
"fmt"
"github.com/google/nftables"
)
// RuleTarget represents a location to manipulate nftables rules
type RuleTarget struct {
table *nftables.Table
chain *nftables.Chain
}
// Create a new location to manipulate nftables rules
func NewRuleTarget(table *nftables.Table, chain *nftables.Chain) RuleTarget {
return RuleTarget{
table: table,
chain: chain,
}
}
// Add a rule with a given ID to a specific table and chain, returns true if the rule was added
func (r *RuleTarget) Add(c *nftables.Conn, ruleData RuleData) (bool, error) {
exists, err := r.Exists(c, ruleData)
if err != nil {
return false, err
}
if exists {
return false, nil
}
add(c, r.table, r.chain, ruleData)
return true, nil
}
func add(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ruleData RuleData) {
c.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: ruleData.Expressions,
UserData: ruleData.ID,
})
}
// Delete a rule with a given ID from a specific table and chain, returns true if the rule was deleted
func (r *RuleTarget) Delete(c *nftables.Conn, ruleData RuleData) (bool, error) {
rules, err := c.GetRules(r.table, r.chain)
if err != nil {
return false, err
}
rule := findRuleByID(ruleData.ID, rules)
if rule.Table.Name == "" {
// if the rule we get back is empty (the final return in findRuleByID) we didn't find it
return false, nil
}
if err := c.DelRule(rule); err != nil {
return false, err
}
return true, nil
}
// Determine if a rule with a given ID exists in a specific table and chain
func (r *RuleTarget) Exists(c *nftables.Conn, ruleData RuleData) (bool, error) {
rules, err := c.GetRules(r.table, r.chain)
if err != nil {
return false, err
}
rule := findRuleByID(ruleData.ID, rules)
if rule.Table == nil {
// if the rule we get back is empty (the final return in findRuleByID) we didn't find it
return false, nil
}
return true, nil
}
// Compare existing and incoming rule IDs adding/removing the difference
//
// First return value is true if the number of rules has changed, false if there were no updates. The second
// and third return values indicate the number of rules added or removed, respectively.
func (r *RuleTarget) Update(c *nftables.Conn, rules []RuleData) (bool, int, int, error) {
var modified bool
existingRules, err := c.GetRules(r.table, r.chain)
if err != nil {
return false, 0, 0, fmt.Errorf("error getting existing rules for update: %v", err)
}
addRDList, removeRDList := genRuleDelta(existingRules, rules)
if len(removeRDList) > 0 {
for _, rule := range removeRDList {
err := c.DelRule(rule)
if err != nil {
return false, 0, 0, err
}
modified = true
}
}
if len(addRDList) > 0 {
for _, rule := range addRDList {
add(c, r.table, r.chain, rule)
modified = true
}
}
return modified, len(addRDList), len(removeRDList), nil
}
// Get the nftables table and chain associated with this RuleTarget
func (r *RuleTarget) GetTableAndChain() (*nftables.Table, *nftables.Chain) {
return r.table, r.chain
}
// Get the rule data associated with a table and chain
func (r *RuleTarget) Get(c *nftables.Conn) ([]RuleData, error) {
rules, err := c.GetRules(r.table, r.chain)
if err != nil {
return nil, err
}
ruleData := make([]RuleData, len(rules))
for i, rule := range rules {
ruleData[i] = RuleData{
ID: rule.UserData,
Expressions: rule.Exprs,
}
}
return ruleData, nil
}
func genRuleDelta(existingRules []*nftables.Rule, newRules []RuleData) (add []RuleData, remove []*nftables.Rule) {
existingRuleMap := make(map[string]*nftables.Rule)
for _, existingRule := range existingRules {
existingRuleMap[string(existingRule.UserData)] = existingRule
}
for _, ruleData := range newRules {
if _, exists := existingRuleMap[string(ruleData.ID)]; !exists {
add = append(add, ruleData)
} else {
delete(existingRuleMap, string(ruleData.ID))
}
}
for _, v := range existingRuleMap {
remove = append(remove, v)
}
return
}
func findRuleByID(id []byte, rules []*nftables.Rule) *nftables.Rule {
for _, rule := range rules {
if bytes.Equal(rule.UserData, id) {
return rule
}
}
return &nftables.Rule{}
}