Skip to content

Commit

Permalink
use error groups and supply managers with a context
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Williams committed Dec 8, 2022
1 parent 64cec29 commit bd3ffdd
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 77 deletions.
3 changes: 3 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
linters-settings:
nestif:
min-complexity: 6
35 changes: 23 additions & 12 deletions cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ package main

import (
"bufio"
"context"
"flag"
"os"
"sync"
"time"

"github.com/google/nftables"
"github.com/google/nftables/expr"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"

"github.com/ngrok/firewall_toolkit/pkg/expressions"
Expand Down Expand Up @@ -163,11 +164,11 @@ func main() {

// manager mode will keep running refreshing sets based on what's in the files
if *mode == "manager" {
var wg sync.WaitGroup
wg.Add(4)
ctx, cancel := context.WithCancel(context.Background())
eg, gctx := errgroup.WithContext(ctx)
defer cancel()

ipv4SetManager, err := set.ManagerInit(
&wg,
c,
ipv4Set,
ipSource.getIPList,
Expand All @@ -180,7 +181,6 @@ func main() {
}

ipv6SetManager, err := set.ManagerInit(
&wg,
c,
ipv6Set,
ipSource.getIPList,
Expand All @@ -193,7 +193,6 @@ func main() {
}

portSetManager, err := set.ManagerInit(
&wg,
c,
portSet,
portSource.getPortList,
Expand All @@ -206,7 +205,6 @@ func main() {
}

ruleManager, err := rule.ManagerInit(
&wg,
c,
ruleTarget,
ruleInfo.createRuleData,
Expand All @@ -218,12 +216,25 @@ func main() {
logger.Default.Fatal(err)
}

go ipv4SetManager.Start()
go ipv6SetManager.Start()
go portSetManager.Start()
go ruleManager.Start()
eg.Go(func() error {
return ipv4SetManager.Start(gctx)
})

wg.Wait()
eg.Go(func() error {
return ipv6SetManager.Start(gctx)
})

eg.Go(func() error {
return portSetManager.Start(gctx)
})

eg.Go(func() error {
return ruleManager.Start(gctx)
})

if err := eg.Wait(); err != nil {
logger.Default.Fatal(err)
}
}
}

Expand Down
58 changes: 26 additions & 32 deletions pkg/rule/rule_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
package rule

import (
"context"
"os"
"os/signal"
"sync"
"syscall"
"time"

Expand All @@ -18,7 +18,6 @@ type RulesUpdateFunc func() ([]RuleData, error)

// Represents a table/chain ruleset managed by the manager goroutine
type ManagedRules struct {
WaitGroup *sync.WaitGroup
Conn *nftables.Conn
RuleTarget RuleTarget
rulesUpdateFunc RulesUpdateFunc
Expand All @@ -27,9 +26,8 @@ type ManagedRules struct {
}

// Create a rule manager
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) {
func ManagerInit(c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) {
return ManagedRules{
WaitGroup: wg,
Conn: c,
RuleTarget: ruleTarget,
rulesUpdateFunc: f,
Expand All @@ -39,43 +37,39 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f
}

// Start the rule manager goroutine
func (r *ManagedRules) Start() {
func (r *ManagedRules) Start(ctx context.Context) error {
r.logger.Infof("starting rule manager for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
defer r.WaitGroup.Done()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

ticker := time.NewTicker(r.interval)
done := make(chan bool)

go func() {
for {
select {
case <-done:
return
case <-ticker.C:
ruleData, err := r.rulesUpdateFunc()
if err != nil {
r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
for {
select {
case <-ctx.Done():
r.logger.Infof("got context done, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
return nil
case sig := <-sigChan:
r.logger.Infof("got %s, stopping rule update loop for table/chain %v/%v", sig, r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
return nil
case <-ticker.C:
ruleData, err := r.rulesUpdateFunc()
if err != nil {
r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}

flush, err := r.RuleTarget.Update(r.Conn, ruleData)
if err != nil {
r.logger.Errorf("error updating rules: %v", err)
}
flush, err := r.RuleTarget.Update(r.Conn, ruleData)
if err != nil {
r.logger.Errorf("error updating rules: %v", err)
}

// only flush if things went well above
if flush {
r.logger.Infof("flushing rules for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
if err := r.Conn.Flush(); err != nil {
r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
// only flush if things went well above
if flush {
r.logger.Infof("flushing rules for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
if err := r.Conn.Flush(); err != nil {
r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
}
}
}()

<-sigChan
r.logger.Infof("got sigterm, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
}
}
60 changes: 27 additions & 33 deletions pkg/set/set_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
package set

import (
"context"
"os"
"os/signal"
"sync"
"syscall"
"time"

Expand All @@ -18,7 +18,6 @@ type SetUpdateFunc func() ([]SetData, error)

// Represents a set managed by the manager goroutine
type ManagedSet struct {
WaitGroup *sync.WaitGroup
Conn *nftables.Conn
Set Set
setUpdateFunc SetUpdateFunc
Expand All @@ -27,9 +26,8 @@ type ManagedSet struct {
}

// Create a set manager
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) {
func ManagerInit(c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) {
return ManagedSet{
WaitGroup: wg,
Conn: c,
Set: set,
setUpdateFunc: f,
Expand All @@ -39,44 +37,40 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc,
}

// Start the set manager goroutine
func (s *ManagedSet) Start() {
func (s *ManagedSet) Start(ctx context.Context) error {
s.logger.Infof("starting set manager for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
defer s.WaitGroup.Done()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

ticker := time.NewTicker(s.interval)
done := make(chan bool)

go func() {
for {
select {
case <-done:
return
case <-ticker.C:
data, err := s.setUpdateFunc()
if err != nil {
s.logger.Errorf("error with set update function for table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}
for {
select {
case <-ctx.Done():
s.logger.Infof("got context done, stopping set update loop for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
return nil
case sig := <-sigChan:
s.logger.Infof("got %s, stopping set update loop for table/set %v/%v", sig, s.Set.Set.Table.Name, s.Set.Set.Name)
return nil
case <-ticker.C:
data, err := s.setUpdateFunc()
if err != nil {
s.logger.Errorf("error with set update function for table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}

flush, err := s.Set.UpdateElements(s.Conn, data)
if err != nil {
s.logger.Errorf("error updating table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}
flush, err := s.Set.UpdateElements(s.Conn, data)
if err != nil {
s.logger.Errorf("error updating table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}

// only flush if things went well above
if flush {
if err := s.Conn.Flush(); err != nil {
s.logger.Errorf("error flushing table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
}
// only flush if things went well above
if flush {
if err := s.Conn.Flush(); err != nil {
s.logger.Errorf("error flushing table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
}
}
}
}()

<-sigChan
s.logger.Infof("got sigterm, stopping set update loop for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
}
}

0 comments on commit bd3ffdd

Please sign in to comment.