-
Notifications
You must be signed in to change notification settings - Fork 71
/
locked_iptables.go
153 lines (125 loc) · 4.4 KB
/
locked_iptables.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package rules
import (
"fmt"
"os/exec"
"strings"
)
//go:generate counterfeiter -o ../fakes/iptables.go --fake-name IPTables . iptables
type iptables interface {
Exists(table, chain string, rulespec ...string) (bool, error)
Insert(table, chain string, pos int, rulespec ...string) error
AppendUnique(table, chain string, rulespec ...string) error
Delete(table, chain string, rulespec ...string) error
List(table, chain string) ([]string, error)
NewChain(table, chain string) error
ClearChain(table, chain string) error
DeleteChain(table, chain string) error
}
//go:generate counterfeiter -o ../fakes/iptables_extended.go --fake-name IPTablesAdapter . IPTablesAdapter
type IPTablesAdapter interface {
Exists(table, chain string, rulespec IPTablesRule) (bool, error)
Delete(table, chain string, rulespec IPTablesRule) error
List(table, chain string) ([]string, error)
NewChain(table, chain string) error
ClearChain(table, chain string) error
DeleteChain(table, chain string) error
BulkInsert(table, chain string, pos int, rulespec ...IPTablesRule) error
BulkAppend(table, chain string, rulespec ...IPTablesRule) error
}
//go:generate counterfeiter -o ../fakes/locker.go --fake-name Locker . locker
type locker interface {
Lock() error
Unlock() error
}
//go:generate counterfeiter -o ../fakes/restorer.go --fake-name Restorer . restorer
type restorer interface {
Restore(ruleState string) error
}
type Restorer struct{}
func (r *Restorer) Restore(input string) error {
cmd := exec.Command("iptables-restore", "--noflush")
cmd.Stdin = strings.NewReader(input)
bytes, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("iptables-restore error: %s combined output: %s", err, string(bytes))
}
return nil
}
type LockedIPTables struct {
IPTables iptables
Locker locker
Restorer restorer
}
func handleIPTablesError(err1, err2 error) error {
return fmt.Errorf("iptables call: %+v and unlock: %+v", err1, err2)
}
func (l *LockedIPTables) Exists(table, chain string, rulespec IPTablesRule) (bool, error) {
if err := l.Locker.Lock(); err != nil {
return false, fmt.Errorf("lock: %s", err)
}
b, err := l.IPTables.Exists(table, chain, rulespec...)
if err != nil {
return false, handleIPTablesError(err, l.Locker.Unlock())
}
return b, l.Locker.Unlock()
}
func (l *LockedIPTables) bulkAction(table, prefix string, rulespec ...IPTablesRule) error {
if err := l.Locker.Lock(); err != nil {
return fmt.Errorf("lock: %s", err)
}
input := []string{fmt.Sprintf("*%s\n", table)}
for _, r := range rulespec {
tmp := fmt.Sprintf("%s %s\n", prefix, strings.Join(r, " "))
input = append(input, tmp)
}
input = append(input, "COMMIT\n")
err := l.Restorer.Restore(strings.Join(input, ""))
if err != nil {
return handleIPTablesError(err, l.Locker.Unlock())
}
return l.Locker.Unlock()
}
func (l *LockedIPTables) BulkInsert(table, chain string, pos int, rulespec ...IPTablesRule) error {
return l.bulkAction(table, fmt.Sprintf("-I %s %d", chain, pos), rulespec...)
}
func (l *LockedIPTables) BulkAppend(table, chain string, rulespec ...IPTablesRule) error {
return l.bulkAction(table, fmt.Sprintf("-A %s", chain), rulespec...)
}
func (l *LockedIPTables) Delete(table, chain string, rulespec IPTablesRule) error {
if err := l.Locker.Lock(); err != nil {
return fmt.Errorf("lock: %s", err)
}
err := l.IPTables.Delete(table, chain, rulespec...)
if err != nil {
return handleIPTablesError(err, l.Locker.Unlock())
}
return l.Locker.Unlock()
}
func (l *LockedIPTables) List(table, chain string) ([]string, error) {
if err := l.Locker.Lock(); err != nil {
return nil, fmt.Errorf("lock: %s", err)
}
ret, err := l.IPTables.List(table, chain)
if err != nil {
return nil, handleIPTablesError(err, l.Locker.Unlock())
}
return ret, l.Locker.Unlock()
}
func (l *LockedIPTables) NewChain(table, chain string) error {
return l.chainExec(table, chain, l.IPTables.NewChain)
}
func (l *LockedIPTables) ClearChain(table, chain string) error {
return l.chainExec(table, chain, l.IPTables.ClearChain)
}
func (l *LockedIPTables) DeleteChain(table, chain string) error {
return l.chainExec(table, chain, l.IPTables.DeleteChain)
}
func (l *LockedIPTables) chainExec(table, chain string, action func(string, string) error) error {
if err := l.Locker.Lock(); err != nil {
return fmt.Errorf("lock: %s", err)
}
if err := action(table, chain); err != nil {
return handleIPTablesError(err, l.Locker.Unlock())
}
return l.Locker.Unlock()
}