Skip to content

Commit

Permalink
fqdn: Add Protocol to DNS Proxy Cache
Browse files Browse the repository at this point in the history
DNS Proxy indexes domain selectors by port
only. In cases where protocols collide on port
the DNS proxy may have a more restrictive selector
than it should because it does not merge port
protocols for L7 policies (only ports).

All callers of the DNS Proxy are updated
to add protocol to any DNS Proxy entries, and all
tests are updated to test for port-protocol
merge errors.

Signed-off-by: Nate Sweet <nathanjsweet@pm.me>
  • Loading branch information
nathanjsweet committed Mar 26, 2024
1 parent 1941679 commit bc7fbf3
Show file tree
Hide file tree
Showing 14 changed files with 405 additions and 273 deletions.
18 changes: 12 additions & 6 deletions pkg/endpoint/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,14 @@ func (e *Endpoint) writeHeaderfile(prefix string) error {
// instead of returning a 0 port number.
type proxyPolicy struct {
*policy.L4Filter
ps *policy.PerSelectorPolicy
port uint16
ps *policy.PerSelectorPolicy
port uint16
protocol uint8
}

// newProxyPolicy returns a new instance of proxyPolicy by value
func (e *Endpoint) newProxyPolicy(l4 *policy.L4Filter, ps *policy.PerSelectorPolicy, port uint16) proxyPolicy {
return proxyPolicy{L4Filter: l4, ps: ps, port: port}
func (e *Endpoint) newProxyPolicy(l4 *policy.L4Filter, ps *policy.PerSelectorPolicy, port uint16, proto uint8) proxyPolicy {
return proxyPolicy{L4Filter: l4, ps: ps, port: port, protocol: proto}
}

// GetPort returns the destination port number on which the proxy policy applies
Expand All @@ -191,6 +192,11 @@ func (p *proxyPolicy) GetPort() uint16 {
return p.port
}

// GetProtocol returns the destination protocol number on which the proxy policy applies
func (p *proxyPolicy) GetProtocol() uint8 {
return p.protocol
}

// GetListener returns the listener name referenced by the policy, if any
func (p *proxyPolicy) GetListener() string {
return p.ps.GetListener()
Expand Down Expand Up @@ -226,7 +232,7 @@ func (e *Endpoint) addNewRedirectsFromDesiredPolicy(ingress bool, desiredRedirec
}
// proxyID() returns also the destination port for the policy,
// which may be resolved from a named port
proxyID, dstPort := e.proxyID(l4, v.Listener)
proxyID, dstPort, dstProto := e.proxyID(l4, v.Listener)
if proxyID == "" {
// Skip redirects for which a proxyID cannot be created.
// This may happen due to the named port mapping not
Expand All @@ -244,7 +250,7 @@ func (e *Endpoint) addNewRedirectsFromDesiredPolicy(ingress bool, desiredRedirec

var redirectPort uint16

pp := e.newProxyPolicy(l4, v, dstPort)
pp := e.newProxyPolicy(l4, v, dstPort, dstProto)
proxyPort, err, finalizeFunc, revertFunc := e.proxy.CreateOrUpdateRedirect(e.aliveCtx, &pp, proxyID, e, proxyWaitGroup)
if err != nil {
// Skip redirects that can not be created or updated. This
Expand Down
9 changes: 6 additions & 3 deletions pkg/endpoint/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,10 @@ func (s *EndpointSuite) TestProxyID(c *C) {
e := &Endpoint{ID: 123, policyRevision: 0}
e.UpdateLogger(nil)

id, port := e.proxyID(&policy.L4Filter{Port: 8080, Protocol: api.ProtoTCP, Ingress: true}, "")
id, port, proto := e.proxyID(&policy.L4Filter{Port: 8080, Protocol: api.ProtoTCP, Ingress: true}, "")
c.Assert(id, Not(Equals), "")
c.Assert(port, Equals, uint16(8080))
c.Assert(proto, Equals, uint8(6))

endpointID, ingress, protocol, port, listener, err := policy.ParseProxyID(id)
c.Assert(endpointID, Equals, uint16(123))
Expand All @@ -537,9 +538,10 @@ func (s *EndpointSuite) TestProxyID(c *C) {
c.Assert(listener, Equals, "")
c.Assert(err, IsNil)

id, port = e.proxyID(&policy.L4Filter{Port: 8080, Protocol: api.ProtoTCP, Ingress: true, L7Parser: policy.ParserTypeCRD}, "test-listener")
id, port, proto = e.proxyID(&policy.L4Filter{Port: 8080, Protocol: api.ProtoTCP, Ingress: true, L7Parser: policy.ParserTypeCRD}, "test-listener")
c.Assert(id, Not(Equals), "")
c.Assert(port, Equals, uint16(8080))
c.Assert(proto, Equals, uint8(6))
endpointID, ingress, protocol, port, listener, err = policy.ParseProxyID(id)
c.Assert(endpointID, Equals, uint16(123))
c.Assert(ingress, Equals, true)
Expand All @@ -549,9 +551,10 @@ func (s *EndpointSuite) TestProxyID(c *C) {
c.Assert(err, IsNil)

// Undefined named port
id, port = e.proxyID(&policy.L4Filter{PortName: "foobar", Protocol: api.ProtoTCP, Ingress: true}, "")
id, port, proto = e.proxyID(&policy.L4Filter{PortName: "foobar", Protocol: api.ProtoTCP, Ingress: true}, "")
c.Assert(id, Equals, "")
c.Assert(port, Equals, uint16(0))
c.Assert(proto, Equals, uint8(0))
}

func TestEndpoint_GetK8sPodLabels(t *testing.T) {
Expand Down
17 changes: 12 additions & 5 deletions pkg/endpoint/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,25 @@ func (e *Endpoint) getNamedPortEgress(npMap types.NamedPortMultiMap, name string
}

// proxyID returns a unique string to identify a proxy mapping,
// and the resolved destination port number, if any.
// and the resolved destination port and protocol numbers, if any.
// Must be called with e.mutex held.
func (e *Endpoint) proxyID(l4 *policy.L4Filter, listener string) (string, uint16) {
func (e *Endpoint) proxyID(l4 *policy.L4Filter, listener string) (string, uint16, uint8) {
port := uint16(l4.Port)
protocol := uint8(l4.U8Proto)
// Calculate protocol if it is 0 (default) and
// is not "ANY" (that is, it was not calculated).
if protocol == 0 && !l4.Protocol.IsAny() {
proto, _ := u8proto.ParseProtocol(string(l4.Protocol))
protocol = uint8(proto)
}
if port == 0 && l4.PortName != "" {
port = e.GetNamedPort(l4.Ingress, l4.PortName, uint8(l4.U8Proto))
port = e.GetNamedPort(l4.Ingress, l4.PortName, protocol)
if port == 0 {
return "", 0
return "", 0, 0
}
}

return policy.ProxyID(e.ID, l4.Ingress, string(l4.Protocol), port, listener), port
return policy.ProxyID(e.ID, l4.Ingress, string(l4.Protocol), port, listener), port, protocol
}

var unrealizedRedirect = errors.New("Proxy port for redirect not found")
Expand Down
9 changes: 5 additions & 4 deletions pkg/fqdn/dnsproxy/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ import (
"github.com/cilium/dns"

"github.com/cilium/cilium/pkg/fqdn/restore"
"github.com/cilium/cilium/pkg/u8proto"
)

// lookupTargetDNSServer finds the intended DNS target server for a specific
// request (passed in via ServeDNS). The IP:port combination is
// request (passed in via ServeDNS). The IP:port:protocol combination is
// returned.
func lookupTargetDNSServer(w dns.ResponseWriter) (serverIP net.IP, serverPort restore.PortProto, addrStr string, err error) {
func lookupTargetDNSServer(w dns.ResponseWriter) (serverIP net.IP, serverPortProto restore.PortProto, addrStr string, err error) {
switch addr := (w.LocalAddr()).(type) {
case *net.UDPAddr:
return addr.IP, restore.PortProto(addr.Port), addr.String(), nil
return addr.IP, restore.MakeV2PortProto(uint16(addr.Port), uint8(u8proto.UDP)), addr.String(), nil
case *net.TCPAddr:
return addr.IP, restore.PortProto(addr.Port), addr.String(), nil
return addr.IP, restore.MakeV2PortProto(uint16(addr.Port), uint8(u8proto.TCP)), addr.String(), nil
default:
return nil, 0, addr.String(), fmt.Errorf("Cannot extract address information for type %T: %+v", addr, addr)
}
Expand Down
34 changes: 23 additions & 11 deletions pkg/fqdn/dnsproxy/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ import (
"github.com/cilium/cilium/pkg/defaults"
"github.com/cilium/cilium/pkg/fqdn/dns"
"github.com/cilium/cilium/pkg/fqdn/re"
"github.com/cilium/cilium/pkg/fqdn/restore"
"github.com/cilium/cilium/pkg/identity"
"github.com/cilium/cilium/pkg/labels"
"github.com/cilium/cilium/pkg/policy"
"github.com/cilium/cilium/pkg/policy/api"
"github.com/cilium/cilium/pkg/u8proto"
)

const (
udpProto = uint8(u8proto.UDP)
tcpProto = uint8(u8proto.TCP)
)

type DNSProxyHelperTestSuite struct{}
Expand All @@ -34,6 +41,8 @@ func (s *DNSProxyHelperTestSuite) TestSetPortRulesForID(c *C) {
epID := uint64(1)
pea := perEPAllow{}
cache := make(regexCache)
udpProtoPort8053 := restore.MakeV2PortProto(8053, udpProto)

rules[new(MockCachedSelector)] = &policy.PerSelectorPolicy{
L7Rules: api.L7Rules{
DNS: []api.PortRuleDNS{
Expand All @@ -42,7 +51,8 @@ func (s *DNSProxyHelperTestSuite) TestSetPortRulesForID(c *C) {
},
},
}
err := pea.setPortRulesForID(cache, epID, 8053, rules)

err := pea.setPortRulesForID(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 1)

Expand All @@ -56,16 +66,17 @@ func (s *DNSProxyHelperTestSuite) TestSetPortRulesForID(c *C) {
},
},
}
err = pea.setPortRulesForID(cache, epID, 8053, rules)

err = pea.setPortRulesForID(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 2)

delete(rules, selector2)
err = pea.setPortRulesForID(cache, epID, 8053, rules)
err = pea.setPortRulesForID(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 1)

err = pea.setPortRulesForID(cache, epID, 8053, nil)
err = pea.setPortRulesForID(cache, epID, udpProtoPort8053, nil)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 0)

Expand All @@ -79,7 +90,7 @@ func (s *DNSProxyHelperTestSuite) TestSetPortRulesForID(c *C) {
},
},
}
err = pea.setPortRulesForID(cache, epID, 8053, rules)
err = pea.setPortRulesForID(cache, epID, udpProtoPort8053, rules)

c.Assert(err, NotNil)
c.Assert(len(cache), Equals, 0)
Expand All @@ -91,34 +102,35 @@ func (s *DNSProxyHelperTestSuite) TestSetPortRulesForIDFromUnifiedFormat(c *C) {
epID := uint64(1)
pea := perEPAllow{}
cache := make(regexCache)
udpProtoPort8053 := restore.MakeV2PortProto(8053, udpProto)
rules[new(MockCachedSelector)] = regexp.MustCompile("^.*[.]cilium[.]io$")
rules[new(MockCachedSelector)] = regexp.MustCompile("^.*[.]cilium[.]io$")

err := pea.setPortRulesForIDFromUnifiedFormat(cache, epID, 8053, rules)
err := pea.setPortRulesForIDFromUnifiedFormat(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 1)

selector2 := new(MockCachedSelector)
rules[selector2] = regexp.MustCompile("^sub[.]cilium[.]io")
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, 8053, rules)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 2)

delete(rules, selector2)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, 8053, rules)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 1)

err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, 8053, nil)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, udpProtoPort8053, nil)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 0)

delete(rules, selector2)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, 8053, rules)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, udpProtoPort8053, rules)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 1)

err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, 8053, nil)
err = pea.setPortRulesForIDFromUnifiedFormat(cache, epID, udpProtoPort8053, nil)
c.Assert(err, Equals, nil)
c.Assert(len(cache), Equals, 0)
}
Expand Down

0 comments on commit bc7fbf3

Please sign in to comment.