Skip to content

Commit

Permalink
Add support for iptables in nftables mode.
Browse files Browse the repository at this point in the history
Iptables also has the ability to work in nftables mode, where it is
supposed to act like iptables but use the nftables subsystem.
Unfortunately, it isn't exactly the same.

The biggest difference is that counter output is iptables-save style,
rather than with "-c N N".

Also, improve some tests.
  • Loading branch information
squeed committed Aug 3, 2018
1 parent 25d087f commit 5c15b20
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 31 deletions.
97 changes: 76 additions & 21 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ import (
// Adds the output of stderr to exec.ExitError
type Error struct {
exec.ExitError
cmd exec.Cmd
msg string
cmd exec.Cmd
msg string
exitStatus *int //for overriding
}

func (e *Error) ExitStatus() int {
if e.exitStatus != nil {
return *e.exitStatus
}
return e.Sys().(syscall.WaitStatus).ExitStatus()
}

Expand Down Expand Up @@ -65,6 +69,7 @@ type IPTables struct {
v1 int
v2 int
v3 int
mode string // the underlying iptables operating mode, e.g. nf_tables
}

// New creates a new IPTables.
Expand All @@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
return nil, err
}
vstring, err := getIptablesVersionString(path)
v1, v2, v3, err := extractIptablesVersion(vstring)
v1, v2, v3, mode, err := extractIptablesVersion(vstring)

checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)

checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3)
if err != nil {
return nil, fmt.Errorf("error checking iptables version: %v", err)
}
ipt := IPTables{
path: path,
proto: proto,
Expand All @@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
v1: v1,
v2: v2,
v3: v3,
mode: mode,
}
return &ipt, nil
}
Expand Down Expand Up @@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
}

rules := strings.Split(stdout.String(), "\n")

// strip trailing newline
if len(rules) > 0 && rules[len(rules)-1] == "" {
rules = rules[:len(rules)-1]
}

// nftables mode doesn't return an error code when listing a non-existent
// chain. Patch that up.
if len(rules) == 0 && ipt.mode == "nf_tables" {
v := 1
return nil, &Error{
cmd: exec.Cmd{Args: args},
msg: "iptables: No chain/target/match by that name.",
exitStatus: &v,
}
}

for i, rule := range rules {
rules[i] = filterRuleOutput(rule)
}

return rules, nil
}

Expand All @@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
func (ipt *IPTables) ClearChain(table, chain string) error {
err := ipt.NewChain(table, chain)

// the exit code for "this table already exists" is different for
// different iptables modes
existsErr := 1
if ipt.mode == "nf_tables" {
existsErr = 4
}

eerr, eok := err.(*Error)
switch {
case err == nil:
return nil
case eok && eerr.ExitStatus() == 1:
case eok && eerr.ExitStatus() == existsErr:
// chain already exists. Flush (clear) it.
return ipt.run("-t", table, "-F", chain)
default:
Expand Down Expand Up @@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
if err := cmd.Run(); err != nil {
switch e := err.(type) {
case *exec.ExitError:
return &Error{*e, cmd, stderr.String()}
return &Error{*e, cmd, stderr.String(), nil}
default:
return err
}
Expand All @@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string {
}

// Checks if iptables has the "-C" and "--wait" flag
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) {

return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
}

// getIptablesVersion returns the first three components of the iptables version.
// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
func extractIptablesVersion(str string) (int, int, int, error) {
versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
// getIptablesVersion returns the first three components of the iptables version
// and the operating mode (e.g. nf_tables or legacy)
// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
func extractIptablesVersion(str string) (int, int, int, string, error) {
versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
result := versionMatcher.FindStringSubmatch(str)
if result == nil {
return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
}

v1, err := strconv.Atoi(result[1])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

v2, err := strconv.Atoi(result[2])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

v3, err := strconv.Atoi(result[3])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

return v1, v2, v3, nil
mode := "legacy"
if result[4] != "" {
mode = result[4]
}
return v1, v2, v3, mode, nil
}

// Runs "iptables --version" to get the version string
Expand Down Expand Up @@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
}
return strings.Contains(stdout.String(), rs), nil
}

// counterRegex is the regex used to detect nftables counter format
var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)

// filterRuleOutput works around some inconsistencies in output.
// For example, when iptables is in legacy vs. nftables mode, it produces
// different results.
func filterRuleOutput(rule string) string {
out := rule

// work around an output difference in nftables mode where counters
// are output in iptables-save format, rather than iptables -S format
// The string begins with "[0:0]"
//
// Fixes #49
if groups := counterRegex.FindStringSubmatch(out); groups != nil {
// drop the brackets
out = out[len(groups[0]):]
out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
}

return out
}
98 changes: 90 additions & 8 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables {
}

func TestChain(t *testing.T) {
for _, ipt := range mustTestableIptables() {
runChainTests(t, ipt)
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runChainTests(t, ipt)
})
}
}

Expand Down Expand Up @@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) {
}

func TestRules(t *testing.T) {
for _, ipt := range mustTestableIptables() {
runRulesTests(t, ipt)
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runRulesTests(t, ipt)
})
}
}

Expand Down Expand Up @@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
t.Fatalf("ListWithCounters failed: %v", err)
}

suffix := " -c 0 0 -j ACCEPT"
if ipt.mode == "nf_tables" {
suffix = " -j ACCEPT -c 0 0"
}

expected = []string{
"-N " + chain,
"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix,
"-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix,
"-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix,
"-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix,
}

if !reflect.DeepEqual(rules, expected) {
Expand Down Expand Up @@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) {
t.Fatal("IsNotExist returned false, expected true")
}
}

func TestFilterRuleOutput(t *testing.T) {
testCases := []struct {
name string
in string
out string
}{
{
"legacy output",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
},
{
"nft output",
"[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42",
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
actual := filterRuleOutput(tt.in)
if actual != tt.out {
t.Fatalf("expect %s actual %s", tt.out, actual)
}
})
}
}

func TestExtractIptablesVersion(t *testing.T) {
testCases := []struct {
in string
v1, v2, v3 int
mode string
err bool
}{
{
"iptables v1.8.0 (nf_tables)",
1, 8, 0,
"nf_tables",
false,
},
{
"iptables v1.8.0 (legacy)",
1, 8, 0,
"legacy",
false,
},
{
"iptables v1.6.2",
1, 6, 2,
"legacy",
false,
},
}

for i, tt := range testCases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
v1, v2, v3, mode, err := extractIptablesVersion(tt.in)
if err == nil && tt.err {
t.Fatal("expected err, got none")
} else if err != nil && !tt.err {
t.Fatalf("unexpected err %s", err)
}

if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
tt.v1, tt.v2, tt.v3, tt.mode,
v1, v2, v3, mode)
}
})
}
}
7 changes: 5 additions & 2 deletions test
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ split=(${TEST// / })
TEST=${split[@]/#/${REPO_PATH}/}

echo "Running tests..."
go test -i ${TEST}
bin=$(mktemp)

go test -c -o ${bin} ${COVER} -i ${TEST}
if [[ -z "$SUDO_PERMITTED" ]]; then
echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
exit 1
fi

sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}"
sudo -E bash -c "${bin} $@ ${TEST}"
echo "Success"
rm "${bin}"

0 comments on commit 5c15b20

Please sign in to comment.