Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use error groups and supply managers with a context #5

Merged
merged 1 commit into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
linters-settings:
nestif:
min-complexity: 6
35 changes: 23 additions & 12 deletions cmd/fwtk-input-filter-sets/fwtk-input-filter-sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ package main

import (
"bufio"
"context"
"flag"
"os"
"sync"
"time"

"github.com/google/nftables"
"github.com/google/nftables/expr"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"

"github.com/ngrok/firewall_toolkit/pkg/expressions"
Expand Down Expand Up @@ -163,11 +164,11 @@ func main() {

// manager mode will keep running refreshing sets based on what's in the files
if *mode == "manager" {
var wg sync.WaitGroup
wg.Add(4)
ctx, cancel := context.WithCancel(context.Background())
eg, gctx := errgroup.WithContext(ctx)
defer cancel()

ipv4SetManager, err := set.ManagerInit(
&wg,
c,
ipv4Set,
ipSource.getIPList,
Expand All @@ -180,7 +181,6 @@ func main() {
}

ipv6SetManager, err := set.ManagerInit(
&wg,
c,
ipv6Set,
ipSource.getIPList,
Expand All @@ -193,7 +193,6 @@ func main() {
}

portSetManager, err := set.ManagerInit(
&wg,
c,
portSet,
portSource.getPortList,
Expand All @@ -206,7 +205,6 @@ func main() {
}

ruleManager, err := rule.ManagerInit(
&wg,
c,
ruleTarget,
ruleInfo.createRuleData,
Expand All @@ -218,12 +216,25 @@ func main() {
logger.Default.Fatal(err)
}

go ipv4SetManager.Start()
go ipv6SetManager.Start()
go portSetManager.Start()
go ruleManager.Start()
eg.Go(func() error {
return ipv4SetManager.Start(gctx)
})

wg.Wait()
eg.Go(func() error {
return ipv6SetManager.Start(gctx)
})

eg.Go(func() error {
return portSetManager.Start(gctx)
})

eg.Go(func() error {
return ruleManager.Start(gctx)
})

if err := eg.Wait(); err != nil {
logger.Default.Fatal(err)
}
}
}

Expand Down
58 changes: 26 additions & 32 deletions pkg/rule/rule_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
package rule

import (
"context"
"os"
"os/signal"
"sync"
"syscall"
"time"

Expand All @@ -18,7 +18,6 @@ type RulesUpdateFunc func() ([]RuleData, error)

// Represents a table/chain ruleset managed by the manager goroutine
type ManagedRules struct {
WaitGroup *sync.WaitGroup
Conn *nftables.Conn
RuleTarget RuleTarget
rulesUpdateFunc RulesUpdateFunc
Expand All @@ -27,9 +26,8 @@ type ManagedRules struct {
}

// Create a rule manager
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) {
func ManagerInit(c *nftables.Conn, ruleTarget RuleTarget, f RulesUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedRules, error) {
return ManagedRules{
WaitGroup: wg,
Conn: c,
RuleTarget: ruleTarget,
rulesUpdateFunc: f,
Expand All @@ -39,43 +37,39 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, ruleTarget RuleTarget, f
}

// Start the rule manager goroutine
func (r *ManagedRules) Start() {
func (r *ManagedRules) Start(ctx context.Context) error {
joewilliams marked this conversation as resolved.
Show resolved Hide resolved
r.logger.Infof("starting rule manager for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
defer r.WaitGroup.Done()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

ticker := time.NewTicker(r.interval)
done := make(chan bool)

go func() {
for {
select {
case <-done:
return
case <-ticker.C:
ruleData, err := r.rulesUpdateFunc()
if err != nil {
r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
for {
select {
case <-ctx.Done():
r.logger.Infof("got context done, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
return nil
case sig := <-sigChan:
r.logger.Infof("got %s, stopping rule update loop for table/chain %v/%v", sig, r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
return nil
case <-ticker.C:
ruleData, err := r.rulesUpdateFunc()
if err != nil {
r.logger.Errorf("error with rules update function for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}

flush, err := r.RuleTarget.Update(r.Conn, ruleData)
if err != nil {
r.logger.Errorf("error updating rules: %v", err)
}
flush, err := r.RuleTarget.Update(r.Conn, ruleData)
if err != nil {
r.logger.Errorf("error updating rules: %v", err)
}

// only flush if things went well above
if flush {
r.logger.Infof("flushing rules for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
if err := r.Conn.Flush(); err != nil {
r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
// only flush if things went well above
if flush {
r.logger.Infof("flushing rules for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
if err := r.Conn.Flush(); err != nil {
r.logger.Errorf("error flushing rules for table/chain %v/%v: %v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name, err)
}
}
}
}()

<-sigChan
r.logger.Infof("got sigterm, stopping rule update loop for table/chain %v/%v", r.RuleTarget.Table.Name, r.RuleTarget.Chain.Name)
}
}
60 changes: 27 additions & 33 deletions pkg/set/set_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
package set

import (
"context"
"os"
"os/signal"
"sync"
"syscall"
"time"

Expand All @@ -18,7 +18,6 @@ type SetUpdateFunc func() ([]SetData, error)

// Represents a set managed by the manager goroutine
type ManagedSet struct {
WaitGroup *sync.WaitGroup
Conn *nftables.Conn
Set Set
setUpdateFunc SetUpdateFunc
Expand All @@ -27,9 +26,8 @@ type ManagedSet struct {
}

// Create a set manager
func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) {
func ManagerInit(c *nftables.Conn, set Set, f SetUpdateFunc, interval time.Duration, logger logger.Logger) (ManagedSet, error) {
return ManagedSet{
WaitGroup: wg,
Conn: c,
Set: set,
setUpdateFunc: f,
Expand All @@ -39,44 +37,40 @@ func ManagerInit(wg *sync.WaitGroup, c *nftables.Conn, set Set, f SetUpdateFunc,
}

// Start the set manager goroutine
func (s *ManagedSet) Start() {
func (s *ManagedSet) Start(ctx context.Context) error {
s.logger.Infof("starting set manager for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
defer s.WaitGroup.Done()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

ticker := time.NewTicker(s.interval)
done := make(chan bool)

go func() {
for {
select {
case <-done:
return
case <-ticker.C:
data, err := s.setUpdateFunc()
if err != nil {
s.logger.Errorf("error with set update function for table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}
for {
select {
case <-ctx.Done():
s.logger.Infof("got context done, stopping set update loop for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
return nil
case sig := <-sigChan:
s.logger.Infof("got %s, stopping set update loop for table/set %v/%v", sig, s.Set.Set.Table.Name, s.Set.Set.Name)
return nil
case <-ticker.C:
data, err := s.setUpdateFunc()
if err != nil {
s.logger.Errorf("error with set update function for table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}

flush, err := s.Set.UpdateElements(s.Conn, data)
if err != nil {
s.logger.Errorf("error updating table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}
flush, err := s.Set.UpdateElements(s.Conn, data)
if err != nil {
s.logger.Errorf("error updating table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
continue
}

// only flush if things went well above
if flush {
if err := s.Conn.Flush(); err != nil {
s.logger.Errorf("error flushing table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
}
// only flush if things went well above
if flush {
if err := s.Conn.Flush(); err != nil {
s.logger.Errorf("error flushing table/set %v/%v: %v", s.Set.Set.Table.Name, s.Set.Set.Name, err)
}
}
}
}()

<-sigChan
s.logger.Infof("got sigterm, stopping set update loop for table/set %v/%v", s.Set.Set.Table.Name, s.Set.Set.Name)
}
}