Skip to content

Commit c3a8e63

Browse files
Maisem Alimaisem
authored andcommitted
util/linuxfw: add additional nftable detection logic
We were previously using the netlink API to see if there are chains/rules that already exist. This works fine in environments where there is either full nftable support or no support at all. However, we have identified certain environments which have partial nftable support and the only feasible way of detecting such an environment is to try to create some of the chains that we need. This adds a check to create a dummy postrouting chain which is immediately deleted. The goal of the check is to ensure we are able to use nftables and that it won't error out later. This check is only done in the path where we detected that the system has no preexisting nftable rules. Updates tailscale#5621 Updates tailscale#8555 Updates tailscale#8762 Signed-off-by: Maisem Ali <maisem@tailscale.com>
1 parent b47cf04 commit c3a8e63

File tree

3 files changed

+119
-20
lines changed

3 files changed

+119
-20
lines changed

util/linuxfw/nftables.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,29 @@ func DebugNetfilter(logf logger.Logf) error {
105105

106106
// detectNetfilter returns the number of nftables rules present in the system.
107107
func detectNetfilter() (int, error) {
108+
// Frist try creating a dummy postrouting chain. Emperically, we have
109+
// noticed that on some devices there is partial nftables support and the
110+
// kernel rejects some chains that are valid on other devices. This is a
111+
// workaround to detect that case.
112+
//
113+
// This specifically allows us to run in on GKE nodes using COS images which
114+
// have partial nftables support (as of 2023-10-18). When we try to create a
115+
// dummy postrouting chain, we get an error like:
116+
// add chain: conn.Receive: netlink receive: no such file or directory
117+
nft, err := newNfTablesRunner(logger.Discard)
118+
if err != nil {
119+
return 0, FWModeNotSupportedError{
120+
Mode: FirewallModeNfTables,
121+
Err: fmt.Errorf("cannot create nftables runner: %w", err),
122+
}
123+
}
124+
if err := nft.createDummyPostroutingChains(); err != nil {
125+
return 0, FWModeNotSupportedError{
126+
Mode: FirewallModeNfTables,
127+
Err: err,
128+
}
129+
}
130+
108131
conn, err := nftables.New()
109132
if err != nil {
110133
return 0, FWModeNotSupportedError{
@@ -129,6 +152,7 @@ func detectNetfilter() (int, error) {
129152
}
130153
validRules += len(rules)
131154
}
155+
132156
return validRules, nil
133157
}
134158

util/linuxfw/nftables_runner.go

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"golang.org/x/sys/unix"
2121
"tailscale.com/net/tsaddr"
2222
"tailscale.com/types/logger"
23+
"tailscale.com/types/ptr"
2324
)
2425

2526
const (
@@ -316,8 +317,33 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
316317
return n.conn.Flush()
317318
}
318319

319-
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
320-
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
320+
// deleteTableIfExists deletes a nftables table via connection c if it exists
321+
// within the given family.
322+
func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
323+
t, err := getTableIfExists(c, family, name)
324+
if err != nil {
325+
return fmt.Errorf("get table: %w", err)
326+
}
327+
if t == nil {
328+
// Table does not exist, so nothing to delete.
329+
return nil
330+
}
331+
c.DelTable(t)
332+
if err := c.Flush(); err != nil {
333+
if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
334+
// Check if the table still exists. If it does not, then the error
335+
// is due to the table not existing, so we can ignore it. Maybe a
336+
// concurrent process deleted the table.
337+
return nil
338+
}
339+
return fmt.Errorf("del table: %w", err)
340+
}
341+
return nil
342+
}
343+
344+
// getTableIfExists returns the table with the given name from the given family
345+
// if it exists. If none match, it returns (nil, nil).
346+
func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
321347
tables, err := c.ListTables()
322348
if err != nil {
323349
return nil, fmt.Errorf("get tables: %w", err)
@@ -327,7 +353,17 @@ func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name s
327353
return table, nil
328354
}
329355
}
356+
return nil, nil
357+
}
330358

359+
// createTableIfNotExist creates a nftables table via connection c if it does
360+
// not exist within the given family.
361+
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
362+
if t, err := getTableIfExists(c, family, name); err != nil {
363+
return nil, fmt.Errorf("get table: %w", err)
364+
} else if t != nil {
365+
return t, nil
366+
}
331367
t := c.AddTable(&nftables.Table{
332368
Family: family,
333369
Name: name,
@@ -365,24 +401,6 @@ func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*n
365401
return nil, errorChainNotFound{table.Name, name}
366402
}
367403

368-
// getChainsFromTable returns all chains from the given table.
369-
func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) {
370-
chains, err := c.ListChainsOfTableFamily(table.Family)
371-
if err != nil {
372-
return nil, fmt.Errorf("list chains: %w", err)
373-
}
374-
375-
var ret []*nftables.Chain
376-
for _, chain := range chains {
377-
// Table family is already checked so table name is unique
378-
if chain.Table.Name == table.Name {
379-
ret = append(ret, chain)
380-
}
381-
}
382-
383-
return ret, nil
384-
}
385-
386404
// isTSChain reports whether `name` begins with "ts-" (and is thus a
387405
// Tailscale-managed chain).
388406
func isTSChain(name string) bool {
@@ -804,6 +822,43 @@ func (n *nftablesRunner) AddChains() error {
804822
return n.conn.Flush()
805823
}
806824

825+
// These are dummy chains and tables we create to detect if nftables is
826+
// available. We create them, then delete them. If we can create and delete
827+
// them, then we can use nftables. If we can't, then we assume that we're
828+
// running on a system that doesn't support nftables. See
829+
// createDummyPostroutingChains.
830+
const (
831+
tsDummyChainName = "ts-test-postrouting"
832+
tsDummyTableName = "ts-test-nat"
833+
)
834+
835+
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
836+
// via netfilter via nftables, as a last resort measure to detect that nftables
837+
// can be used. It cleans up the dummy chains after creation.
838+
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
839+
polAccept := ptr.To(nftables.ChainPolicyAccept)
840+
for _, table := range n.getNATTables() {
841+
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
842+
if err != nil {
843+
return fmt.Errorf("create nat table: %w", err)
844+
}
845+
defer func(fm nftables.TableFamily) {
846+
if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil {
847+
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
848+
}
849+
}(table.Proto)
850+
851+
table.Nat = nat
852+
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
853+
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
854+
}
855+
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
856+
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
857+
}
858+
}
859+
return nil
860+
}
861+
807862
// deleteChainIfExists deletes a chain if it exists.
808863
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
809864
chain, err := getChainFromTable(c, table, name)

util/linuxfw/nftables_runner_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,26 @@ func (t *testFWDetector) nftDetect() (int, error) {
851851
return t.nftRuleCount, t.nftErr
852852
}
853853

854+
// TestCreateDummyPostroutingChains tests that on a system with nftables
855+
// available, the function does not return an error and that the dummy
856+
// postrouting chains are cleaned up.
857+
func TestCreateDummyPostroutingChains(t *testing.T) {
858+
conn := newSysConn(t)
859+
runner := newFakeNftablesRunner(t, conn)
860+
if err := runner.createDummyPostroutingChains(); err != nil {
861+
t.Fatalf("createDummyPostroutingChains() failed: %v", err)
862+
}
863+
for _, table := range runner.getNATTables() {
864+
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
865+
if err != nil {
866+
t.Fatalf("getTableIfExists() failed: %v", err)
867+
}
868+
if nt != nil {
869+
t.Fatalf("expected table to be nil, got %v", nt)
870+
}
871+
}
872+
}
873+
854874
func TestPickFirewallModeFromInstalledRules(t *testing.T) {
855875
tests := []struct {
856876
name string

0 commit comments

Comments
 (0)