Skip to content

Commit

Permalink
Move Address/pod creation to podWatcher, consider all Server updates
Browse files Browse the repository at this point in the history
Moved the Address/pod object creation logic into the podWatcher, on the `getOrNewPodPublisher()` method, called from `Subscribe()`. The latter receives `service`, `hostname` and `ip`, which can be empty depending on the case.

With this change also the auxiliary functions `getIndexedPod()` and `podReceivingTraffic()` were moved into the `watcher` package.

Also the Server updates handler now triggers updates for all podPublishers regardless of the Server selectors. The endpointProfileTranslator is now tracking the last message sent, to avoid sending dupe messages.

Other changes:
- Added fields `defaultOpaquePorts`, `k8sAPI` and `metadataAPI` into the podPublisher struct.
- Removed the fields `ip` and `port` from the endpointProfileTranslator, which are not used.
  • Loading branch information
alpeb committed Sep 19, 2023
1 parent 67a5da1 commit 1093f78
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 305 deletions.
29 changes: 15 additions & 14 deletions controller/api/destination/endpoint_profile_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@ import (
pb "github.com/linkerd/linkerd2-proxy-api/go/destination"
"github.com/linkerd/linkerd2/controller/api/destination/watcher"
"github.com/linkerd/linkerd2/controller/k8s"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

type endpointProfileTranslator struct {
enableH2Upgrade bool
controllerNS string
identityTrustDomain string
defaultOpaquePorts map[uint32]struct{}
ip string
port uint32
stream pb.Destination_GetProfileServer
lastMessage string

k8sAPI *k8s.API
metadataAPI *k8s.MetadataAPI
log *logrus.Entry
log *log.Entry
}

// newEndpointProfileTranslator translates pod updates and protocol updates to
Expand All @@ -30,32 +29,29 @@ func newEndpointProfileTranslator(
controllerNS,
identityTrustDomain string,
defaultOpaquePorts map[uint32]struct{},
ip string,
port uint32,
stream pb.Destination_GetProfileServer,
k8sAPI *k8s.API,
metadataAPI *k8s.MetadataAPI,
log *logrus.Entry,
) *endpointProfileTranslator {
return &endpointProfileTranslator{
enableH2Upgrade: enableH2Upgrade,
controllerNS: controllerNS,
identityTrustDomain: identityTrustDomain,
defaultOpaquePorts: defaultOpaquePorts,
ip: ip,
port: port,
stream: stream,
k8sAPI: k8sAPI,
metadataAPI: metadataAPI,
log: log.WithField("component", "endpoint-profile-translator"),
}
}

func (ept *endpointProfileTranslator) Update(address *watcher.Address) error {
// Update sends a DestinationProfile message into the stream, if the same
// message hasn't been sent already. If it has, false is returned.
func (ept *endpointProfileTranslator) Update(address *watcher.Address) (bool, error) {
opaquePorts := watcher.GetAnnotatedOpaquePorts(address.Pod, ept.defaultOpaquePorts)
endpoint, err := ept.createEndpoint(*address, opaquePorts)
if err != nil {
return fmt.Errorf("failed to create endpoint: %w", err)
return false, fmt.Errorf("failed to create endpoint: %w", err)
}

// The protocol for an endpoint should only be updated if there is a pod,
Expand All @@ -68,7 +64,7 @@ func (ept *endpointProfileTranslator) Update(address *watcher.Address) error {
} else if endpoint.ProtocolHint.OpaqueTransport == nil {
port, err := getInboundPort(&address.Pod.Spec)
if err != nil {
return err
return false, err
}

endpoint.ProtocolHint.OpaqueTransport = &pb.ProtocolHint_OpaqueTransport{
Expand All @@ -82,12 +78,17 @@ func (ept *endpointProfileTranslator) Update(address *watcher.Address) error {
Endpoint: endpoint,
OpaqueProtocol: address.OpaqueProtocol,
}
msg := profile.String()
if msg == ept.lastMessage {
return false, nil
}
ept.lastMessage = msg
ept.log.Debugf("sending protocol update: %+v", profile)
if err := ept.stream.Send(profile); err != nil {
return fmt.Errorf("failed to send protocol update: %w", err)
return false, fmt.Errorf("failed to send protocol update: %w", err)
}

return nil
return true, nil
}

func (ept *endpointProfileTranslator) createEndpoint(address watcher.Address, opaquePorts map[uint32]struct{}) (*pb.WeightedAddr, error) {
Expand Down
153 changes: 9 additions & 144 deletions controller/api/destination/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,7 @@ func (s *server) getProfileByIP(
}

if svcID == nil {
// If the IP does not map to a service, check if it maps to a pod
var pod *corev1.Pod
targetIP := ip.String()
pod, err = getPodByPodIP(s.k8sAPI, targetIP, port, s.log)
if err != nil {
return err
}
if pod != nil {
targetIP = pod.Status.PodIP
} else {
pod, err = getPodByHostIP(s.k8sAPI, targetIP, port, s.log)
if err != nil {
return err
}
}

address, err := watcher.CreateAddress(s.k8sAPI, s.metadataAPI, s.defaultOpaquePorts, pod, targetIP, port)
if err != nil {
return fmt.Errorf("failed to create address: %w", err)
}
return s.subscribeToEndpointProfile(&address, port, stream)
return s.subscribeToEndpointProfile(nil, "", ip.String(), port, stream)
}

fqn := fmt.Sprintf("%s.%s.svc.%s", svcID.Name, svcID.Namespace, s.clusterDomain)
Expand All @@ -317,11 +297,7 @@ func (s *server) getProfileByName(
// name. When we fetch the profile using a pod's DNS name, we want to
// return an endpoint in the profile response.
if hostname != "" {
address, err := s.getEndpointByHostname(s.k8sAPI, hostname, service, port)
if err != nil {
return fmt.Errorf("failed to get pod for hostname %s: %w", hostname, err)
}
return s.subscribeToEndpointProfile(address, port, stream)
return s.subscribeToEndpointProfile(&service, hostname, "", port, stream)
}

return s.subscribeToServiceProfile(service, token, host, port, stream)
Expand Down Expand Up @@ -466,7 +442,9 @@ func (s *server) subscribeToServiceWithoutContext(
//
// This function does not return until the stream is closed.
func (s *server) subscribeToEndpointProfile(
address *watcher.Address,
service *watcher.ServiceID,
hostname,
ip string,
port uint32,
stream pb.Destination_GetProfileServer,
) error {
Expand All @@ -475,20 +453,17 @@ func (s *server) subscribeToEndpointProfile(
s.controllerNS,
s.identityTrustDomain,
s.defaultOpaquePorts,
address.IP,
port,
stream,
s.k8sAPI,
s.metadataAPI,
s.log,
)

if err := translator.Update(address); err != nil {
var err error
ip, err = s.pods.Subscribe(service, hostname, ip, port, translator)
if err != nil {
return err
}

s.pods.Subscribe(address.Pod, address.IP, port, translator)
defer s.pods.Unsubscribe(address.IP, port, translator)
defer s.pods.Unsubscribe(ip, port, translator)

select {
case <-s.shutdown:
Expand Down Expand Up @@ -528,116 +503,6 @@ func getSvcID(k8sAPI *k8s.API, clusterIP string, log *logging.Entry) (*watcher.S
return service, nil
}

// getEndpointByHostname returns a pod that maps to the given hostname (or an
// instanceID). The hostname is generally the prefix of the pod's DNS name;
// since it may be arbitrary we need to look at the corresponding service's
// Endpoints object to see whether the hostname matches a pod.
func (s *server) getEndpointByHostname(k8sAPI *k8s.API, hostname string, svcID watcher.ServiceID, port uint32) (*watcher.Address, error) {
ep, err := k8sAPI.Endpoint().Lister().Endpoints(svcID.Namespace).Get(svcID.Name)
if err != nil {
return nil, err
}

for _, subset := range ep.Subsets {
for _, addr := range subset.Addresses {

if hostname == addr.Hostname {
if addr.TargetRef != nil && addr.TargetRef.Kind == "Pod" {
podName := addr.TargetRef.Name
podNamespace := addr.TargetRef.Namespace
pod, err := k8sAPI.Pod().Lister().Pods(podNamespace).Get(podName)
if err != nil {
return nil, err
}
address, err := watcher.CreateAddress(s.k8sAPI, s.metadataAPI, s.defaultOpaquePorts, pod, addr.IP, port)
if err != nil {
return nil, err
}
return &address, nil
}
return &watcher.Address{
IP: addr.IP,
Port: port,
}, nil

}
}
}

return nil, fmt.Errorf("no pod found in Endpoints %s/%s for hostname %s", svcID.Namespace, svcID.Name, hostname)
}

// getPodByHostIP returns a pod that maps to the given IP address in the host
// network. It must have a container port that exposes `port` as a host port.
func getPodByHostIP(k8sAPI *k8s.API, hostIP string, port uint32, log *logging.Entry) (*corev1.Pod, error) {
addr := net.JoinHostPort(hostIP, fmt.Sprintf("%d", port))
hostIPPods, err := getIndexedPods(k8sAPI, watcher.HostIPIndex, addr)
if err != nil {
return nil, status.Error(codes.Unknown, err.Error())
}
if len(hostIPPods) == 1 {
log.Debugf("found %s:%d on the host network", hostIP, port)
return hostIPPods[0], nil
}
if len(hostIPPods) > 1 {
conflictingPods := []string{}
for _, pod := range hostIPPods {
conflictingPods = append(conflictingPods, fmt.Sprintf("%s:%s", pod.Namespace, pod.Name))
}
log.Warnf("found conflicting %s:%d endpoint on the host network: %s", hostIP, port, strings.Join(conflictingPods, ","))
return nil, status.Errorf(codes.FailedPrecondition, "found %d pods with a conflicting host network endpoint %s:%d", len(hostIPPods), hostIP, port)
}

return nil, nil
}

// getPodByPodIP returns a pod that maps to the given IP address in the pod network
func getPodByPodIP(k8sAPI *k8s.API, podIP string, port uint32, log *logging.Entry) (*corev1.Pod, error) {
podIPPods, err := getIndexedPods(k8sAPI, watcher.PodIPIndex, podIP)
if err != nil {
return nil, status.Error(codes.Unknown, err.Error())
}
if len(podIPPods) == 1 {
log.Debugf("found %s on the pod network", podIP)
return podIPPods[0], nil
}
if len(podIPPods) > 1 {
conflictingPods := []string{}
for _, pod := range podIPPods {
conflictingPods = append(conflictingPods, fmt.Sprintf("%s:%s", pod.Namespace, pod.Name))
}
log.Warnf("found conflicting %s IP on the pod network: %s", podIP, strings.Join(conflictingPods, ","))
return nil, status.Errorf(codes.FailedPrecondition, "found %d pods with a conflicting pod network IP %s", len(podIPPods), podIP)
}

log.Debugf("no pod found for %s:%d", podIP, port)
return nil, nil
}

func getIndexedPods(k8sAPI *k8s.API, indexName string, podIP string) ([]*corev1.Pod, error) {
objs, err := k8sAPI.Pod().Informer().GetIndexer().ByIndex(indexName, podIP)
if err != nil {
return nil, fmt.Errorf("failed getting %s indexed pods: %w", indexName, err)
}
pods := make([]*corev1.Pod, 0)
for _, obj := range objs {
pod := obj.(*corev1.Pod)
if !podReceivingTraffic(pod) {
continue
}
pods = append(pods, pod)
}
return pods, nil
}

func podReceivingTraffic(pod *corev1.Pod) bool {
phase := pod.Status.Phase
podTerminated := phase == corev1.PodSucceeded || phase == corev1.PodFailed
podTerminating := pod.DeletionTimestamp != nil

return !podTerminating && !podTerminated
}

////////////
/// util ///
////////////
Expand Down
96 changes: 0 additions & 96 deletions controller/api/destination/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package destination
import (
"context"
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -690,101 +689,6 @@ spec:
})
}

func TestIpWatcherGetPod(t *testing.T) {
podIP := "10.255.0.1"
hostIP := "172.0.0.1"
var hostPort1 uint32 = 22345
var hostPort2 uint32 = 22346
expectedPodName := "hostPortPod1"
k8sConfigs := []string{`
apiVersion: v1
kind: Pod
metadata:
name: hostPortPod1
namespace: ns
spec:
containers:
- image: test
name: hostPortContainer1
ports:
- containerPort: 12345
hostIP: 172.0.0.1
hostPort: 22345
- image: test
name: hostPortContainer2
ports:
- containerPort: 12346
hostIP: 172.0.0.1
hostPort: 22346
status:
phase: Running
podIP: 10.255.0.1
hostIP: 172.0.0.1`,
`
apiVersion: v1
kind: Pod
metadata:
name: pod
namespace: ns
status:
phase: Running
podIP: 10.255.0.1`,
}
t.Run("get pod by host IP and host port", func(t *testing.T) {
k8sAPI, err := k8s.NewFakeAPI(k8sConfigs...)
if err != nil {
t.Fatalf("failed to create new fake API: %s", err)
}

err = watcher.InitializeIndexers(k8sAPI)
if err != nil {
t.Fatalf("initializeIndexers returned an error: %s", err)
}

k8sAPI.Sync(nil)
// Get host IP pod that is mapped to the port `hostPort1`
pod, err := getPodByHostIP(k8sAPI, hostIP, hostPort1, logging.WithFields(nil))
if err != nil {
t.Fatalf("failed to get pod: %s", err)
}
if pod == nil {
t.Fatalf("failed to find pod mapped to %s:%d", hostIP, hostPort1)
}
if pod.Name != expectedPodName {
t.Fatalf("expected pod name to be %s, but got %s", expectedPodName, pod.Name)
}
// Get host IP pod that is mapped to the port `hostPort2`; this tests
// that the indexer properly adds multiple containers from a single
// pod.
pod, err = getPodByHostIP(k8sAPI, hostIP, hostPort2, logging.WithFields(nil))
if err != nil {
t.Fatalf("failed to get pod: %s", err)
}
if pod == nil {
t.Fatalf("failed to find pod mapped to %s:%d", hostIP, hostPort2)
}
if pod.Name != expectedPodName {
t.Fatalf("expected pod name to be %s, but got %s", expectedPodName, pod.Name)
}
// Get host IP pod with unmapped host port
pod, err = getPodByHostIP(k8sAPI, hostIP, 12347, logging.WithFields(nil))
if err != nil {
t.Fatalf("expected no error when getting host IP pod with unmapped host port, but got: %s", err)
}
if pod != nil {
t.Fatal("expected no pod to be found with unmapped host port")
}
// Get pod IP pod and expect an error
_, err = getPodByPodIP(k8sAPI, podIP, 12346, logging.WithFields(nil))
if err == nil {
t.Fatal("expected error when getting by pod IP and unmapped host port, but got none")
}
if !strings.Contains(err.Error(), "pods with a conflicting pod network IP") {
t.Fatalf("expected error to be pod IP address conflict, but got: %s", err)
}
})
}

func assertSingleProfile(t *testing.T, updates []*pb.DestinationProfile) *pb.DestinationProfile {
t.Helper()
// Under normal conditions the creation of resources by the fake API will
Expand Down
Loading

0 comments on commit 1093f78

Please sign in to comment.