Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hubble: Use netip.Addr instead of net.IP in getter functions #23143

Merged
merged 3 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions daemon/cmd/fqdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
"github.com/cilium/cilium/pkg/fqdn/re"
"github.com/cilium/cilium/pkg/identity"
secIDCache "github.com/cilium/cilium/pkg/identity/cache"
"github.com/cilium/cilium/pkg/ip"
ippkg "github.com/cilium/cilium/pkg/ip"
"github.com/cilium/cilium/pkg/logging/logfields"
"github.com/cilium/cilium/pkg/metrics"
"github.com/cilium/cilium/pkg/option"
Expand Down Expand Up @@ -398,7 +398,11 @@ func (d *Daemon) updateSelectors(ctx context.Context, selectorWithIPsToUpdate ma

// lookupEPByIP returns the endpoint that this IP belongs to
func (d *Daemon) lookupEPByIP(endpointIP net.IP) (endpoint *endpoint.Endpoint, err error) {
e := d.endpointManager.LookupIP(endpointIP)
endpointAddr, ok := ippkg.AddrFromIP(endpointIP)
if !ok {
return nil, fmt.Errorf("invalid IP %s for endpoint lookup", endpointIP)
}
e := d.endpointManager.LookupIP(endpointAddr)
if e == nil {
return nil, fmt.Errorf("Cannot find endpoint with IP %s", endpointIP.String())
}
Expand Down Expand Up @@ -581,7 +585,7 @@ func (d *Daemon) notifyOnDNSMsg(lookupTime time.Time, ep *endpoint.Endpoint, epI
// doesn't happen in the case, we play it safe and don't purge the zombie
// in case of races.
log.WithField(logfields.EndpointID, ep.ID).Debug("Recording DNS lookup in endpoint specific cache")
if updated := ep.DNSHistory.Update(lookupTime, qname, ip.MustAddrsFromIPs(responseIPs), int(TTL)); updated {
if updated := ep.DNSHistory.Update(lookupTime, qname, ippkg.MustAddrsFromIPs(responseIPs), int(TTL)); updated {
ep.DNSZombies.ForceExpireByNameIP(lookupTime, qname, responseIPs...)
ep.SyncEndpointHeaderFile()
}
Expand Down
47 changes: 17 additions & 30 deletions daemon/cmd/hubble.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package cmd
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"

Expand Down Expand Up @@ -39,7 +39,6 @@ import (
"github.com/cilium/cilium/pkg/hubble/server"
"github.com/cilium/cilium/pkg/hubble/server/serveroption"
"github.com/cilium/cilium/pkg/identity"
ippkg "github.com/cilium/cilium/pkg/ip"
"github.com/cilium/cilium/pkg/ipcache"
"github.com/cilium/cilium/pkg/loadbalancer"
"github.com/cilium/cilium/pkg/logging"
Expand Down Expand Up @@ -297,8 +296,6 @@ func (d *Daemon) launchHubble() {

// GetIdentity looks up identity by ID from Cilium's identity cache. Hubble uses the identity info
// to populate source and destination labels of flows.
//
// - IdentityGetter: https://github.com/cilium/hubble/blob/04ab72591faca62a305ce0715108876167182e04/pkg/parser/getters/getters.go#L40
func (d *Daemon) GetIdentity(securityIdentity uint32) (*identity.Identity, error) {
ident := d.identityAllocator.LookupIdentityByID(context.Background(), identity.NumericIdentity(securityIdentity))
if ident == nil {
Expand All @@ -309,9 +306,10 @@ func (d *Daemon) GetIdentity(securityIdentity uint32) (*identity.Identity, error

// GetEndpointInfo returns endpoint info for a given IP address. Hubble uses this function to populate
// fields like namespace and pod name for local endpoints.
//
// - EndpointGetter: https://github.com/cilium/hubble/blob/04ab72591faca62a305ce0715108876167182e04/pkg/parser/getters/getters.go#L34
func (d *Daemon) GetEndpointInfo(ip net.IP) (endpoint v1.EndpointInfo, ok bool) {
func (d *Daemon) GetEndpointInfo(ip netip.Addr) (endpoint v1.EndpointInfo, ok bool) {
if !ip.IsValid() {
return nil, false
}
ep := d.endpointManager.LookupIP(ip)
if ep == nil {
return nil, false
Expand All @@ -334,19 +332,16 @@ func (d *Daemon) GetEndpoints() map[policy.Endpoint]struct{} {

// GetNamesOf implements DNSGetter.GetNamesOf. It looks up DNS names of a given IP from the
// FQDN cache of an endpoint specified by sourceEpID.
//
// - DNSGetter: https://github.com/cilium/hubble/blob/04ab72591faca62a305ce0715108876167182e04/pkg/parser/getters/getters.go#L27
func (d *Daemon) GetNamesOf(sourceEpID uint32, ip net.IP) []string {
func (d *Daemon) GetNamesOf(sourceEpID uint32, ip netip.Addr) []string {
ep := d.endpointManager.LookupCiliumID(uint16(sourceEpID))
if ep == nil {
return nil
}

addr, ok := ippkg.AddrFromIP(ip)
if !ok {
if !ip.IsValid() {
return nil
}
names := ep.DNSHistory.LookupIP(addr)
names := ep.DNSHistory.LookupIP(ip)

for i := range names {
names[i] = strings.TrimSuffix(names[i], ".")
Expand All @@ -357,13 +352,11 @@ func (d *Daemon) GetNamesOf(sourceEpID uint32, ip net.IP) []string {

// GetServiceByAddr looks up service by IP/port. Hubble uses this function to annotate flows
// with service information.
//
// - ServiceGetter: https://github.com/cilium/hubble/blob/04ab72591faca62a305ce0715108876167182e04/pkg/parser/getters/getters.go#L52
func (d *Daemon) GetServiceByAddr(ip net.IP, port uint16) *flowpb.Service {
addrCluster, ok := cmtypes.AddrClusterFromIP(ip)
if !ok {
func (d *Daemon) GetServiceByAddr(ip netip.Addr, port uint16) *flowpb.Service {
if !ip.IsValid() {
return nil
}
addrCluster := cmtypes.AddrClusterFrom(ip, 0)
addr := loadbalancer.L3n4Addr{
AddrCluster: addrCluster,
L4Addr: loadbalancer.L4Addr{
Expand All @@ -382,8 +375,8 @@ func (d *Daemon) GetServiceByAddr(ip net.IP, port uint16) *flowpb.Service {

// GetK8sMetadata returns the Kubernetes metadata for the given IP address.
// It implements hubble parser's IPGetter.GetK8sMetadata.
func (d *Daemon) GetK8sMetadata(ip net.IP) *ipcache.K8sMetadata {
if ip == nil {
func (d *Daemon) GetK8sMetadata(ip netip.Addr) *ipcache.K8sMetadata {
if !ip.IsValid() {
return nil
}
return d.ipcache.GetK8sMetadata(ip.String())
Expand All @@ -392,8 +385,8 @@ func (d *Daemon) GetK8sMetadata(ip net.IP) *ipcache.K8sMetadata {
// LookupSecIDByIP returns the security ID for the given IP. If the security ID
// cannot be found, ok is false.
// It implements hubble parser's IPGetter.LookupSecIDByIP.
func (d *Daemon) LookupSecIDByIP(ip net.IP) (id ipcache.Identity, ok bool) {
if ip == nil {
func (d *Daemon) LookupSecIDByIP(ip netip.Addr) (id ipcache.Identity, ok bool) {
if !ip.IsValid() {
return ipcache.Identity{}, false
}

Expand All @@ -403,20 +396,14 @@ func (d *Daemon) LookupSecIDByIP(ip net.IP) (id ipcache.Identity, ok bool) {

ipv6Prefixes, ipv4Prefixes := d.GetCIDRPrefixLengths()
prefixes := ipv4Prefixes
bits := net.IPv4len * 8
if ip.To4() == nil {
if ip.Is6() {
prefixes = ipv6Prefixes
bits = net.IPv6len * 8
}
for _, prefixLen := range prefixes {
// note: we perform a lookup even when `prefixLen == bits`, as some
// entries derived by a single address cidr-range will not have been
// found by the above lookup
mask := net.CIDRMask(prefixLen, bits)
cidr := net.IPNet{
IP: ip.Mask(mask),
Mask: mask,
}
cidr, _ := ip.Prefix(prefixLen)
if id, ok = d.ipcache.LookupByPrefix(cidr.String()); ok {
return id, ok
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/endpointmanager/cell.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package endpointmanager

import (
"context"
"net"
"net/netip"
"sync"
"time"

Expand Down Expand Up @@ -51,7 +51,7 @@ type EndpointsLookup interface {
LookupIPv6(ipv6 string) *endpoint.Endpoint

// LookupIP looks up endpoint by IP address
LookupIP(ip net.IP) (ep *endpoint.Endpoint)
LookupIP(ip netip.Addr) (ep *endpoint.Endpoint)

// LookupPodName looks up endpoint by namespace + pod name, e.g. "prod/pod-0"
LookupPodName(name string) *endpoint.Endpoint
Expand Down
11 changes: 5 additions & 6 deletions pkg/endpointmanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
Expand Down Expand Up @@ -316,13 +315,13 @@ func (mgr *endpointManager) LookupIPv6(ipv6 string) *endpoint.Endpoint {
}

// LookupIP looks up endpoint by IP address
func (mgr *endpointManager) LookupIP(ip net.IP) (ep *endpoint.Endpoint) {
addr := ip.String()
func (mgr *endpointManager) LookupIP(ip netip.Addr) (ep *endpoint.Endpoint) {
ipStr := ip.String()
mgr.mutex.RLock()
if ip.To4() != nil {
ep = mgr.lookupIPv4(addr)
if ip.Is4() {
ep = mgr.lookupIPv4(ipStr)
} else {
ep = mgr.lookupIPv6(addr)
ep = mgr.lookupIPv6(ipStr)
}
mgr.mutex.RUnlock()
return ep
Expand Down
38 changes: 21 additions & 17 deletions pkg/fqdn/dnsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"math"
"net"
"net/netip"
"regexp"
"strconv"
"strings"
Expand All @@ -27,6 +28,7 @@ import (
"github.com/cilium/cilium/pkg/fqdn/re"
"github.com/cilium/cilium/pkg/fqdn/restore"
"github.com/cilium/cilium/pkg/identity"
ippkg "github.com/cilium/cilium/pkg/ip"
"github.com/cilium/cilium/pkg/ipcache"
"github.com/cilium/cilium/pkg/lock"
"github.com/cilium/cilium/pkg/logging"
Expand Down Expand Up @@ -473,7 +475,7 @@ type LookupEndpointIDByIPFunc func(ip net.IP) (endpoint *endpoint.Endpoint, err
// LookupSecIDByIPFunc Func wraps logic to lookup an IP's security ID from the
// ipcache.
// See DNSProxy.LookupSecIDByIP for usage.
type LookupSecIDByIPFunc func(ip net.IP) (secID ipcache.Identity, exists bool)
type LookupSecIDByIPFunc func(ip netip.Addr) (secID ipcache.Identity, exists bool)

// LookupIPsBySecIDFunc Func wraps logic to lookup an IPs by security ID from the
// ipcache.
Expand Down Expand Up @@ -798,22 +800,24 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
logfields.Identity: ep.GetIdentity(),
})

targetServerIP, targetServerPort, targetServerAddr, err := p.lookupTargetDNSServer(w)
targetServerIP, targetServerPort, targetServerAddrStr, err := p.lookupTargetDNSServer(w)
if err != nil {
log.WithError(err).Error("cannot extract destination IP:port from DNS request")
stat.Err = fmt.Errorf("Cannot extract destination IP:port from DNS request: %w", err)
stat.ProcessingTime.End(false)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, 0, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, 0, targetServerAddrStr, request, protocol, false, &stat)
p.sendRefused(scopedLog, w, request)
return
}

targetServerID := identity.ReservedIdentityWorld
if serverSecID, exists := p.LookupSecIDByIP(targetServerIP); !exists {
scopedLog.WithField("server", targetServerAddr).Debug("cannot find server ip in ipcache, defaulting to WORLD")
// Ignore invalid IP - getter will handle invalid value.
targetServerAddr, _ := ippkg.AddrFromIP(targetServerIP)
if serverSecID, exists := p.LookupSecIDByIP(targetServerAddr); !exists {
scopedLog.WithField("server", targetServerAddrStr).Debug("cannot find server ip in ipcache, defaulting to WORLD")
} else {
targetServerID = serverSecID.ID
scopedLog.WithField("server", targetServerAddr).Debugf("Found target server to of DNS request secID %+v", serverSecID)
scopedLog.WithField("server", targetServerAddrStr).Debugf("Found target server to of DNS request secID %+v", serverSecID)
}

// The allowed check is first because we don't want to use DNS responses that
Expand All @@ -829,7 +833,7 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
scopedLog.WithError(err).Error("Rejecting DNS query from endpoint due to error")
stat.Err = fmt.Errorf("Rejecting DNS query from endpoint due to error: %w", err)
stat.ProcessingTime.End(false)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
p.sendRefused(scopedLog, w, request)
return

Expand All @@ -841,12 +845,12 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
// information for metrics.
stat.Err = p.sendRefused(scopedLog, w, request)
stat.ProcessingTime.End(true)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
return
}

scopedLog.Debug("Forwarding DNS request for a name that is allowed")
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, true, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, true, &stat)

// Keep the same L4 protocol. This handles DNS re-requests over TCP, for
// requests that were too large for UDP.
Expand All @@ -860,7 +864,7 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
scopedLog.Error("Cannot parse DNS proxy client network to select forward client")
stat.Err = fmt.Errorf("Cannot parse DNS proxy client network to select forward client: %w", err)
stat.ProcessingTime.End(false)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
p.sendRefused(scopedLog, w, request)
return
}
Expand All @@ -880,12 +884,12 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
}}
client.Dialer = &dialer

conn, err := client.Dial(targetServerAddr)
conn, err := client.Dial(targetServerAddrStr)
if err != nil {
err := fmt.Errorf("failed to dial connection to %v: %w", targetServerAddr, err)
err := fmt.Errorf("failed to dial connection to %v: %w", targetServerAddrStr, err)
stat.Err = err
scopedLog.WithError(err).Error("Failed to dial connection to the upstream DNS server, cannot service DNS request")
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
p.sendRefused(scopedLog, w, request)
return
}
Expand All @@ -898,12 +902,12 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
stat.Err = err
if stat.IsTimeout() {
scopedLog.WithError(err).Warn("Timeout waiting for response to forwarded proxied DNS lookup")
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
return
}
scopedLog.WithError(err).Error("Cannot forward proxied DNS lookup")
stat.Err = fmt.Errorf("cannot forward proxied DNS lookup: %w", err)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, request, protocol, false, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
p.sendRefused(scopedLog, w, request)
return
}
Expand All @@ -912,7 +916,7 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
stat.Success = true

scopedLog.Debug("Notifying with DNS response to original DNS query")
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, response, protocol, true, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, response, protocol, true, &stat)

scopedLog.Debug("Responding to original DNS query")
// restore the ID to the one in the initial request so it matches what the requester expects.
Expand All @@ -922,7 +926,7 @@ func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
if err != nil {
scopedLog.WithError(err).Error("Cannot forward proxied DNS response")
stat.Err = fmt.Errorf("Cannot forward proxied DNS response: %w", err)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddr, response, protocol, true, &stat)
p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, response, protocol, true, &stat)
} else {
p.Lock()
// Add the server to the set of used DNS servers. This set is never GCd, but is limited by set
Expand Down
2 changes: 1 addition & 1 deletion pkg/fqdn/dnsproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (s *DNSProxyTestSuite) SetUpTest(c *C) {
return endpoint.NewEndpointWithState(s, s, testipcache.NewMockIPCache(), &endpoint.FakeEndpointProxy{}, testidentity.NewMockIdentityAllocator(nil), uint16(epID1), endpoint.StateReady), nil
},
// LookupSecIDByIP
func(ip net.IP) (ipcache.Identity, bool) {
func(ip netip.Addr) (ipcache.Identity, bool) {
DNSServerListenerAddr := (s.dnsServer.Listener.Addr()).(*net.TCPAddr)
switch {
case ip.String() == DNSServerListenerAddr.IP.String():
Expand Down
4 changes: 2 additions & 2 deletions pkg/hubble/parser/common/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package common

import (
"net"
"net/netip"

"github.com/sirupsen/logrus"

Expand Down Expand Up @@ -36,7 +36,7 @@ func NewEndpointResolver(
}
}

func (r *EndpointResolver) ResolveEndpoint(ip net.IP, datapathSecurityIdentity uint32) *pb.Endpoint {
func (r *EndpointResolver) ResolveEndpoint(ip netip.Addr, datapathSecurityIdentity uint32) *pb.Endpoint {
// The datapathSecurityIdentity parameter is the numeric security identity
// obtained from the datapath.
// The numeric identity from the datapath can differ from the one we obtain
Expand Down