Skip to content

Commit

Permalink
Merge pull request #3 from ngrok/fixes
Browse files Browse the repository at this point in the history
Refactor rule api and a few fixes/typos
  • Loading branch information
Joe Williams committed Dec 6, 2022
2 parents 3dfe0ad + fb7401b commit fc18996
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 40 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This is a collection of golang libraries and tools for managing nftables. It pro
* `pkg/expressions` includes nftables expression partials for generating common firewall rules.
* `pkg/xtables` library for bpf/ebpf nftables rule creation. It supports adding all three types of xtables bpf match configurations: bytecode, pinned bpf programs and socket file descriptors.
* `pkg/set` is a library for managing nftables sets, it supports IPv4, IPv6 and port based set types.
* `pkg/rule` is a library for managing nftable rules, it uses rule "user data" to provide unique IDs for each rule in a given chain.
* `pkg/logger` supports the stdlib log and [zerolog](https://github.com/rs/zerolog), or bring your own logger.
* `pkg/utils` utility functions for validating IPs and etc.
* `cmd/*` provides tools you can use to manage nftables built on top of the firewall_toolkit, also serves as an example of how to use the library.
Expand Down
3 changes: 2 additions & 1 deletion cmd/fwtk-input-filter-bpf/fwtk-input-filter-bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ func main() {
logger.Default.Fatal(err)
}

ruleTarget := rule.NewRuleTarget(nfTable, nfChain)
bpfRule := rule.NewRuleData([]byte{0xd, 0xe, 0xa, 0xd}, expressions.MatchBpfWithVerdict(xtBpfInfoBytes, nfVerdict))
added, err := rule.Add(c, nfTable, nfChain, bpfRule)
added, err := ruleTarget.Add(c, bpfRule)
if err != nil {
logger.Default.Fatalf("adding rule %x failed: %v", bpfRule.ID, err)
}
Expand Down
19 changes: 10 additions & 9 deletions cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ func main() {
}

// create all the sets you plan to use
ipv4Set, err := set.New("ipv4_blocklist", c, nfTable, nftables.TypeIPAddr)
ipv4Set, err := set.New(c, nfTable, "ipv4_blocklist", nftables.TypeIPAddr)
if err != nil {
logger.Default.Fatalf("new set failed %v", err)
}

ipv6Set, err := set.New("ipv6_blocklist", c, nfTable, nftables.TypeIP6Addr)
ipv6Set, err := set.New(c, nfTable, "ipv6_blocklist", nftables.TypeIP6Addr)
if err != nil {
logger.Default.Fatalf("new set failed %v", err)
}

portSet, err := set.New("port_blocklist", c, nfTable, nftables.TypeInetService)
portSet, err := set.New(c, nfTable, "port_blocklist", nftables.TypeInetService)
if err != nil {
logger.Default.Fatalf("new set failed %v", err)
}
Expand Down Expand Up @@ -131,6 +131,8 @@ func main() {
logger.Default.Fatalf("add elements flush failed: %v", err)
}

ruleTarget := rule.NewRuleTarget(nfTable, nfChain)

ruleInfo := newRuleInfo(portSet, ipv4Set, ipv6Set)

ruleData, err := ruleInfo.createRuleData()
Expand All @@ -140,7 +142,7 @@ func main() {

flush := false
for _, rD := range ruleData {
added, err := rule.Add(c, nfTable, nfChain, rD)
added, err := ruleTarget.Add(c, rD)
if err != nil {
logger.Default.Fatalf("adding rule %x failed: %v", rD.ID, err)
}
Expand All @@ -167,7 +169,7 @@ func main() {
ipv4SetManager, err := set.ManagerInit(
&wg,
c,
&ipv4Set,
ipv4Set,
ipSource.getIPList,
RefreshInterval,
logger.Default,
Expand All @@ -180,7 +182,7 @@ func main() {
ipv6SetManager, err := set.ManagerInit(
&wg,
c,
&ipv6Set,
ipv6Set,
ipSource.getIPList,
RefreshInterval,
logger.Default,
Expand All @@ -193,7 +195,7 @@ func main() {
portSetManager, err := set.ManagerInit(
&wg,
c,
&portSet,
portSet,
portSource.getPortList,
RefreshInterval,
logger.Default,
Expand All @@ -206,8 +208,7 @@ func main() {
ruleManager, err := rule.ManagerInit(
&wg,
c,
nfTable,
nfChain,
ruleTarget,
ruleInfo.createRuleData,
RefreshInterval,
logger.Default,
Expand Down
28 changes: 21 additions & 7 deletions pkg/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,23 @@ import (
"github.com/google/nftables"
)

// RuleTarget represents a location to manipulate nftables rules
type RuleTarget struct {
Table *nftables.Table
Chain *nftables.Chain
}

// Create a new location to manipulate nftables rules
func NewRuleTarget(table *nftables.Table, chain *nftables.Chain) RuleTarget {
return RuleTarget{
Table: table,
Chain: chain,
}
}

// Add a rule with a given ID to a specific table and chain, returns true if the rule was added
func Add(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ruleData RuleData) (bool, error) {
exists, err := Exists(c, table, chain, ruleData)
func (r *RuleTarget) Add(c *nftables.Conn, ruleData RuleData) (bool, error) {
exists, err := r.Exists(c, ruleData)
if err != nil {
return false, err
}
Expand All @@ -22,7 +36,7 @@ func Add(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ruleDat
return false, nil
}

add(c, table, chain, ruleData)
add(c, r.Table, r.Chain, ruleData)
return true, nil
}

Expand All @@ -36,8 +50,8 @@ func add(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ruleDat
}

// Delete a rule with a given ID from a specific table and chain, returns true if the rule was deleted
func Delete(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ruleData RuleData) (bool, error) {
rules, err := c.GetRules(table, chain)
func (r *RuleTarget) Delete(c *nftables.Conn, ruleData RuleData) (bool, error) {
rules, err := c.GetRules(r.Table, r.Chain)
if err != nil {
return false, err
}
Expand All @@ -57,8 +71,8 @@ func Delete(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, rule
}

// Determine if a rule with a given ID exists in a specific table and chain
func Exists(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ruleData RuleData) (bool, error) {
rules, err := c.GetRules(table, chain)
func (r *RuleTarget) Exists(c *nftables.Conn, ruleData RuleData) (bool, error) {
rules, err := c.GetRules(r.Table, r.Chain)
if err != nil {
return false, err
}
Expand Down
24 changes: 11 additions & 13 deletions pkg/rule/rule_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ type RulesUpdateFunc func() ([]RuleData, error)
type ManagedRules struct {
WaitGroup *sync.WaitGroup
Conn *nftables.Conn
Table *nftables.Table
Chain *nftables.Chain
RuleTarget RuleTarget
rulesUpdateFunc RulesUpdateFunc
interval time.Duration
logger logger.Logger
}

// Create a rule manager
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) {
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) {
return ManagedRules{
WaitGroup: wg,
Conn: c,
Table: table,
Chain: chain,
RuleTarget: ruleTarget,
rulesUpdateFunc: f,
interval: interval,
logger: logger,
Expand All @@ -42,7 +40,7 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, table *nftables.Table, ch

// Start the rule manager goroutine
func (r *ManagedRules) Start() {
r.logger.Infof("starting rule manager for table/chain %v/%v", r.Table.Name, r.Chain.Name)
r.logger.Infof("starting rule manager for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
defer r.WaitGroup.Done()

sigChan := make(chan os.Signal, 1)
Expand All @@ -62,19 +60,19 @@ func (r *ManagedRules) Start() {

ruleData, err := r.rulesUpdateFunc()
if err != nil {
r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.Table.Name, r.Chain.Name, err)
r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
flush = false
}

for _, rD := range ruleData {
added, err := Add(r.Conn, r.Table, r.Chain, rD)
added, err := r.RuleTarget.Add(r.Conn, rD)
if err != nil {
r.logger.Errorf("error adding rule %x for table/chain %v/%v: %v", rD.ID, r.Table.Name, r.Chain.Name, err)
r.logger.Errorf("error adding rule %x for table/chain %v/%v: %v", rD.ID, r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
flush = false
}

if added {
r.logger.Infof("added rule %x for table/chain %v/%v", rD.ID, r.Table.Name, r.Chain.Name)
r.logger.Infof("added rule %x for table/chain %v/%v", rD.ID, r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
addCount++
}
}
Expand All @@ -86,15 +84,15 @@ func (r *ManagedRules) Start() {

// only flush if things went well above
if flush {
r.logger.Infof("flushing %v rules for table/chain %v/%v", addCount, r.Table.Name, r.Chain.Name)
r.logger.Infof("flushing %v rules for table/chain %v/%v", addCount, r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
if err := r.Conn.Flush(); err != nil {
r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.Table.Name, r.Chain.Name, err)
r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
}
}
}
}()

<-sigChan
r.logger.Infof("got sigterm, stopping rule update loop for table/chain %v/%v", r.Table.Name, r.Chain.Name)
r.logger.Infof("got sigterm, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
}
2 changes: 1 addition & 1 deletion pkg/set/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Set struct {
}

// Create a new set on a table with a given key type
func New(name string, c *nftables.Conn, table *nftables.Table, keyType nftables.SetDatatype) (Set, error) {
func New(c *nftables.Conn, table *nftables.Table, name string, keyType nftables.SetDatatype) (Set, error) {
// we've seen problems where sets need to be initialized with a value otherwise nftables seems to default to the
// native endianness, likely little endian, which is always incorrect for network stuff resulting in backwards ips, etc.
// we set everything to documentation values and then immediately delete them leaving empty, correctly created sets.
Expand Down
6 changes: 3 additions & 3 deletions pkg/set/set_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ type SetUpdateFunc func() ([]SetData, error)
type ManagedSet struct {
WaitGroup *sync.WaitGroup
Conn *nftables.Conn
Set *Set
Set Set
setUpdateFunc SetUpdateFunc
interval time.Duration
logger logger.Logger
}

// Create a set manager
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set *Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) {
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) {
return ManagedSet{
WaitGroup: wg,
Conn: c,
Expand All @@ -40,7 +40,7 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set *Set, f SetUpdateFunc

// Start the set manager goroutine
func (s *ManagedSet) Start() {
s.logger.Infof("starting set manager fortable/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
s.logger.Infof("starting set manager for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
defer s.WaitGroup.Done()

sigChan := make(chan os.Signal, 1)
Expand Down
36 changes: 30 additions & 6 deletions pkg/set/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,25 @@ import (
)

func TestNewSetBadType(t *testing.T) {
c := testDialWithWant(t, [][]byte{})
want := [][]byte{
// batch begin
{0x0, 0x0, 0x0, 0xa},
// add testtable
// "0x74, 0x65, 0x73, 0x74, 0x74, 0x61, 0x62, 0x6c, 0x65" == "testtable"
{0x1, 0x0, 0x0, 0x0, 0xe, 0x0, 0x1, 0x0, 0x74, 0x65, 0x73, 0x74, 0x74, 0x61, 0x62, 0x6c, 0x65, 0x0, 0x0, 0x0, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0},
// batch end
{0x0, 0x0, 0x0, 0xa},
}
c := testDialWithWant(t, want)

table := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyINet,
Name: "testtable",
})
res, err := New("testset", c, table, nftables.TypeARPHRD)
res, err := New(c, table, "testset", nftables.TypeARPHRD)
assert.Error(t, err)
assert.Equal(t, Set{}, res)
c.Flush()
}

func TestNewV4Set(t *testing.T) {
Expand Down Expand Up @@ -57,14 +67,15 @@ func TestNewV4Set(t *testing.T) {
Family: nftables.TableFamilyINet,
Name: "testtable",
})
res, err := New("testset", c, table, nftables.TypeIPAddr)
res, err := New(c, table, "testset", nftables.TypeIPAddr)
assert.Nil(t, err)

assert.True(t, res.Set.Counter)
assert.True(t, res.Set.Interval)
assert.Equal(t, "testset", res.Set.Name)
assert.Equal(t, "testtable", res.Set.Table.Name)
assert.Equal(t, nftables.TypeIPAddr, res.Set.KeyType)
c.Flush()
}

func TestNewV6Set(t *testing.T) {
Expand Down Expand Up @@ -97,14 +108,15 @@ func TestNewV6Set(t *testing.T) {
Family: nftables.TableFamilyINet,
Name: "testtable",
})
res, err := New("testset", c, table, nftables.TypeIP6Addr)
res, err := New(c, table, "testset", nftables.TypeIP6Addr)
assert.Nil(t, err)

assert.True(t, res.Set.Counter)
assert.True(t, res.Set.Interval)
assert.Equal(t, "testset", res.Set.Name)
assert.Equal(t, "testtable", res.Set.Table.Name)
assert.Equal(t, nftables.TypeIP6Addr, res.Set.KeyType)
c.Flush()
}

func TestNewPortSet(t *testing.T) {
Expand Down Expand Up @@ -137,14 +149,15 @@ func TestNewPortSet(t *testing.T) {
Family: nftables.TableFamilyINet,
Name: "testtable",
})
res, err := New("testset", c, table, nftables.TypeInetService)
res, err := New(c, table, "testset", nftables.TypeInetService)
assert.Nil(t, err)

assert.True(t, res.Set.Counter)
assert.True(t, res.Set.Interval)
assert.Equal(t, "testset", res.Set.Name)
assert.Equal(t, "testtable", res.Set.Table.Name)
assert.Equal(t, nftables.TypeInetService, res.Set.KeyType)
c.Flush()
}

func TestClearAndAddElements(t *testing.T) {
Expand Down Expand Up @@ -181,10 +194,20 @@ func TestClearAndAddElements(t *testing.T) {
assert.Nil(t, err)
err = set.ClearAndAddElements(c, []SetData{setData})
assert.Nil(t, err)
c.Flush()
}

func TestUpdateSetBadType(t *testing.T) {
c := testDialWithWant(t, [][]byte{})
want := [][]byte{
// batch begin
{0x0, 0x0, 0x0, 0xa},
// "0xe" == unix.NFT_MSG_DELSETELEM
{0x1, 0x0, 0x0, 0x0, 0xe, 0x0, 0x1, 0x0, 0x74, 0x65, 0x73, 0x74, 0x74, 0x61, 0x62, 0x6c, 0x65, 0x0, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x74, 0x65, 0x73, 0x74, 0x73, 0x65, 0x74, 0x0},
// batch end
{0x0, 0x0, 0x0, 0xa},
}

c := testDialWithWant(t, want)

nfTable := &nftables.Table{
Family: nftables.TableFamilyINet,
Expand All @@ -204,6 +227,7 @@ func TestUpdateSetBadType(t *testing.T) {
assert.Nil(t, err)
err = set.ClearAndAddElements(c, []SetData{setData})
assert.Error(t, err)
c.Flush()
}

func testDialWithWant(t *testing.T, want [][]byte) *nftables.Conn {
Expand Down

0 comments on commit fc18996

Please sign in to comment.