Skip to content

Commit

Permalink
refactor(firewall): move duplicated rule parsing code into own functi…
Browse files Browse the repository at this point in the history
…ons (#796)
  • Loading branch information
phm07 committed Jun 27, 2024
1 parent 186e182 commit 12354bb
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 196 deletions.
125 changes: 73 additions & 52 deletions internal/cmd/firewall/add_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"

"github.com/spf13/cobra"
"github.com/spf13/pflag"

"github.com/hetznercloud/cli/internal/cmd/base"
"github.com/hetznercloud/cli/internal/cmd/cmpl"
Expand Down Expand Up @@ -40,13 +41,6 @@ var AddRuleCmd = base.Cmd{
return cmd
},
Run: func(s state.State, cmd *cobra.Command, args []string) error {
direction, _ := cmd.Flags().GetString("direction")
protocol, _ := cmd.Flags().GetString("protocol")
sourceIPs, _ := cmd.Flags().GetStringArray("source-ips")
destinationIPs, _ := cmd.Flags().GetStringArray("destination-ips")
port, _ := cmd.Flags().GetString("port")
description, _ := cmd.Flags().GetString("description")

idOrName := args[0]
firewall, _, err := s.Client().Firewall().Get(s, idOrName)
if err != nil {
Expand All @@ -56,53 +50,12 @@ var AddRuleCmd = base.Cmd{
return fmt.Errorf("Firewall not found: %v", idOrName)
}

d := hcloud.FirewallRuleDirection(direction)
rule := hcloud.FirewallRule{
Direction: d,
Protocol: hcloud.FirewallRuleProtocol(protocol),
}

if port != "" {
rule.Port = hcloud.String(port)
}

if description != "" {
rule.Description = hcloud.String(description)
}

switch rule.Protocol {
case hcloud.FirewallRuleProtocolUDP, hcloud.FirewallRuleProtocolTCP:
if port == "" {
return fmt.Errorf("port is required (--port)")
}
default:
if port != "" {
return fmt.Errorf("port is not allowed for this protocol")
}
}

switch d {
case hcloud.FirewallRuleDirectionOut:
rule.DestinationIPs = make([]net.IPNet, len(destinationIPs))
for i, ip := range destinationIPs {
n, err := ValidateFirewallIP(ip)
if err != nil {
return fmt.Errorf("destination error on index %d: %s", i, err)
}
rule.DestinationIPs[i] = *n
}
case hcloud.FirewallRuleDirectionIn:
rule.SourceIPs = make([]net.IPNet, len(sourceIPs))
for i, ip := range sourceIPs {
n, err := ValidateFirewallIP(ip)
if err != nil {
return fmt.Errorf("source ips error on index %d: %s", i, err)
}
rule.SourceIPs[i] = *n
}
rule, err := parseRuleFromArgs(cmd.Flags())
if err != nil {
return err
}

rules := append(firewall.Rules, rule)
rules := append(firewall.Rules, *rule)

actions, _, err := s.Client().Firewall().SetRules(s, firewall,
hcloud.FirewallSetRulesOpts{Rules: rules},
Expand All @@ -119,3 +72,71 @@ var AddRuleCmd = base.Cmd{
return nil
},
}

func parseRuleFromArgs(flags *pflag.FlagSet) (*hcloud.FirewallRule, error) {
direction, _ := flags.GetString("direction")
protocol, _ := flags.GetString("protocol")
sourceIPs, _ := flags.GetStringArray("source-ips")
destinationIPs, _ := flags.GetStringArray("destination-ips")
port, _ := flags.GetString("port")
description, _ := flags.GetString("description")

rule := &hcloud.FirewallRule{
SourceIPs: make([]net.IPNet, 0),
DestinationIPs: make([]net.IPNet, 0),
}

switch hcloud.FirewallRuleDirection(direction) {
case hcloud.FirewallRuleDirectionIn, hcloud.FirewallRuleDirectionOut:
rule.Direction = hcloud.FirewallRuleDirection(direction)
default:
return nil, fmt.Errorf("invalid direction: %s", direction)
}

switch hcloud.FirewallRuleProtocol(protocol) {
case hcloud.FirewallRuleProtocolTCP, hcloud.FirewallRuleProtocolUDP, hcloud.FirewallRuleProtocolICMP, hcloud.FirewallRuleProtocolESP, hcloud.FirewallRuleProtocolGRE:
rule.Protocol = hcloud.FirewallRuleProtocol(protocol)
default:
return nil, fmt.Errorf("invalid protocol: %s", protocol)
}

if port != "" {
rule.Port = hcloud.Ptr(port)
}

if description != "" {
rule.Description = hcloud.Ptr(description)
}

switch rule.Protocol {
case hcloud.FirewallRuleProtocolUDP, hcloud.FirewallRuleProtocolTCP:
if port == "" {
return nil, fmt.Errorf("port is required (--port)")
}
default:
if port != "" {
return nil, fmt.Errorf("port is not allowed for this protocol")
}
}

switch rule.Direction {
case hcloud.FirewallRuleDirectionOut:
for i, ip := range destinationIPs {
n, err := ValidateFirewallIP(ip)
if err != nil {
return nil, fmt.Errorf("destination error on index %d: %s", i, err)
}
rule.DestinationIPs = append(rule.DestinationIPs, *n)
}
case hcloud.FirewallRuleDirectionIn:
for i, ip := range sourceIPs {
n, err := ValidateFirewallIP(ip)
if err != nil {
return nil, fmt.Errorf("source ips error on index %d: %s", i, err)
}
rule.SourceIPs = append(rule.SourceIPs, *n)
}
}

return rule, nil
}
2 changes: 1 addition & 1 deletion internal/cmd/firewall/add_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestAddRule(t *testing.T) {
Rules: []hcloud.FirewallRule{{
Direction: hcloud.FirewallRuleDirectionIn,
SourceIPs: []net.IPNet{{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}},
DestinationIPs: nil,
DestinationIPs: []net.IPNet{},
Protocol: hcloud.FirewallRuleProtocolTCP,
Port: hcloud.Ptr("80"),
Description: hcloud.Ptr("http"),
Expand Down
95 changes: 54 additions & 41 deletions internal/cmd/firewall/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package firewall
import (
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net"
"os"

Expand Down Expand Up @@ -41,49 +41,12 @@ var CreateCmd = base.CreateCmd{
}

rulesFile, _ := cmd.Flags().GetString("rules-file")

if len(rulesFile) > 0 {
var data []byte
var err error
if rulesFile == "-" {
data, err = ioutil.ReadAll(os.Stdin)
} else {
data, err = ioutil.ReadFile(rulesFile)
}
if err != nil {
return nil, nil, err
}
var rules []schema.FirewallRule
err = json.Unmarshal(data, &rules)
if rulesFile != "" {
rules, err := parseRulesFile(rulesFile)
if err != nil {
return nil, nil, err
}
for _, rule := range rules {
var sourceNets []net.IPNet
for i, sourceIP := range rule.SourceIPs {
_, sourceNet, err := net.ParseCIDR(sourceIP)
if err != nil {
return nil, nil, fmt.Errorf("invalid CIDR on index %d : %s", i, err)
}
sourceNets = append(sourceNets, *sourceNet)
}
var destNets []net.IPNet
for i, destIP := range rule.DestinationIPs {
_, destNet, err := net.ParseCIDR(destIP)
if err != nil {
return nil, nil, fmt.Errorf("invalid CIDR on index %d : %s", i, err)
}
destNets = append(destNets, *destNet)
}
opts.Rules = append(opts.Rules, hcloud.FirewallRule{
Direction: hcloud.FirewallRuleDirection(rule.Direction),
SourceIPs: sourceNets,
DestinationIPs: destNets,
Protocol: hcloud.FirewallRuleProtocol(rule.Protocol),
Port: rule.Port,
Description: rule.Description,
})
}
opts.Rules = rules
}

result, _, err := s.Client().Firewall().Create(s, opts)
Expand All @@ -100,3 +63,53 @@ var CreateCmd = base.CreateCmd{
return result.Firewall, util.Wrap("firewall", hcloud.SchemaFromFirewall(result.Firewall)), err
},
}

func parseRulesFile(path string) ([]hcloud.FirewallRule, error) {
var (
data []byte
err error
)
if path == "-" {
data, err = io.ReadAll(os.Stdin)
} else {
data, err = os.ReadFile(path)
}
if err != nil {
return nil, err
}

var ruleSchemas []schema.FirewallRule
err = json.Unmarshal(data, &ruleSchemas)
if err != nil {
return nil, err
}

rules := make([]hcloud.FirewallRule, 0, len(ruleSchemas))
for _, rule := range ruleSchemas {
var sourceNets []net.IPNet
for i, sourceIP := range rule.SourceIPs {
_, sourceNet, err := net.ParseCIDR(sourceIP)
if err != nil {
return nil, fmt.Errorf("invalid CIDR on index %d : %s", i, err)
}
sourceNets = append(sourceNets, *sourceNet)
}
var destNets []net.IPNet
for i, destIP := range rule.DestinationIPs {
_, destNet, err := net.ParseCIDR(destIP)
if err != nil {
return nil, fmt.Errorf("invalid CIDR on index %d : %s", i, err)
}
destNets = append(destNets, *destNet)
}
rules = append(rules, hcloud.FirewallRule{
Direction: hcloud.FirewallRuleDirection(rule.Direction),
SourceIPs: sourceNets,
DestinationIPs: destNets,
Protocol: hcloud.FirewallRuleProtocol(rule.Protocol),
Port: rule.Port,
Description: rule.Description,
})
}
return rules, nil
}
57 changes: 4 additions & 53 deletions internal/cmd/firewall/delete_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package firewall

import (
"fmt"
"net"
"reflect"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -41,13 +40,6 @@ var DeleteRuleCmd = base.Cmd{
return cmd
},
Run: func(s state.State, cmd *cobra.Command, args []string) error {
direction, _ := cmd.Flags().GetString("direction")
protocol, _ := cmd.Flags().GetString("protocol")
sourceIPs, _ := cmd.Flags().GetStringArray("source-ips")
destinationIPs, _ := cmd.Flags().GetStringArray("destination-ips")
port, _ := cmd.Flags().GetString("port")
description, _ := cmd.Flags().GetString("description")

idOrName := args[0]
firewall, _, err := s.Client().Firewall().Get(s, idOrName)
if err != nil {
Expand All @@ -57,55 +49,14 @@ var DeleteRuleCmd = base.Cmd{
return fmt.Errorf("Firewall not found: %v", idOrName)
}

d := hcloud.FirewallRuleDirection(direction)
rule := hcloud.FirewallRule{
Direction: d,
Protocol: hcloud.FirewallRuleProtocol(protocol),
}
if port != "" {
rule.Port = hcloud.String(port)
}
if description != "" {
rule.Description = hcloud.String(description)
}

switch rule.Protocol {
case hcloud.FirewallRuleProtocolTCP, hcloud.FirewallRuleProtocolUDP:
if port == "" {
return fmt.Errorf("port is required (--port)")
}
default:
if port != "" {
return fmt.Errorf("port is not allowed for this protocol")
}
}

switch d {
case hcloud.FirewallRuleDirectionOut:
rule.DestinationIPs = make([]net.IPNet, len(destinationIPs))
for i, ip := range destinationIPs {
n, err := ValidateFirewallIP(ip)
if err != nil {
return fmt.Errorf("destination ips error on index %d: %s", i, err)
}
rule.DestinationIPs[i] = *n
rule.SourceIPs = make([]net.IPNet, 0)
}
case hcloud.FirewallRuleDirectionIn:
rule.SourceIPs = make([]net.IPNet, len(sourceIPs))
for i, ip := range sourceIPs {
n, err := ValidateFirewallIP(ip)
if err != nil {
return fmt.Errorf("source ips error on index %d: %s", i, err)
}
rule.DestinationIPs = make([]net.IPNet, 0)
rule.SourceIPs[i] = *n
}
rule, err := parseRuleFromArgs(cmd.Flags())
if err != nil {
return err
}

var rules = make([]hcloud.FirewallRule, 0)
for _, existingRule := range firewall.Rules {
if !reflect.DeepEqual(existingRule, rule) {
if !reflect.DeepEqual(existingRule, *rule) {
rules = append(rules, existingRule)
}
}
Expand Down
Loading

0 comments on commit 12354bb

Please sign in to comment.