@@ -20,6 +20,7 @@ import (
2020 "golang.org/x/sys/unix"
2121 "tailscale.com/net/tsaddr"
2222 "tailscale.com/types/logger"
23+ "tailscale.com/types/ptr"
2324)
2425
2526const (
@@ -316,8 +317,33 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
316317 return n .conn .Flush ()
317318}
318319
319- // createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
320- func createTableIfNotExist (c * nftables.Conn , family nftables.TableFamily , name string ) (* nftables.Table , error ) {
320+ // deleteTableIfExists deletes a nftables table via connection c if it exists
321+ // within the given family.
322+ func deleteTableIfExists (c * nftables.Conn , family nftables.TableFamily , name string ) error {
323+ t , err := getTableIfExists (c , family , name )
324+ if err != nil {
325+ return fmt .Errorf ("get table: %w" , err )
326+ }
327+ if t == nil {
328+ // Table does not exist, so nothing to delete.
329+ return nil
330+ }
331+ c .DelTable (t )
332+ if err := c .Flush (); err != nil {
333+ if t , err = getTableIfExists (c , family , name ); t == nil && err == nil {
334+ // Check if the table still exists. If it does not, then the error
335+ // is due to the table not existing, so we can ignore it. Maybe a
336+ // concurrent process deleted the table.
337+ return nil
338+ }
339+ return fmt .Errorf ("del table: %w" , err )
340+ }
341+ return nil
342+ }
343+
344+ // getTableIfExists returns the table with the given name from the given family
345+ // if it exists. If none match, it returns (nil, nil).
346+ func getTableIfExists (c * nftables.Conn , family nftables.TableFamily , name string ) (* nftables.Table , error ) {
321347 tables , err := c .ListTables ()
322348 if err != nil {
323349 return nil , fmt .Errorf ("get tables: %w" , err )
@@ -327,7 +353,17 @@ func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name s
327353 return table , nil
328354 }
329355 }
356+ return nil , nil
357+ }
330358
359+ // createTableIfNotExist creates a nftables table via connection c if it does
360+ // not exist within the given family.
361+ func createTableIfNotExist (c * nftables.Conn , family nftables.TableFamily , name string ) (* nftables.Table , error ) {
362+ if t , err := getTableIfExists (c , family , name ); err != nil {
363+ return nil , fmt .Errorf ("get table: %w" , err )
364+ } else if t != nil {
365+ return t , nil
366+ }
331367 t := c .AddTable (& nftables.Table {
332368 Family : family ,
333369 Name : name ,
@@ -365,24 +401,6 @@ func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*n
365401 return nil , errorChainNotFound {table .Name , name }
366402}
367403
368- // getChainsFromTable returns all chains from the given table.
369- func getChainsFromTable (c * nftables.Conn , table * nftables.Table ) ([]* nftables.Chain , error ) {
370- chains , err := c .ListChainsOfTableFamily (table .Family )
371- if err != nil {
372- return nil , fmt .Errorf ("list chains: %w" , err )
373- }
374-
375- var ret []* nftables.Chain
376- for _ , chain := range chains {
377- // Table family is already checked so table name is unique
378- if chain .Table .Name == table .Name {
379- ret = append (ret , chain )
380- }
381- }
382-
383- return ret , nil
384- }
385-
386404// isTSChain reports whether `name` begins with "ts-" (and is thus a
387405// Tailscale-managed chain).
388406func isTSChain (name string ) bool {
@@ -804,6 +822,43 @@ func (n *nftablesRunner) AddChains() error {
804822 return n .conn .Flush ()
805823}
806824
825+ // These are dummy chains and tables we create to detect if nftables is
826+ // available. We create them, then delete them. If we can create and delete
827+ // them, then we can use nftables. If we can't, then we assume that we're
828+ // running on a system that doesn't support nftables. See
829+ // createDummyPostroutingChains.
830+ const (
831+ tsDummyChainName = "ts-test-postrouting"
832+ tsDummyTableName = "ts-test-nat"
833+ )
834+
835+ // createDummyPostroutingChains creates dummy postrouting chains in netfilter
836+ // via netfilter via nftables, as a last resort measure to detect that nftables
837+ // can be used. It cleans up the dummy chains after creation.
838+ func (n * nftablesRunner ) createDummyPostroutingChains () (retErr error ) {
839+ polAccept := ptr .To (nftables .ChainPolicyAccept )
840+ for _ , table := range n .getNATTables () {
841+ nat , err := createTableIfNotExist (n .conn , table .Proto , tsDummyTableName )
842+ if err != nil {
843+ return fmt .Errorf ("create nat table: %w" , err )
844+ }
845+ defer func (fm nftables.TableFamily ) {
846+ if err := deleteTableIfExists (n .conn , table .Proto , tsDummyTableName ); err != nil && retErr == nil {
847+ retErr = fmt .Errorf ("delete %q table: %w" , tsDummyTableName , err )
848+ }
849+ }(table .Proto )
850+
851+ table .Nat = nat
852+ if err = createChainIfNotExist (n .conn , chainInfo {nat , tsDummyChainName , nftables .ChainTypeNAT , nftables .ChainHookPostrouting , nftables .ChainPriorityNATSource , polAccept }); err != nil {
853+ return fmt .Errorf ("create %q chain: %w" , tsDummyChainName , err )
854+ }
855+ if err := deleteChainIfExists (n .conn , nat , tsDummyChainName ); err != nil {
856+ return fmt .Errorf ("delete %q chain: %w" , tsDummyChainName , err )
857+ }
858+ }
859+ return nil
860+ }
861+
807862// deleteChainIfExists deletes a chain if it exists.
808863func deleteChainIfExists (c * nftables.Conn , table * nftables.Table , name string ) error {
809864 chain , err := getChainFromTable (c , table , name )
0 commit comments