Skip to content

Commit

Permalink
Port filter rules to nftables
Browse files Browse the repository at this point in the history
  • Loading branch information
danwinship committed Oct 31, 2023
1 parent 6cff415 commit 0c5c620
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 57 deletions.
260 changes: 258 additions & 2 deletions pkg/proxy/nftables/helpers_test.go
Expand Up @@ -19,6 +19,7 @@ package nftables
import (
"context"
"fmt"
"regexp"
"runtime"
"sort"
"strings"
Expand All @@ -29,6 +30,8 @@ import (
"github.com/lithammer/dedent"

"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/sets"
netutils "k8s.io/utils/net"
)

// getLine returns a string containing the file and line number of the caller, if
Expand All @@ -55,6 +58,15 @@ var objectOrder = map[string]int{
// anything else: 0
}

// For most chains we leave the rules in order (because the order matters), but for chains
// with per-service rules, we don't know what order syncProxyRules is going to output them
// in, but the order doesn't matter anyway. So we sort the rules in those chains.
var sortedChains = sets.New(
kubeServicesFilterChain,
kubeExternalServicesChain,
kubeFirewallChain,
)

// sortNFTablesTransaction sorts an nftables transaction into a standard order for comparison
func sortNFTablesTransaction(tx string) string {
lines := strings.Split(tx, "\n")
Expand Down Expand Up @@ -93,8 +105,13 @@ func sortNFTablesTransaction(tx string) string {
return wi[4] < wj[4]
}

// Leave rules in the order they were added in
if wi[1] == "rule" {
// Sort rules in chains that need to be sorted
if sortedChains.Has(wi[4]) {
return li < lj
}

// Otherwise leave rules in the order they were originally added.
return false
}

Expand Down Expand Up @@ -146,6 +163,224 @@ func assertNFTablesChainEqual(t *testing.T, line string, nft *knftables.Fake, ch
}
}

// nftablesTracer holds data used while virtually tracing a packet through a set of
// iptables rules
type nftablesTracer struct {
nft *knftables.Fake
nodeIPs sets.Set[string]
t *testing.T

// matches accumulates the list of rules that were matched, for debugging purposes.
matches []string

// outputs accumulates the list of matched terminal rule targets (endpoint
// IP:ports, or a special target like "REJECT") and is eventually used to generate
// the return value of tracePacket.
outputs []string

// markMasq tracks whether the packet has been marked for masquerading
markMasq bool
}

// newNFTablesTracer creates an nftablesTracer. nodeIPs are the IP to treat as local node
// IPs (for determining whether rules with "fib saddr type local" or "fib daddr type
// local" match).
func newNFTablesTracer(t *testing.T, nft *knftables.Fake, nodeIPs []string) *nftablesTracer {
return &nftablesTracer{
nft: nft,
nodeIPs: sets.New(nodeIPs...),
t: t,
}
}

func (tracer *nftablesTracer) addressMatches(ipStr, not, ruleAddress string) bool {
ip := netutils.ParseIPSloppy(ipStr)
if ip == nil {
tracer.t.Fatalf("Bad IP in test case: %s", ipStr)
}

var match bool
if strings.Contains(ruleAddress, "/") {
_, cidr, err := netutils.ParseCIDRSloppy(ruleAddress)
if err != nil {
tracer.t.Errorf("Bad CIDR in kube-proxy output: %v", err)
}
match = cidr.Contains(ip)
} else {
ip2 := netutils.ParseIPSloppy(ruleAddress)
if ip2 == nil {
tracer.t.Errorf("Bad IP/CIDR in kube-proxy output: %s", ruleAddress)
}
match = ip.Equal(ip2)
}

if not == "!= " {
return !match
} else {
return match
}
}

// We intentionally don't try to parse arbitrary nftables rules, as the syntax is quite
// complicated and context sensitive. (E.g., "ip daddr" could be the start of an address
// comparison, or it could be the start of a set/map lookup.) Instead, we just have
// regexps to recognize the specific pieces of rules that we create in proxier.go.
// Anything matching ignoredRegexp gets stripped out of the rule, and then what's left
// *must* match one of the cases in runChain or an error will be logged. In cases where
// the regexp doesn't end with `$`, and the matched rule succeeds against the input data,
// runChain will continue trying to match the rest of the rule. E.g., "ip daddr 10.0.0.1
// drop" would first match destAddrRegexp, and then (assuming destIP was "10.0.0.1") would
// match verdictRegexp.

var destAddrRegexp = regexp.MustCompile(`^ip6* daddr (!= )?(\S+)`)
var destAddrLocalRegexp = regexp.MustCompile(`^fib daddr type local`)
var destPortRegexp = regexp.MustCompile(`^(tcp|udp|sctp) dport (\d+)`)

var jumpRegexp = regexp.MustCompile(`^(jump|goto) (\S+)$`)
var verdictRegexp = regexp.MustCompile(`^(drop|reject)$`)

var ignoredRegexp = regexp.MustCompile(strings.Join(
[]string{
// Ignore comments (which can only appear at the end of a rule).
` *comment "[^"]*"$`,

// The trace tests only check new connections, so for our purposes, this
// check always succeeds (and thus can be ignored).
`^ct state new`,

// Likewise, this rule never matches and thus never drops anything, and so
// can be ignored.
`^ct state invalid drop$`,
},
"|",
))

// runChain runs the given packet through the rules in the given table and chain, updating
// tracer's internal state accordingly. It returns true if it hits a terminal action.
func (tracer *nftablesTracer) runChain(chname, sourceIP, protocol, destIP, destPort string) bool {
ch := tracer.nft.Table.Chains[chname]
if ch == nil {
tracer.t.Errorf("unknown chain %q", chname)
return true
}

for _, ruleObj := range ch.Rules {
rule := ignoredRegexp.ReplaceAllLiteralString(ruleObj.Rule, "")
for rule != "" {
rule = strings.TrimLeft(rule, " ")

switch {
case destAddrRegexp.MatchString(rule):
// `^ip6* daddr (!= )?(\S+)`
// Tests whether destIP does/doesn't match a literal.
match := destAddrRegexp.FindStringSubmatch(rule)
rule = strings.TrimPrefix(rule, match[0])
not, ip := match[1], match[2]
if !tracer.addressMatches(destIP, not, ip) {
rule = ""
break
}

case destAddrLocalRegexp.MatchString(rule):
// `^fib daddr type local`
// Tests whether destIP is a local IP.
match := destAddrLocalRegexp.FindStringSubmatch(rule)
rule = strings.TrimPrefix(rule, match[0])
if !tracer.nodeIPs.Has(destIP) {
rule = ""
break
}

case destPortRegexp.MatchString(rule):
// `^(tcp|udp|sctp) dport (\d+)`
// Tests whether destPort matches a literal.
match := destPortRegexp.FindStringSubmatch(rule)
rule = strings.TrimPrefix(rule, match[0])
proto, port := match[1], match[2]
if protocol != proto || destPort != port {
rule = ""
break
}

case jumpRegexp.MatchString(rule):
// `^(jump|goto) (\S+)$`
// Jumps to another chain.
match := jumpRegexp.FindStringSubmatch(rule)
rule = strings.TrimPrefix(rule, match[0])
action, destChain := match[1], match[2]

tracer.matches = append(tracer.matches, ruleObj.Rule)
terminated := tracer.runChain(destChain, sourceIP, protocol, destIP, destPort)
if terminated {
// destChain reached a terminal statement, so we
// terminate too.
return true
} else if action == "goto" {
// After a goto, return to our calling chain
// (without terminating) rather than continuing
// with this chain.
return false
}

case verdictRegexp.MatchString(rule):
// `^(drop|reject)$`
// Drop/reject the packet and terminate processing.
match := verdictRegexp.FindStringSubmatch(rule)
verdict := match[1]

tracer.matches = append(tracer.matches, ruleObj.Rule)
tracer.outputs = append(tracer.outputs, strings.ToUpper(verdict))
return true

default:
tracer.t.Errorf("unmatched rule: %s", ruleObj.Rule)
rule = ""
}
}
}

return false
}

// tracePacket determines what would happen to a packet with the given sourceIP, destIP,
// and destPort, given the indicated iptables ruleData. nodeIPs are the local node IPs (for
// rules matching "local"). (The protocol value should be lowercase as in nftables
// rules, not uppercase as in corev1.)
//
// The return values are: an array of matched rules (for debugging), the final packet
// destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT",
// "DROP", or "REJECT"), and whether the packet would be masqueraded.
func tracePacket(t *testing.T, nft *knftables.Fake, sourceIP, protocol, destIP, destPort string, nodeIPs []string) ([]string, string, bool) {
tracer := newNFTablesTracer(t, nft, nodeIPs)

// Collect "base chains" (ie, the chains that are run by netfilter directly rather
// than only being run when they are jumped to). Skip postrouting because it only
// does masquerading and we handle that separately.
var baseChains []string
for chname, ch := range nft.Table.Chains {
if ch.Priority != nil && chname != "nat-postrouting" {
baseChains = append(baseChains, chname)
}
}

// Sort by priority
sort.Slice(baseChains, func(i, j int) bool {
// FIXME: IPv4 vs IPv6 doesn't actually matter here
iprio, _ := knftables.ParsePriority(knftables.IPv4Family, string(*nft.Table.Chains[baseChains[i]].Priority))
jprio, _ := knftables.ParsePriority(knftables.IPv4Family, string(*nft.Table.Chains[baseChains[j]].Priority))
return iprio < jprio
})

for _, chname := range baseChains {
terminated := tracer.runChain(chname, sourceIP, protocol, destIP, destPort)
if terminated {
break
}
}

return tracer.matches, strings.Join(tracer.outputs, ", "), tracer.markMasq
}

type packetFlowTest struct {
name string
sourceIP string
Expand All @@ -158,7 +393,28 @@ type packetFlowTest struct {

func runPacketFlowTests(t *testing.T, line string, nft *knftables.Fake, nodeIPs []string, testCases []packetFlowTest) {
for _, tc := range testCases {
t.Logf("Skipping test %s which doesn't work yet", tc.name)
if tc.output != "DROP" && tc.output != "REJECT" && tc.output != "" {
t.Logf("Skipping test %s which doesn't work yet", tc.name)
continue
}
t.Run(tc.name, func(t *testing.T) {
protocol := strings.ToLower(string(tc.protocol))
if protocol == "" {
protocol = "tcp"
}
matches, output, masq := tracePacket(t, nft, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIPs)
var errors []string
if output != tc.output {
errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output))
}
if masq != tc.masq {
errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq))
}
if errors != nil {
t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n",
tc.name, tc.sourceIP, tc.destIP, tc.destPort, line, strings.Join(errors, "\n"), strings.Join(matches, "\n"))
}
})
}
}

Expand Down

0 comments on commit 0c5c620

Please sign in to comment.