Skip to content

Commit

Permalink
Merge pull request #114 from costasd/allow_supplying_command
Browse files Browse the repository at this point in the history
Feat: Support overriding binaries' path
  • Loading branch information
squeed committed Nov 2, 2023
2 parents 60f4899 + f61413f commit 50d824b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
26 changes: 23 additions & 3 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,20 @@ func Timeout(timeout int) option {
}
}

// New creates a new IPTables configured with the options passed as parameter.
// For backwards compatibility, by default always uses IPv4 and timeout 0.
func Path(path string) option {
return func(ipt *IPTables) {
ipt.path = path
}
}

// New creates a new IPTables configured with the options passed as parameters.
// Supported parameters are:
//
// IPFamily(Protocol)
// Timeout(int)
// Path(string)
//
// For backwards compatibility, by default New uses IPv4 and timeout 0.
// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
// the IPFamily and Timeout options as follow:
//
Expand All @@ -123,13 +135,21 @@ func New(opts ...option) (*IPTables, error) {
ipt := &IPTables{
proto: ProtocolIPv4,
timeout: 0,
path: "",
}

for _, opt := range opts {
opt(ipt)
}

path, err := exec.LookPath(getIptablesCommand(ipt.proto))
// if path wasn't preset through New(Path()), autodiscover it
cmd := ""
if ipt.path == "" {
cmd = getIptablesCommand(ipt.proto)
} else {
cmd = ipt.path
}
path, err := exec.LookPath(cmd)
if err != nil {
return nil, err
}
Expand Down
48 changes: 48 additions & 0 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,54 @@ func TestTimeout(t *testing.T) {

}

// force usage of -legacy or -nft commands and check that they're detected correctly
func TestLegacyDetection(t *testing.T) {
testCases := []struct {
in string
mode string
err bool
}{
{
"iptables-legacy",
"legacy",
false,
},
{
"ip6tables-legacy",
"legacy",
false,
},
{
"iptables-nft",
"nf_tables",
false,
},
{
"ip6tables-nft",
"nf_tables",
false,
},
}

for i, tt := range testCases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
ipt, err := New(Path(tt.in))
if err == nil && tt.err {
t.Fatal("expected err, got none")
} else if err != nil && !tt.err {
t.Fatalf("unexpected err %s", err)
}

if !strings.Contains(ipt.path, tt.in) {
t.Fatalf("Expected path %s in %s", tt.in, ipt.path)
}
if ipt.mode != tt.mode {
t.Fatalf("Expected %s iptables, but got %s", tt.mode, ipt.mode)
}
})
}
}

func randChain(t *testing.T) string {
n, err := rand.Int(rand.Reader, big.NewInt(1000000))
if err != nil {
Expand Down

0 comments on commit 50d824b

Please sign in to comment.