diff --git a/backends/nft/dnat-rule.go b/backends/nft/dnat-rule.go index 7fe7bd711..9bce74b2a 100644 --- a/backends/nft/dnat-rule.go +++ b/backends/nft/dnat-rule.go @@ -45,16 +45,8 @@ func writeUint64(w *bytes.Buffer, v uint64) (int, error) { } func (d dnatRule) WriteTo(rule *bytes.Buffer, endpointsMap string, endpointsOffset uint64) { - var protoMatch string - switch d.Protocol { - case localnetv1.Protocol_TCP: - protoMatch = "tcp dport" - case localnetv1.Protocol_UDP: - protoMatch = "udp dport" - case localnetv1.Protocol_SCTP: - protoMatch = "sctp dport" - default: - klog.Errorf("unknown protocol: %v", d.Protocol) + protoMatch := protoMatch(d.Protocol) + if protoMatch == "" { return } @@ -71,13 +63,21 @@ func (d dnatRule) WriteTo(rule *bytes.Buffer, endpointsMap string, endpointsOffs return } - // printf is nice but take 50% on CPU time so... optimize! + // printf is nice but takes 50% on CPU time so... optimize! for _, port := range ports { rule.WriteString(" ") rule.WriteString(protoMatch) - rule.WriteByte(' ') - writeInt32(rule, port.Port) + if port.NodePort == 0 { + rule.WriteByte(' ') + writeInt32(rule, port.Port) + } else { + rule.WriteString(" { ") + writeInt32(rule, port.Port) + rule.WriteString(", ") + writeInt32(rule, port.NodePort) + rule.WriteByte('}') + } // handle reject case if len(d.EndpointIPs) == 0 { @@ -101,7 +101,7 @@ func (d dnatRule) WriteTo(rule *bytes.Buffer, endpointsMap string, endpointsOffs rule.WriteString(endpointsMap) } - if port.Port != port.TargetPort { + if port.Port != port.TargetPort || port.Port != port.NodePort { rule.WriteByte(':') writeInt32(rule, port.TargetPort) } @@ -111,3 +111,17 @@ func (d dnatRule) WriteTo(rule *bytes.Buffer, endpointsMap string, endpointsOffs return } + +func protoMatch(protocol localnetv1.Protocol) string { + switch protocol { + case localnetv1.Protocol_TCP: + return "tcp dport" + case localnetv1.Protocol_UDP: + return "udp dport" + case localnetv1.Protocol_SCTP: + return "sctp dport" + default: + klog.Errorf("unknown protocol: %v", protocol) + return "" + } +} diff --git a/backends/nft/nft.go b/backends/nft/nft.go index fa0383af4..86172e374 100644 --- a/backends/nft/nft.go +++ b/backends/nft/nft.go @@ -375,6 +375,16 @@ func Callback(ch <-chan *client.ServiceEndpoints) { } } + // handle node ports + for _, port := range svc.Ports { + if port.NodePort == 0 { + continue + } + + chain := chainBuffers.Get("chain", "nodeports") + fmt.Fprintf(chain, " "+protoMatch(port.Protocol)+" %d jump %s\n", port.NodePort, svc_chain) + } + // handle external IPs dispatch extIPs := svc.IPs.ExternalIPs.V4 if set.v6 { @@ -480,13 +490,21 @@ func addDispatchChains(family string, chainBuffers *chainBufferSet) { if chainBuffers.Get("chain", "dnat_external").Len() != 0 { fmt.Fprint(chainBuffers.Get("chain", "z_dnat_all"), " jump dnat_external\n") } + + chain := chainBuffers.Get("chain", "hook_nat_prerouting") + fmt.Fprintf(chain, " type nat hook prerouting priority %d;\n", *hookPrio) + if chainBuffers.Get("chain", "z_dnat_all").Len() != 0 { - fmt.Fprintf(chainBuffers.Get("chain", "hook_nat_prerouting"), - " type nat hook prerouting priority %d;\n jump z_dnat_all\n", *hookPrio) + chain.WriteString(" jump z_dnat_all\n") fmt.Fprintf(chainBuffers.Get("chain", "hook_nat_output"), " type nat hook output priority %d;\n jump z_dnat_all\n", *hookPrio) } + if chainBuffers.Get("chain", "nodeports").Len() != 0 { + // nodeports has a fib match valid only in prerouting + chain.WriteString(" fib daddr . iif type local jump nodeports\n") + } + // filtering if chainBuffers.Get("chain", "filter_external").Len() != 0 { fmt.Fprint(chainBuffers.Get("chain", "z_filter_all"), " jump filter_external\n")