Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ func main() {
setupLog.Error(err, "unable to set up placement shim")
os.Exit(1)
}
metrics.Registry.MustRegister(placementShim)
placementShim.RegisterRoutes(mux)
}

Expand Down
Empty file removed internal/shim/placement/.gitkeep
Empty file.
112 changes: 99 additions & 13 deletions internal/shim/placement/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ package placement
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"time"

"github.com/cobaltcore-dev/cortex/pkg/conf"
"github.com/cobaltcore-dev/cortex/pkg/multicluster"
"github.com/cobaltcore-dev/cortex/pkg/sso"
hv1 "github.com/cobaltcore-dev/openstack-hypervisor-operator/api/v1"
"github.com/prometheus/client_golang/prometheus"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/handler"
Expand All @@ -30,6 +33,13 @@ var (
setupLog = ctrl.Log.WithName("placement-shim")
)

// contextKey is an unexported type for context keys in this package.
type contextKey struct{}

// routePatternKey is the context key used to pass the route pattern from the
// measurement middleware (set in RegisterRoutes) to the forward method.
var routePatternKey = contextKey{}

// config holds configuration for the placement shim.
type config struct {
// SSO is an optional reference to a Kubernetes secret containing
Expand Down Expand Up @@ -58,6 +68,45 @@ type Shim struct {
// HTTP client that can talk to openstack placement, if needed, over
// ingress with single-sign-on.
httpClient *http.Client

// downstreamRequestTimer is a prometheus histogram to measure the duration
// (and count) of requests coming from the client that wants to talk to the
// placement API.
downstreamRequestTimer *prometheus.HistogramVec
// upstreamRequestTimer is a prometheus histogram to measure the duration
// (and count) of requests to the upstream placement API by route and method.
upstreamRequestTimer *prometheus.HistogramVec
}

// statusCapturingResponseWriter wraps http.ResponseWriter to capture the
// HTTP status code written via WriteHeader for use in metrics labels.
type statusCapturingResponseWriter struct {
http.ResponseWriter
statusCode int
}

func (w *statusCapturingResponseWriter) WriteHeader(code int) {
w.statusCode = code
w.ResponseWriter.WriteHeader(code)
}

func (w *statusCapturingResponseWriter) Write(b []byte) (int, error) {
if w.statusCode == 0 {
w.statusCode = http.StatusOK
}
return w.ResponseWriter.Write(b)
}

// Describe implements prometheus.Collector.
func (s *Shim) Describe(ch chan<- *prometheus.Desc) {
s.downstreamRequestTimer.Describe(ch)
s.upstreamRequestTimer.Describe(ch)
}

// Collect implements prometheus.Collector.
func (s *Shim) Collect(ch chan<- prometheus.Metric) {
s.downstreamRequestTimer.Collect(ch)
s.upstreamRequestTimer.Collect(ch)
}

// Start is called after the manager has started and the cache is running.
Expand Down Expand Up @@ -138,17 +187,17 @@ func (s *Shim) predicateRemoteHypervisor() predicate.Predicate {
})
}

// SetupWithManager registers field indexes on the manager's cache so that
// subsequent list calls are served from the informer cache rather than
// hitting the API server. This must be called before the manager is started.
//
// Calling IndexField internally invokes GetInformer, which creates and
// registers a shared informer for the indexed type (hv1.Hypervisor) with the
// cache. The informer is started later when mgr.Start() is called. This
// means no separate controller or empty Reconcile loop is needed — the
// index registration alone is sufficient to warm the cache.
// SetupWithManager sets up the controller with the manager.
// It registers watches for the Hypervisor CRD across all clusters and sets up
// the HTTP client for talking to the placement API.
func (s *Shim) SetupWithManager(ctx context.Context, mgr ctrl.Manager) (err error) {
setupLog.Info("Setting up placement shim with manager")

// Bind the Start method to the manager.
if err := mgr.Add(s); err != nil {
return err
}

s.config, err = conf.GetConfig[config]()
if err != nil {
setupLog.Error(err, "Failed to load placement shim config")
Expand All @@ -158,6 +207,19 @@ func (s *Shim) SetupWithManager(ctx context.Context, mgr ctrl.Manager) (err erro
if err := s.config.validate(); err != nil {
return err
}

// Initialize Prometheus histogram timers for request monitoring.
s.downstreamRequestTimer = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "cortex_placement_shim_downstream_request_duration_seconds",
Help: "Duration of downstream requests to the placement shim from clients.",
Buckets: prometheus.DefBuckets,
}, []string{"method", "pattern", "responsecode"})
s.upstreamRequestTimer = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "cortex_placement_shim_upstream_request_duration_seconds",
Help: "Duration of upstream requests from the placement shim to the placement API.",
Buckets: prometheus.DefBuckets,
}, []string{"method", "pattern", "responsecode"})

// Check that the provided client is a multicluster client, since we need
// that to watch for hypervisors across clusters.
mcl, ok := s.Client.(*multicluster.Client)
Expand All @@ -178,8 +240,11 @@ func (s *Shim) SetupWithManager(ctx context.Context, mgr ctrl.Manager) (err erro

// forward proxies the incoming HTTP request to the upstream placement API
// and copies the response (status, headers, body) back to the client.
// The route pattern for metric labels is read from the request context
// (set by the measurement middleware in RegisterRoutes).
func (s *Shim) forward(w http.ResponseWriter, r *http.Request) {
log := logf.FromContext(r.Context())
ctx := r.Context()
log := logf.FromContext(ctx)

if s.httpClient == nil {
log.Info("placement shim not yet initialized, rejecting request")
Expand All @@ -204,7 +269,7 @@ func (s *Shim) forward(w http.ResponseWriter, r *http.Request) {
upstream.RawQuery = r.URL.RawQuery

// Create upstream request preserving method, body, and context.
upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, upstream.String(), r.Body)
upstreamReq, err := http.NewRequestWithContext(ctx, r.Method, upstream.String(), r.Body)
if err != nil {
log.Error(err, "failed to create upstream request", "url", upstream.String())
http.Error(w, "failed to create upstream request", http.StatusBadGateway)
Expand All @@ -214,9 +279,14 @@ func (s *Shim) forward(w http.ResponseWriter, r *http.Request) {
// Copy all incoming headers.
upstreamReq.Header = r.Header.Clone()

resp, err := s.httpClient.Do(upstreamReq) //nolint:gosec // G704: intentional reverse proxy; host is fixed by operator config, only path varies
pattern, _ := ctx.Value(routePatternKey).(string)
start := time.Now()
resp, err := s.httpClient.Do(upstreamReq) //nolint:gosec // G704: intentional reverse proxy
if err != nil {
log.Error(err, "failed to reach placement API", "url", upstream.String())
s.upstreamRequestTimer.
WithLabelValues(r.Method, pattern, strconv.Itoa(http.StatusBadGateway)).
Observe(time.Since(start).Seconds())
http.Error(w, "failed to reach placement API", http.StatusBadGateway)
return
}
Expand All @@ -232,6 +302,11 @@ func (s *Shim) forward(w http.ResponseWriter, r *http.Request) {
if _, err := io.Copy(w, resp.Body); err != nil {
log.Error(err, "failed to copy upstream response body")
}
// Observe after the body is fully consumed so the duration includes
// the time spent streaming the response from upstream.
s.upstreamRequestTimer.
WithLabelValues(r.Method, pattern, strconv.Itoa(resp.StatusCode)).
Observe(time.Since(start).Seconds())
}

// RegisterRoutes binds all Placement API handlers to the given mux. The
Expand Down Expand Up @@ -283,7 +358,18 @@ func (s *Shim) RegisterRoutes(mux *http.ServeMux) {
}
for _, h := range handlers {
setupLog.Info("Registering route", "method", h.method, "pattern", h.pattern)
mux.HandleFunc(h.method+" "+h.pattern, h.handler)
routePattern := fmt.Sprintf("%s %s", h.method, h.pattern)
handlerPattern := h.pattern
next := h.handler
mux.HandleFunc(routePattern, func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), routePatternKey, handlerPattern))
sw := &statusCapturingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
start := time.Now()
next.ServeHTTP(sw, r)
s.downstreamRequestTimer.
WithLabelValues(r.Method, handlerPattern, strconv.Itoa(sw.statusCode)).
Observe(time.Since(start).Seconds())
Comment thread
PhilippMatthes marked this conversation as resolved.
})
}
setupLog.Info("Successfully registered placement API routes")
}
127 changes: 121 additions & 6 deletions internal/shim/placement/shim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,47 @@
package placement

import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)

const validUUID = "d9b3a520-2a3c-4f6b-8b9a-1c2d3e4f5a6b"

// timerLabels are the histogram label names used by both request timers.
var timerLabels = []string{"method", "pattern", "responsecode"}

// histSampleCount returns the number of observations recorded by the histogram
// with the given label values. Returns 0 when no matching series exists.
func histSampleCount(t *testing.T, h *prometheus.HistogramVec, lvs ...string) uint64 {
t.Helper()
obs, err := h.GetMetricWithLabelValues(lvs...)
if err != nil {
t.Fatalf("failed to get metric with labels %v: %v", lvs, err)
}
m := &dto.Metric{}
if err := obs.(prometheus.Metric).Write(m); err != nil {
t.Fatalf("failed to write metric: %v", err)
}
return m.GetHistogram().GetSampleCount()
}

// newTestTimers returns fresh downstream and upstream histogram vecs for tests.
func newTestTimers() (downstream, upstream *prometheus.HistogramVec) {
return prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_downstream", Buckets: prometheus.DefBuckets,
}, timerLabels),
prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_upstream", Buckets: prometheus.DefBuckets,
}, timerLabels)
}

// newTestShim creates a Shim backed by an upstream test server that returns
// the given status and body for every request. It records the last request
// path in *gotPath when non-nil.
Expand All @@ -28,9 +60,12 @@ func newTestShim(t *testing.T, status int, body string, gotPath *string) *Shim {
}
}))
t.Cleanup(upstream.Close)
down, up := newTestTimers()
return &Shim{
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
downstreamRequestTimer: down,
upstreamRequestTimer: up,
}
}

Expand Down Expand Up @@ -127,6 +162,7 @@ func TestForward(t *testing.T) {
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
}
s.downstreamRequestTimer, s.upstreamRequestTimer = newTestTimers()
target := tt.path
if tt.query != "" {
target += "?" + tt.query
Expand Down Expand Up @@ -158,9 +194,12 @@ func TestForward(t *testing.T) {
}

func TestForwardUpstreamUnreachable(t *testing.T) {
down, up := newTestTimers()
s := &Shim{
config: config{PlacementURL: "http://127.0.0.1:1"},
httpClient: &http.Client{},
config: config{PlacementURL: "http://127.0.0.1:1"},
httpClient: &http.Client{},
downstreamRequestTimer: down,
upstreamRequestTimer: up,
}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
w := httptest.NewRecorder()
Expand All @@ -175,9 +214,12 @@ func TestRegisterRoutes(t *testing.T) {
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
down, up := newTestTimers()
s := &Shim{
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
downstreamRequestTimer: down,
upstreamRequestTimer: up,
}
mux := http.NewServeMux()
s.RegisterRoutes(mux)
Expand Down Expand Up @@ -207,3 +249,76 @@ func TestRegisterRoutes(t *testing.T) {
})
}
}

func TestRegisterRoutesDownstreamMetrics(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()
down, up := newTestTimers()
s := &Shim{
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
downstreamRequestTimer: down,
upstreamRequestTimer: up,
}
mux := http.NewServeMux()
s.RegisterRoutes(mux)

// Fire a request through the mux so the wrapper observes the downstream timer.
req := httptest.NewRequest(http.MethodGet, "/resource_providers", http.NoBody)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
// The downstream timer should have exactly one observation for the
// expected label combination (method, pattern, responsecode).
if n := histSampleCount(t, down, "GET", "/resource_providers", "200"); n != 1 {
t.Errorf("downstream observation count = %d, want 1", n)
}
}

func TestForwardUpstreamMetrics(t *testing.T) {
t.Run("success records upstream status", func(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer upstream.Close()
down, up := newTestTimers()
s := &Shim{
config: config{PlacementURL: upstream.URL},
httpClient: upstream.Client(),
downstreamRequestTimer: down,
upstreamRequestTimer: up,
}
// Set the route pattern via context, as the RegisterRoutes wrapper would.
req := httptest.NewRequest(http.MethodGet, "/traits", http.NoBody)
req = req.WithContext(context.WithValue(req.Context(), routePatternKey, "/traits"))
w := httptest.NewRecorder()
s.forward(w, req)

if n := histSampleCount(t, up, "GET", "/traits", "404"); n != 1 {
t.Errorf("upstream observation count = %d, want 1", n)
}
})

t.Run("unreachable upstream records 502", func(t *testing.T) {
down, up := newTestTimers()
s := &Shim{
config: config{PlacementURL: "http://127.0.0.1:1"},
httpClient: &http.Client{},
downstreamRequestTimer: down,
upstreamRequestTimer: up,
}
req := httptest.NewRequest(http.MethodGet, "/usages", http.NoBody)
req = req.WithContext(context.WithValue(req.Context(), routePatternKey, "/usages"))
w := httptest.NewRecorder()
s.forward(w, req)

if n := histSampleCount(t, up, "GET", "/usages", "502"); n != 1 {
t.Errorf("upstream observation count = %d, want 1", n)
}
})
}
Loading