diff --git a/pkg/builder/webhook.go b/pkg/builder/webhook.go index 6263f030a0..6f4726d274 100644 --- a/pkg/builder/webhook.go +++ b/pkg/builder/webhook.go @@ -17,6 +17,7 @@ limitations under the License. package builder import ( + "context" "errors" "net/http" "net/url" @@ -49,6 +50,7 @@ type WebhookBuilder struct { config *rest.Config recoverPanic *bool logConstructor func(base logr.Logger, req *admission.Request) logr.Logger + contextFunc func(context.Context, *http.Request) context.Context err error } @@ -90,6 +92,12 @@ func (blder *WebhookBuilder) WithLogConstructor(logConstructor func(base logr.Lo return blder } +// WithContextFunc overrides the webhook's WithContextFunc. +func (blder *WebhookBuilder) WithContextFunc(contextFunc func(context.Context, *http.Request) context.Context) *WebhookBuilder { + blder.contextFunc = contextFunc + return blder +} + // RecoverPanic indicates whether panics caused by the webhook should be recovered. // Defaults to true. func (blder *WebhookBuilder) RecoverPanic(recoverPanic bool) *WebhookBuilder { @@ -205,6 +213,7 @@ func (blder *WebhookBuilder) registerDefaultingWebhook() error { mwh := blder.getDefaultingWebhook() if mwh != nil { mwh.LogConstructor = blder.logConstructor + mwh.WithContextFunc = blder.contextFunc path := generateMutatePath(blder.gvk) if blder.customDefaulterCustomPath != "" { generatedCustomPath, err := generateCustomPath(blder.customDefaulterCustomPath) @@ -243,6 +252,7 @@ func (blder *WebhookBuilder) registerValidatingWebhook() error { vwh := blder.getValidatingWebhook() if vwh != nil { vwh.LogConstructor = blder.logConstructor + vwh.WithContextFunc = blder.contextFunc path := generateValidatePath(blder.gvk) if blder.customValidatorCustomPath != "" { generatedCustomPath, err := generateCustomPath(blder.customValidatorCustomPath) diff --git a/pkg/builder/webhook_test.go b/pkg/builder/webhook_test.go index eb70af2e0a..72538ef7bf 100644 --- a/pkg/builder/webhook_test.go +++ b/pkg/builder/webhook_test.go @@ -49,8 +49,14 @@ const ( svcBaseAddr = "http://svc-name.svc-ns.svc" customPath = "/custom-path" + + userAgentHeader = "User-Agent" + userAgentCtxKey agentCtxKey = "UserAgent" + userAgentValue = "test" ) +type agentCtxKey string + var _ = Describe("webhook", func() { Describe("New", func() { Context("v1 AdmissionReview", func() { @@ -315,6 +321,9 @@ func runTests(admissionReviewVersion string) { WithLogConstructor(func(base logr.Logger, req *admission.Request) logr.Logger { return admission.DefaultLogConstructor(testingLogger, req) }). + WithContextFunc(func(ctx context.Context, request *http.Request) context.Context { + return context.WithValue(ctx, userAgentCtxKey, request.Header.Get(userAgentHeader)) + }). Complete() ExpectWithOffset(1, err).NotTo(HaveOccurred()) svr := m.GetWebhookServer() @@ -344,6 +353,30 @@ func runTests(admissionReviewVersion string) { } } }`) + readerWithCxt := strings.NewReader(admissionReviewGV + admissionReviewVersion + `", + "request":{ + "uid":"07e52e8d-4513-11e9-a716-42010a800270", + "kind":{ + "group":"foo.test.org", + "version":"v1", + "kind":"TestValidator" + }, + "resource":{ + "group":"foo.test.org", + "version":"v1", + "resource":"testvalidator" + }, + "namespace":"default", + "name":"foo", + "operation":"UPDATE", + "object":{ + "replica":1 + }, + "oldObject":{ + "replica":1 + } + } +}`) ctx, cancel := context.WithCancel(specCtx) cancel() @@ -373,6 +406,20 @@ func runTests(admissionReviewVersion string) { ExpectWithOffset(1, w.Body).To(ContainSubstring(`"allowed":false`)) ExpectWithOffset(1, w.Body).To(ContainSubstring(`"code":403`)) EventuallyWithOffset(1, logBuffer).Should(gbytes.Say(`"msg":"Validating object","object":{"name":"foo","namespace":"default"},"namespace":"default","name":"foo","resource":{"group":"foo.test.org","version":"v1","resource":"testvalidator"},"user":"","requestID":"07e52e8d-4513-11e9-a716-42010a800270"`)) + + By("sending a request to a validating webhook with context header validation") + path = generateValidatePath(testValidatorGVK) + _, err = readerWithCxt.Seek(0, 0) + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + req = httptest.NewRequest("POST", svcBaseAddr+path, readerWithCxt) + req.Header.Add("Content-Type", "application/json") + req.Header.Add(userAgentHeader, userAgentValue) + w = httptest.NewRecorder() + svr.WebhookMux().ServeHTTP(w, req) + ExpectWithOffset(1, w.Code).To(Equal(http.StatusOK)) + By("sanity checking the response contains reasonable field") + ExpectWithOffset(1, w.Body).To(ContainSubstring(`"allowed":true`)) + ExpectWithOffset(1, w.Body).To(ContainSubstring(`"code":200`)) }) It("should scaffold a custom validating webhook with a custom path", func(specCtx SpecContext) { @@ -1009,6 +1056,7 @@ func (*TestCustomDefaulter) Default(ctx context.Context, obj runtime.Object) err if d.Replica < 2 { d.Replica = 2 } + return nil } @@ -1035,6 +1083,7 @@ func (*TestCustomValidator) ValidateCreate(ctx context.Context, obj runtime.Obje if v.Replica < 0 { return nil, errors.New("number of replica should be greater than or equal to 0") } + return nil, nil } @@ -1056,6 +1105,12 @@ func (*TestCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj r if v.Replica < old.Replica { return nil, fmt.Errorf("new replica %v should not be fewer than old replica %v", v.Replica, old.Replica) } + + userAgent, ok := ctx.Value(userAgentCtxKey).(string) + if ok && userAgent != userAgentValue { + return nil, fmt.Errorf("expected %s value is %q in TestCustomValidator got %q", userAgentCtxKey, userAgentValue, userAgent) + } + return nil, nil }