From bd3ffdd1c633b8813b7a4c67415d2e85cb9cd55b Mon Sep 17 00:00:00 2001 From: Joe Williams Date: Thu, 8 Dec 2022 15:23:47 -0800 Subject: [PATCH] use error groups and supply managers with a context --- .golangci.yaml | 3 + .../fwtk-input-filter-sets.go | 35 +++++++---- pkg/rule/rule_manager.go | 58 ++++++++---------- pkg/set/set_manager.go | 60 +++++++++---------- 4 files changed, 79 insertions(+), 77 deletions(-) create mode 100644 .golangci.yaml diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..e11e5e2 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,3 @@ +linters-settings: + nestif: + min-complexity: 6 \ No newline at end of file diff --git a/cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go b/cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go index f5d866d..4eb9afe 100644 --- a/cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go +++ b/cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go @@ -23,13 +23,14 @@ package main import ( "bufio" + "context" "flag" "os" - "sync" "time" "github.com/google/nftables" "github.com/google/nftables/expr" + "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" "github.com/ngrok/firewall_toolkit/pkg/expressions" @@ -163,11 +164,11 @@ func main() { // manager mode will keep running refreshing sets based on what's in the files if *mode == "manager" { - var wg sync.WaitGroup - wg.Add(4) + ctx, cancel := context.WithCancel(context.Background()) + eg, gctx := errgroup.WithContext(ctx) + defer cancel() ipv4SetManager, err := set.ManagerInit( - &wg, c, ipv4Set, ipSource.getIPList, @@ -180,7 +181,6 @@ func main() { } ipv6SetManager, err := set.ManagerInit( - &wg, c, ipv6Set, ipSource.getIPList, @@ -193,7 +193,6 @@ func main() { } portSetManager, err := set.ManagerInit( - &wg, c, portSet, portSource.getPortList, @@ -206,7 +205,6 @@ func main() { } ruleManager, err := rule.ManagerInit( - &wg, c, ruleTarget, ruleInfo.createRuleData, @@ -218,12 +216,25 @@ func main() { logger.Default.Fatal(err) } - go ipv4SetManager.Start() - go ipv6SetManager.Start() - go portSetManager.Start() - go ruleManager.Start() + eg.Go(func() error { + return ipv4SetManager.Start(gctx) + }) - wg.Wait() + eg.Go(func() error { + return ipv6SetManager.Start(gctx) + }) + + eg.Go(func() error { + return portSetManager.Start(gctx) + }) + + eg.Go(func() error { + return ruleManager.Start(gctx) + }) + + if err := eg.Wait(); err != nil { + logger.Default.Fatal(err) + } } } diff --git a/pkg/rule/rule_manager.go b/pkg/rule/rule_manager.go index 0b390ab..dede01f 100644 --- a/pkg/rule/rule_manager.go +++ b/pkg/rule/rule_manager.go @@ -3,9 +3,9 @@ package rule import ( + "context" "os" "os/signal" - "sync" "syscall" "time" @@ -18,7 +18,6 @@ type RulesUpdateFunc func() ([]RuleData, error) // Represents a table/chain ruleset managed by the manager goroutine type ManagedRules struct { - WaitGroup *sync.WaitGroup Conn *nftables.Conn RuleTarget RuleTarget rulesUpdateFunc RulesUpdateFunc @@ -27,9 +26,8 @@ type ManagedRules struct { } // Create a rule manager -func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) { +func ManagerInit(c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) { return ManagedRules{ - WaitGroup: wg, Conn: c, RuleTarget: ruleTarget, rulesUpdateFunc: f, @@ -39,43 +37,39 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f } // Start the rule manager goroutine -func (r *ManagedRules) Start() { +func (r *ManagedRules) Start(ctx context.Context) error { r.logger.Infof("starting rule manager for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name) - defer r.WaitGroup.Done() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - ticker := time.NewTicker(r.interval) - done := make(chan bool) - go func() { - for { - select { - case <-done: - return - case <-ticker.C: - ruleData, err := r.rulesUpdateFunc() - if err != nil { - r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err) - } + for { + select { + case <-ctx.Done(): + r.logger.Infof("got context done, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name) + return nil + case sig := <-sigChan: + r.logger.Infof("got %s, stopping rule update loop for table/chain %v/%v", sig, r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name) + return nil + case <-ticker.C: + ruleData, err := r.rulesUpdateFunc() + if err != nil { + r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err) + } - flush, err := r.RuleTarget.Update(r.Conn, ruleData) - if err != nil { - r.logger.Errorf("error updating rules: %v", err) - } + flush, err := r.RuleTarget.Update(r.Conn, ruleData) + if err != nil { + r.logger.Errorf("error updating rules: %v", err) + } - // only flush if things went well above - if flush { - r.logger.Infof("flushing rules for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name) - if err := r.Conn.Flush(); err != nil { - r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err) - } + // only flush if things went well above + if flush { + r.logger.Infof("flushing rules for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name) + if err := r.Conn.Flush(); err != nil { + r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err) } } } - }() - - <-sigChan - r.logger.Infof("got sigterm, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name) + } } diff --git a/pkg/set/set_manager.go b/pkg/set/set_manager.go index d076c66..dd6939b 100644 --- a/pkg/set/set_manager.go +++ b/pkg/set/set_manager.go @@ -3,9 +3,9 @@ package set import ( + "context" "os" "os/signal" - "sync" "syscall" "time" @@ -18,7 +18,6 @@ type SetUpdateFunc func() ([]SetData, error) // Represents a set managed by the manager goroutine type ManagedSet struct { - WaitGroup *sync.WaitGroup Conn *nftables.Conn Set Set setUpdateFunc SetUpdateFunc @@ -27,9 +26,8 @@ type ManagedSet struct { } // Create a set manager -func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) { +func ManagerInit(c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) { return ManagedSet{ - WaitGroup: wg, Conn: c, Set: set, setUpdateFunc: f, @@ -39,44 +37,40 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc, } // Start the set manager goroutine -func (s *ManagedSet) Start() { +func (s *ManagedSet) Start(ctx context.Context) error { s.logger.Infof("starting set manager for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name) - defer s.WaitGroup.Done() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - ticker := time.NewTicker(s.interval) - done := make(chan bool) - go func() { - for { - select { - case <-done: - return - case <-ticker.C: - data, err := s.setUpdateFunc() - if err != nil { - s.logger.Errorf("error with set update function for table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err) - continue - } + for { + select { + case <-ctx.Done(): + s.logger.Infof("got context done, stopping set update loop for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name) + return nil + case sig := <-sigChan: + s.logger.Infof("got %s, stopping set update loop for table/set %v/%v", sig, s.Set.Set.Table.Name, s.Set.Set.Name) + return nil + case <-ticker.C: + data, err := s.setUpdateFunc() + if err != nil { + s.logger.Errorf("error with set update function for table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err) + continue + } - flush, err := s.Set.UpdateElements(s.Conn, data) - if err != nil { - s.logger.Errorf("error updating table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err) - continue - } + flush, err := s.Set.UpdateElements(s.Conn, data) + if err != nil { + s.logger.Errorf("error updating table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err) + continue + } - // only flush if things went well above - if flush { - if err := s.Conn.Flush(); err != nil { - s.logger.Errorf("error flushing table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err) - } + // only flush if things went well above + if flush { + if err := s.Conn.Flush(); err != nil { + s.logger.Errorf("error flushing table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err) } } } - }() - - <-sigChan - s.logger.Infof("got sigterm, stopping set update loop for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name) + } }