Skip to content

Commit

Permalink
Literal declarations for IPTables rules
Browse files Browse the repository at this point in the history
  • Loading branch information
davidefalcone1 committed Jun 15, 2021
1 parent 5005ada commit 4f9b464
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 55 deletions.
105 changes: 50 additions & 55 deletions pkg/liqonet/iptables/iptables.go
Expand Up @@ -61,12 +61,8 @@ const (
remoteNATPodCIDR = "RemoteNATPodCIDR"
)

// IPTableRule struct that holds all the information of an iptables rule.
type IPTableRule struct {
Table string
Chain string
RuleSpec []string
}
// IPTableRule is a slice of string. This is the format used by module go-iptables.
type IPTableRule []string

// IPTHandler a handler that exposes all the functions needed to configure the iptables chains and rules.
type IPTHandler struct {
Expand Down Expand Up @@ -132,7 +128,7 @@ func (h IPTHandler) createLiqoChains(chains map[string]string) error {
}

// Function that guarrantees Liqo rules are inserted.
func (h IPTHandler) ensureLiqoRules(rules map[string]string) error {
func (h IPTHandler) ensureLiqoRules(rules map[string]IPTableRule) error {
for chain, rule := range rules {
if err := h.insertLiqoRuleIfNotExists(chain, rule); err != nil {
return err
Expand All @@ -141,14 +137,14 @@ func (h IPTHandler) ensureLiqoRules(rules map[string]string) error {
return nil
}

// Function that returns the set of Liqo rules. Value is a set of rules
// and key is the chain the set of rules should be inserted in.
func getLiqoRules() map[string]string {
return map[string]string{
inputChain: fmt.Sprintf("-j %s", liqonetInputChain),
forwardChain: fmt.Sprintf("-j %s", liqonetForwardingChain),
preroutingChain: fmt.Sprintf("-j %s", liqonetPreroutingChain),
postroutingChain: fmt.Sprintf("-j %s", liqonetPostroutingChain),
// Function that returns the set of Liqo rules. Value is a rule
// and key is the chain the rule should be inserted in.
func getLiqoRules() map[string]IPTableRule {
return map[string]IPTableRule{
inputChain: {"-j", liqonetInputChain},
forwardChain: {"-j", liqonetForwardingChain},
preroutingChain: {"-j", liqonetPreroutingChain},
postroutingChain: {"-j", liqonetPostroutingChain},
}
}

Expand Down Expand Up @@ -187,7 +183,7 @@ func (h IPTHandler) removeLiqoRules() error {
// Get Liqo rules
liqoRules := getLiqoRules()
for chain, rule := range liqoRules {
if err := h.deleteRulesInChain(chain, []string{rule}); err != nil {
if err := h.deleteRulesInChain(chain, []IPTableRule{rule}); err != nil {
return err
}
}
Expand Down Expand Up @@ -405,18 +401,18 @@ func (h IPTHandler) deleteChainRulesPerCluster(tep *netv1alpha1.TunnelEndpoint)
return nil
}

func (h IPTHandler) deleteRulesInChain(chain string, rules []string) error {
func (h IPTHandler) deleteRulesInChain(chain string, rules []IPTableRule) error {
table := getTableFromChain(chain)
existingRules, err := h.ListRulesInChain(chain)
if err != nil {
return fmt.Errorf("unable to list rules in chain %s (table %s): %w", chain, table, err)
}
for _, rule := range rules {
if !slice.ContainsString(existingRules, rule, nil) {
if !slice.ContainsString(existingRules, strings.Join(rule, " "), nil) {
continue
}
// Rule exists, then delete it
if err := h.ipt.Delete(table, chain, strings.Split(rule, " ")...); err != nil {
if err := h.ipt.Delete(table, chain, rule...); err != nil {
return err
}
klog.Infof("Deleted rule %s in chain %s (table %s)", rule, chain, table)
Expand Down Expand Up @@ -491,23 +487,24 @@ func (h IPTHandler) EnsurePreroutingRules(tep *netv1alpha1.TunnelEndpoint) error
return h.updateRulesPerChain(getClusterPreRoutingChain(clusterID), rules)
}

func getPreRoutingRules(tep *netv1alpha1.TunnelEndpoint) ([]string, error) {
func getPreRoutingRules(tep *netv1alpha1.TunnelEndpoint) ([]IPTableRule, error) {
// Check tep fields
if err := checkTep(tep); err != nil {
return nil, fmt.Errorf("invalid TunnelEndpoint resource: %w", err)
}
localPodCIDR := tep.Status.LocalPodCIDR
localRemappedPodCIDR, remotePodCIDR := utils.GetPodCIDRS(tep)

rules := make([]string, 0)
rules := make([]IPTableRule, 0)
if localRemappedPodCIDR == consts.DefaultCIDRValue {
// Remote cluster has not remapped home PodCIDR,
// this means there is no need to NAT
return rules, nil
}
// Remote cluster has remapped home PodCIDR
rules = append(rules, fmt.Sprintf("-s %s -d %s -j %s --to %s",
remotePodCIDR, localRemappedPodCIDR, NETMAP, localPodCIDR))
rules = append(rules,
IPTableRule{"-s", remotePodCIDR, "-d", localRemappedPodCIDR, "-j", NETMAP, "--to", localPodCIDR},
)
return rules, nil
}

Expand Down Expand Up @@ -560,7 +557,7 @@ func (h IPTHandler) createIptablesChainIfNotExists(table, newChain string) error
return nil
}

func (h IPTHandler) insertLiqoRuleIfNotExists(chain, rule string) error {
func (h IPTHandler) insertLiqoRuleIfNotExists(chain string, rule IPTableRule) error {
table := getTableFromChain(chain)
// Get the list of rules for the specified chain
existingRules, err := h.ipt.List(table, chain)
Expand All @@ -570,37 +567,37 @@ func (h IPTHandler) insertLiqoRuleIfNotExists(chain, rule string) error {
// Check if the rule exists and at the same time if it exists more then once
numOccurrences := 0
for _, existingRule := range existingRules {
if strings.Contains(existingRule, rule) {
if strings.Contains(existingRule, strings.Join(rule, " ")) {
numOccurrences++
}
}
// If the occurrences if greater then one, remove the rule
if numOccurrences > 1 {
for i := 0; i < numOccurrences; i++ {
if err = h.ipt.Delete(table, chain, strings.Split(rule, " ")...); err != nil {
if err = h.ipt.Delete(table, chain, rule...); err != nil {
return fmt.Errorf("unable to delete iptable rule \"%s\": %w", rule, err)
}
}
if err = h.ipt.Insert(table, chain, 1, strings.Split(rule, " ")...); err != nil {
if err = h.ipt.Insert(table, chain, 1, rule...); err != nil {
return fmt.Errorf("unable to insert iptable rule \"%s\": %w", rule, err)
}
}
if numOccurrences == 1 {
// If the occurrence is one then check the position and if not at the first one we delete and reinsert it
if strings.Contains(existingRules[0], rule) {
if strings.Contains(existingRules[0], strings.Join(rule, " ")) {
return nil
}
if err = h.ipt.Delete(table, chain, strings.Split(rule, " ")...); err != nil {
if err = h.ipt.Delete(table, chain, rule...); err != nil {
return fmt.Errorf("unable to delete iptable rule \"%s\": %w", rule, err)
}
if err = h.ipt.Insert(table, chain, 1, strings.Split(rule, " ")...); err != nil {
if err = h.ipt.Insert(table, chain, 1, rule...); err != nil {
return fmt.Errorf("unable to inserte iptable rule \"%s\": %w", rule, err)
}
return nil
}
if numOccurrences == 0 {
// If the occurrence is zero then insert the rule in first position
if err = h.ipt.Insert(table, chain, 1, strings.Split(rule, " ")...); err != nil {
if err = h.ipt.Insert(table, chain, 1, rule...); err != nil {
return fmt.Errorf("unable to insert iptable rule \"%s\": %w", rule, err)
}
klog.Infof("Inserted rule '%s' in chain %s of table %s", rule, chain, table)
Expand All @@ -609,14 +606,14 @@ func (h IPTHandler) insertLiqoRuleIfNotExists(chain, rule string) error {
}

// Function to update specific rules in a given chain.
func (h IPTHandler) updateSpecificRulesPerChain(chain string, existingRules, newRules []string) error {
func (h IPTHandler) updateSpecificRulesPerChain(chain string, existingRules []string, newRules []IPTableRule) error {
table := getTableFromChain(chain)
for _, existingRule := range existingRules {
// Remove existing rules that are not in the set of new rules,
// they are outdated.
outdated := true
for _, newRule := range newRules {
if existingRule == newRule {
if strings.Contains(existingRule, strings.Join(newRule, " ")) {
outdated = false
}
}
Expand All @@ -636,23 +633,23 @@ func (h IPTHandler) updateSpecificRulesPerChain(chain string, existingRules, new
}

// Function to updates rules in a given chain.
func (h IPTHandler) updateRulesPerChain(chain string, newRules []string) error {
func (h IPTHandler) updateRulesPerChain(chain string, newRules []IPTableRule) error {
existingRules, err := h.ListRulesInChain(chain)
if err != nil {
return fmt.Errorf("cannot list rules in chain %s (table %s): %w", chain, getTableFromChain(chain), err)
}
return h.updateSpecificRulesPerChain(chain, existingRules, newRules)
}

func (h IPTHandler) insertRulesIfNotPresent(table, chain string, rules []string) error {
func (h IPTHandler) insertRulesIfNotPresent(table, chain string, rules []IPTableRule) error {
for _, rule := range rules {
exists, err := h.ipt.Exists(table, chain, strings.Split(rule, " ")...)
exists, err := h.ipt.Exists(table, chain, rule...)
if err != nil {
klog.Errorf("unable to check if rule '%s' exists in chain %s in table %s: %w", rule, chain, table, err)
return err
}
if !exists {
if err := h.ipt.AppendUnique(table, chain, strings.Split(rule, " ")...); err != nil {
if err := h.ipt.AppendUnique(table, chain, rule...); err != nil {
return err
}
klog.Infof("Inserting rule '%s' in chain %s in table %s", rule, chain, table)
Expand All @@ -661,7 +658,7 @@ func (h IPTHandler) insertRulesIfNotPresent(table, chain string, rules []string)
return nil
}

func getPostroutingRules(tep *netv1alpha1.TunnelEndpoint) ([]string, error) {
func getPostroutingRules(tep *netv1alpha1.TunnelEndpoint) ([]IPTableRule, error) {
if err := checkTep(tep); err != nil {
return nil, fmt.Errorf("invalid TunnelEndpoint resource: %w", err)
}
Expand All @@ -677,10 +674,9 @@ func getPostroutingRules(tep *netv1alpha1.TunnelEndpoint) ([]string, error) {
localRemappedPodCIDR, clusterID)
return nil, err
}
return []string{
fmt.Sprintf("-s %s -d %s -j %s --to %s",
localPodCIDR, remotePodCIDR, NETMAP, localRemappedPodCIDR),
fmt.Sprintf("! -s %s -d %s -j %s --to-source %s", localPodCIDR, remotePodCIDR, SNAT, natIP),
return []IPTableRule{
{"-s", localPodCIDR, "-d", remotePodCIDR, "-j", NETMAP, "--to", localRemappedPodCIDR},
{"!", "-s", localPodCIDR, "-d", remotePodCIDR, "-j", SNAT, "--to-source", natIP},
}, nil
}
// Get the first IP address from the podCIDR of the local cluster
Expand All @@ -690,9 +686,8 @@ func getPostroutingRules(tep *netv1alpha1.TunnelEndpoint) ([]string, error) {
tep.Spec.PodCIDR, clusterID)
return nil, err
}
return []string{
fmt.Sprintf("! -s %s -d %s -j %s --to-source %s",
localPodCIDR, remotePodCIDR, SNAT, natIP),
return []IPTableRule{
{"!", "-s", localPodCIDR, "-d", remotePodCIDR, "-j", SNAT, "--to-source", natIP},
}, nil
}

Expand Down Expand Up @@ -736,34 +731,34 @@ func checkTep(tep *netv1alpha1.TunnelEndpoint) error {
// Function that returns the set of rules used in Liqo chains (e.g. LIQO-PREROUTING)
// related to a remote cluster. Return value is a map of slices in which value
// is the a set of rules and key is the chain the set of rules should belong to.
func getChainRulesPerCluster(tep *netv1alpha1.TunnelEndpoint) (map[string][]string, error) {
func getChainRulesPerCluster(tep *netv1alpha1.TunnelEndpoint) (map[string][]IPTableRule, error) {
if err := checkTep(tep); err != nil {
return nil, fmt.Errorf("invalid TunnelEndpoint resource: %w", err)
}
clusterID := tep.Spec.ClusterID
localRemappedPodCIDR, remotePodCIDR := utils.GetPodCIDRS(tep)

// Init chain rules
chainRules := make(map[string][]string)
chainRules[liqonetPostroutingChain] = make([]string, 0)
chainRules[liqonetPreroutingChain] = make([]string, 0)
chainRules[liqonetForwardingChain] = make([]string, 0)
chainRules[liqonetInputChain] = make([]string, 0)
chainRules := make(map[string][]IPTableRule)
chainRules[liqonetPostroutingChain] = make([]IPTableRule, 0)
chainRules[liqonetPreroutingChain] = make([]IPTableRule, 0)
chainRules[liqonetForwardingChain] = make([]IPTableRule, 0)
chainRules[liqonetInputChain] = make([]IPTableRule, 0)

// For these rules, source in not necessary since
// the remotePodCIDR is unique in home cluster
chainRules[liqonetPostroutingChain] = append(chainRules[liqonetPostroutingChain],
fmt.Sprintf("-d %s -j %s", remotePodCIDR, getClusterPostRoutingChain(clusterID)))
IPTableRule{"-d", remotePodCIDR, "-j", getClusterPostRoutingChain(clusterID)})
chainRules[liqonetInputChain] = append(chainRules[liqonetInputChain],
fmt.Sprintf("-d %s -j %s", remotePodCIDR, getClusterInputChain(clusterID)))
IPTableRule{"-d", remotePodCIDR, "-j", getClusterInputChain(clusterID)})
chainRules[liqonetForwardingChain] = append(chainRules[liqonetForwardingChain],
fmt.Sprintf("-d %s -j %s", remotePodCIDR, getClusterForwardChain(clusterID)))
IPTableRule{"-d", remotePodCIDR, "-j", getClusterForwardChain(clusterID)})
if localRemappedPodCIDR != consts.DefaultCIDRValue {
// For the following rule, source is necessary
// because more remote clusters could have
// remapped home PodCIDR in the same way, then only use dst is not enough.
chainRules[liqonetPreroutingChain] = append(chainRules[liqonetPreroutingChain],
fmt.Sprintf("-s %s -d %s -j %s", remotePodCIDR, localRemappedPodCIDR, getClusterPreRoutingChain(clusterID)))
IPTableRule{"-s", remotePodCIDR, "-d", localRemappedPodCIDR, "-j", getClusterPreRoutingChain(clusterID)})
}
return chainRules, nil
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/liqonet/iptables/iptables_test.go
Expand Up @@ -55,6 +55,7 @@ var _ = Describe("iptables", func() {
postRoutingRules, err := h.ListRulesInChain(postroutingChain)
Expect(err).To(BeNil())
Expect(postRoutingRules).To(ContainElement(fmt.Sprintf("-j %s", liqonetPostroutingChain)))
Expect(postRoutingRules).To(ContainElement(fmt.Sprintf("-j %s", MASQUERADE)))

// Check existence of LIQO-PREROUTING chain
Expect(natChains).To(ContainElement(liqonetPostroutingChain))
Expand Down Expand Up @@ -401,6 +402,7 @@ var _ = Describe("iptables", func() {
Expect(err).To(BeNil())
Expect(postRoutingRules).ToNot(ContainElements([]string{
fmt.Sprintf("-j %s", liqonetPostroutingChain),
fmt.Sprintf("-j %s", MASQUERADE),
}))

preRoutingRules, err := h.ListRulesInChain(preroutingChain)
Expand Down

0 comments on commit 4f9b464

Please sign in to comment.