From 8b33a9ed4253c7a0423edc648140bb6b05007878 Mon Sep 17 00:00:00 2001 From: "Dr. Stefan Schimanski" Date: Tue, 20 Sep 2016 16:26:29 +0200 Subject: [PATCH 1/3] Move TimeoutHandler+MaxInFlightLimit to Config.New() --- pkg/genericapiserver/config.go | 33 +++++++++++++++++++----- pkg/genericapiserver/genericapiserver.go | 27 +++---------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/pkg/genericapiserver/config.go b/pkg/genericapiserver/config.go index 3a2342dc3054..842fd57b5906 100644 --- a/pkg/genericapiserver/config.go +++ b/pkg/genericapiserver/config.go @@ -23,6 +23,7 @@ import ( "net" "net/http" "os" + "regexp" "strconv" "strings" "time" @@ -157,6 +158,11 @@ type Config struct { // OpenAPIDefinitions is a map of type to OpenAPI spec for all types used in this API server. Failure to provide // this map or any of the models used by the server APIs will result in spec generation failure. OpenAPIDefinitions *common.OpenAPIDefinitions + + // MaxRequestsInFlight is the maximum number of parallel non-long-running requests. Every further + // request has to wait. + MaxRequestsInFlight int + LongRunningRequestRE string } func NewConfig(options *options.ServerRunOptions) *Config { @@ -191,6 +197,8 @@ func NewConfig(options *options.ServerRunOptions) *Config { Version: "unversioned", }, }, + MaxRequestsInFlight: options.MaxRequestsInFlight, + LongRunningRequestRE: options.LongRunningRequestRE, } } @@ -386,21 +394,34 @@ func (c Config) New() (*GenericAPIServer, error) { handler = authenticatedHandler } - // TODO: Make this optional? Consumers of GenericAPIServer depend on this currently. - s.Handler = handler - - // After all wrapping is done, put a context filter around both handlers - var err error - handler, err = api.NewRequestContextFilter(c.RequestContextMapper, s.Handler) + handler, err := api.NewRequestContextFilter(c.RequestContextMapper, handler) if err != nil { glog.Fatalf("Could not initialize request context filter for s.Handler: %v", err) } + + longRunningRE := regexp.MustCompile(c.LongRunningRequestRE) + longRunningRequestCheck := apiserver.BasicLongRunningRequestCheck(longRunningRE, map[string]string{"watch": "true"}) + longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) { + // TODO unify this with apiserver.MaxInFlightLimit + if longRunningRequestCheck(req) { + return nil, "" + } + return time.After(globalTimeout), "" + } + handler = apiserver.TimeoutHandler(apiserver.RecoverPanics(handler, s.NewRequestInfoResolver()), longRunningTimeout) + + var inFlightTokens chan bool + if c.MaxRequestsInFlight > 0 { + inFlightTokens = make(chan bool, c.MaxRequestsInFlight) + } + handler = apiserver.MaxInFlightLimit(inFlightTokens, longRunningRequestCheck, handler) s.Handler = handler handler, err = api.NewRequestContextFilter(c.RequestContextMapper, s.InsecureHandler) if err != nil { glog.Fatalf("Could not initialize request context filter for s.InsecureHandler: %v", err) } + handler = apiserver.TimeoutHandler(apiserver.RecoverPanics(handler, s.NewRequestInfoResolver()), longRunningTimeout) s.InsecureHandler = handler s.installGroupsDiscoveryHandler() diff --git a/pkg/genericapiserver/genericapiserver.go b/pkg/genericapiserver/genericapiserver.go index ffdd51003610..d0ebc11ac285 100644 --- a/pkg/genericapiserver/genericapiserver.go +++ b/pkg/genericapiserver/genericapiserver.go @@ -22,7 +22,6 @@ import ( "net" "net/http" "path" - "regexp" "sort" "strconv" "strings" @@ -248,34 +247,16 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { if s.enableOpenAPISupport { s.InstallOpenAPI() } - // We serve on 2 ports. See docs/admin/accessing-the-api.md + secureLocation := "" if options.SecurePort != 0 { secureLocation = net.JoinHostPort(options.BindAddress.String(), strconv.Itoa(options.SecurePort)) } - insecureLocation := net.JoinHostPort(options.InsecureBindAddress.String(), strconv.Itoa(options.InsecurePort)) - - var sem chan bool - if options.MaxRequestsInFlight > 0 { - sem = make(chan bool, options.MaxRequestsInFlight) - } - - longRunningRE := regexp.MustCompile(options.LongRunningRequestRE) - longRunningRequestCheck := apiserver.BasicLongRunningRequestCheck(longRunningRE, map[string]string{"watch": "true"}) - longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) { - // TODO unify this with apiserver.MaxInFlightLimit - if longRunningRequestCheck(req) { - return nil, "" - } - return time.After(globalTimeout), "" - } - secureStartedCh := make(chan struct{}) if secureLocation != "" { - handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.Handler, s.NewRequestInfoResolver()), longRunningTimeout) secureServer := &http.Server{ Addr: secureLocation, - Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, handler), + Handler: s.Handler, MaxHeaderBytes: 1 << 20, TLSConfig: &tls.Config{ // Can't use SSLv3 because of POODLE and BEAST @@ -342,10 +323,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { close(secureStartedCh) } - handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.InsecureHandler, s.NewRequestInfoResolver()), longRunningTimeout) + insecureLocation := net.JoinHostPort(options.InsecureBindAddress.String(), strconv.Itoa(options.InsecurePort)) http := &http.Server{ Addr: insecureLocation, - Handler: handler, + Handler: s.InsecureHandler, MaxHeaderBytes: 1 << 20, } From 3799ffa0a894507e7982a51cf3cc2d74c4816191 Mon Sep 17 00:00:00 2001 From: "Dr. Stefan Schimanski" Date: Tue, 20 Sep 2016 16:52:36 +0200 Subject: [PATCH 2/3] Simplify genericapiserver.Run() --- pkg/genericapiserver/genericapiserver.go | 25 +++++++++--------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/pkg/genericapiserver/genericapiserver.go b/pkg/genericapiserver/genericapiserver.go index d0ebc11ac285..4b61f5cc2ae2 100644 --- a/pkg/genericapiserver/genericapiserver.go +++ b/pkg/genericapiserver/genericapiserver.go @@ -248,12 +248,9 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { s.InstallOpenAPI() } - secureLocation := "" - if options.SecurePort != 0 { - secureLocation = net.JoinHostPort(options.BindAddress.String(), strconv.Itoa(options.SecurePort)) - } secureStartedCh := make(chan struct{}) - if secureLocation != "" { + if options.SecurePort != 0 { + secureLocation := net.JoinHostPort(options.BindAddress.String(), strconv.Itoa(options.SecurePort)) secureServer := &http.Server{ Addr: secureLocation, Handler: s.Handler, @@ -301,10 +298,6 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { notifyStarted := sync.Once{} for { - // err == systemd.SdNotifyNoSocket when not running on a systemd system - if err := systemd.SdNotify("READY=1\n"); err != nil && err != systemd.SdNotifyNoSocket { - glog.Errorf("Unable to send systemd daemon successful start message: %v\n", err) - } if err := secureServer.ListenAndServeTLS(options.TLSCertFile, options.TLSPrivateKeyFile); err != nil { glog.Errorf("Unable to listen for secure (%v); will try again.", err) } else { @@ -316,20 +309,15 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { } }() } else { - // err == systemd.SdNotifyNoSocket when not running on a systemd system - if err := systemd.SdNotify("READY=1\n"); err != nil && err != systemd.SdNotifyNoSocket { - glog.Errorf("Unable to send systemd daemon successful start message: %v\n", err) - } close(secureStartedCh) } insecureLocation := net.JoinHostPort(options.InsecureBindAddress.String(), strconv.Itoa(options.InsecurePort)) - http := &http.Server{ + insecureServer := &http.Server{ Addr: insecureLocation, Handler: s.InsecureHandler, MaxHeaderBytes: 1 << 20, } - insecureStartedCh := make(chan struct{}) glog.Infof("Serving insecurely on %s", insecureLocation) go func() { @@ -337,7 +325,7 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { notifyStarted := sync.Once{} for { - if err := http.ListenAndServe(); err != nil { + if err := insecureServer.ListenAndServe(); err != nil { glog.Errorf("Unable to listen for insecure (%v); will try again.", err) } else { notifyStarted.Do(func() { @@ -352,6 +340,11 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { <-insecureStartedCh s.RunPostStartHooks(PostStartHookContext{}) + // err == systemd.SdNotifyNoSocket when not running on a systemd system + if err := systemd.SdNotify("READY=1\n"); err != nil && err != systemd.SdNotifyNoSocket { + glog.Errorf("Unable to send systemd daemon successful start message: %v\n", err) + } + select {} } From 87356c0623cd759df5879dc753c8a885bce4ef64 Mon Sep 17 00:00:00 2001 From: "Dr. Stefan Schimanski" Date: Wed, 21 Sep 2016 11:36:44 +0200 Subject: [PATCH 3/3] Cleanup handler chain --- hack/.linted_packages | 2 + pkg/api/requestcontext.go | 26 +- pkg/apiserver/apiserver.go | 14 +- pkg/apiserver/apiserver_test.go | 75 ---- pkg/apiserver/audit/audit.go | 6 +- pkg/apiserver/handler_impersonation_test.go | 2 +- pkg/apiserver/handlers.go | 324 +----------------- pkg/apiserver/handlers_test.go | 179 ---------- pkg/auth/handlers/handlers.go | 12 +- pkg/auth/handlers/handlers_test.go | 32 +- pkg/genericapiserver/config.go | 127 +++---- pkg/genericapiserver/filters/cors.go | 83 +++++ pkg/genericapiserver/filters/cors_test.go | 93 +++++ pkg/genericapiserver/filters/doc.go | 19 + pkg/genericapiserver/filters/longrunning.go | 46 +++ pkg/genericapiserver/filters/maxinflight.go | 60 ++++ .../filters/maxinflight_test.go | 143 ++++++++ pkg/genericapiserver/filters/timeout.go | 259 ++++++++++++++ pkg/genericapiserver/filters/timeout_test.go | 81 +++++ pkg/genericapiserver/genericapiserver.go | 69 ++-- pkg/genericapiserver/routes/index.go | 3 + pkg/genericapiserver/routes/profiling.go | 1 + pkg/genericapiserver/routes/swaggerui.go | 1 + pkg/genericapiserver/routes/version.go | 3 +- pkg/master/master.go | 2 +- 25 files changed, 934 insertions(+), 728 deletions(-) create mode 100644 pkg/genericapiserver/filters/cors.go create mode 100644 pkg/genericapiserver/filters/cors_test.go create mode 100644 pkg/genericapiserver/filters/doc.go create mode 100644 pkg/genericapiserver/filters/longrunning.go create mode 100644 pkg/genericapiserver/filters/maxinflight.go create mode 100644 pkg/genericapiserver/filters/maxinflight_test.go create mode 100644 pkg/genericapiserver/filters/timeout.go create mode 100644 pkg/genericapiserver/filters/timeout_test.go diff --git a/hack/.linted_packages b/hack/.linted_packages index b46e1a2e1cc6..da961ba18c59 100644 --- a/hack/.linted_packages +++ b/hack/.linted_packages @@ -105,6 +105,8 @@ pkg/controller/volume/reconciler pkg/controller/volume/statusupdater pkg/conversion/queryparams pkg/credentialprovider/aws +pkg/genericapiserver/filters +pkg/genericapiserver/routes pkg/hyperkube pkg/kubelet/api pkg/kubelet/container diff --git a/pkg/api/requestcontext.go b/pkg/api/requestcontext.go index 14983b2d41fe..724fea811daa 100644 --- a/pkg/api/requestcontext.go +++ b/pkg/api/requestcontext.go @@ -20,6 +20,8 @@ import ( "errors" "net/http" "sync" + + "github.com/golang/glog" ) // RequestContextMapper keeps track of the context associated with a particular request @@ -89,21 +91,21 @@ func (c *requestContextMap) remove(req *http.Request) { delete(c.contexts, req) } -// NewRequestContextFilter ensures there is a Context object associated with the request before calling the passed handler. +// WithRequestContext ensures there is a Context object associated with the request before calling the passed handler. // After the passed handler runs, the context is cleaned up. -func NewRequestContextFilter(mapper RequestContextMapper, handler http.Handler) (http.Handler, error) { - if mapper, ok := mapper.(*requestContextMap); ok { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if mapper.init(req, NewContext()) { - // If we were the ones to successfully initialize, pair with a remove - defer mapper.remove(req) - } - handler.ServeHTTP(w, req) - }), nil - } else { - return handler, errors.New("Unknown RequestContextMapper implementation.") +func WithRequestContext(handler http.Handler, mapper RequestContextMapper) http.Handler { + rcMap, ok := mapper.(*requestContextMap) + if !ok { + glog.Fatal("Unknown RequestContextMapper implementation.") } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if rcMap.init(req, NewContext()) { + // If we were the ones to successfully initialize, pair with a remove + defer rcMap.remove(req) + } + handler.ServeHTTP(w, req) + }) } // IsEmpty returns true if there are no contexts registered, or an error if it could not be determined. Intended for use by tests. diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 21f2754308ef..b49fc907f7b0 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -293,8 +293,8 @@ func keepUnversioned(group string) bool { return group == "" || group == "extensions" } -// Adds a service to return the supported api versions at /apis. -func AddApisWebService(s runtime.NegotiatedSerializer, container *restful.Container, apiPrefix string, f func(req *restful.Request) []unversioned.APIGroup) { +// NewApisWebService returns a webservice serving the available api version under /apis. +func NewApisWebService(s runtime.NegotiatedSerializer, apiPrefix string, f func(req *restful.Request) []unversioned.APIGroup) *restful.WebService { // Because in release 1.1, /apis returns response with empty APIVersion, we // use StripVersionNegotiatedSerializer to keep the response backwards // compatible. @@ -309,12 +309,12 @@ func AddApisWebService(s runtime.NegotiatedSerializer, container *restful.Contai Produces(s.SupportedMediaTypes()...). Consumes(s.SupportedMediaTypes()...). Writes(unversioned.APIGroupList{})) - container.Add(ws) + return ws } -// Adds a service to return the supported versions, preferred version, and name -// of a group. E.g., a such web service will be registered at /apis/extensions. -func AddGroupWebService(s runtime.NegotiatedSerializer, container *restful.Container, path string, group unversioned.APIGroup) { +// NewGroupWebService returns a webservice serving the supported versions, preferred version, and name +// of a group. E.g., such a web service will be registered at /apis/extensions. +func NewGroupWebService(s runtime.NegotiatedSerializer, path string, group unversioned.APIGroup) *restful.WebService { ss := s if keepUnversioned(group.Name) { // Because in release 1.1, /apis/extensions returns response with empty @@ -332,7 +332,7 @@ func AddGroupWebService(s runtime.NegotiatedSerializer, container *restful.Conta Produces(s.SupportedMediaTypes()...). Consumes(s.SupportedMediaTypes()...). Writes(unversioned.APIGroup{})) - container.Add(ws) + return ws } // Adds a service to return the supported resources, E.g., a such web service diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index a16aeecc7a15..6255820b7b5b 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -43,7 +43,6 @@ import ( "k8s.io/kubernetes/pkg/fields" "k8s.io/kubernetes/pkg/labels" "k8s.io/kubernetes/pkg/runtime" - "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/diff" "k8s.io/kubernetes/pkg/util/sets" "k8s.io/kubernetes/pkg/watch" @@ -2995,80 +2994,6 @@ func TestCreateTimeout(t *testing.T) { } } -func TestCORSAllowedOrigins(t *testing.T) { - table := []struct { - allowedOrigins []string - origin string - allowed bool - }{ - {[]string{}, "example.com", false}, - {[]string{"example.com"}, "example.com", true}, - {[]string{"example.com"}, "not-allowed.com", false}, - {[]string{"not-matching.com", "example.com"}, "example.com", true}, - {[]string{".*"}, "example.com", true}, - } - - for _, item := range table { - allowedOriginRegexps, err := util.CompileRegexps(item.allowedOrigins) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - handler := CORS( - handle(map[string]rest.Storage{}), - allowedOriginRegexps, nil, nil, "true", - ) - server := httptest.NewServer(handler) - defer server.Close() - client := http.Client{} - - request, err := http.NewRequest("GET", server.URL+"/version", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - request.Header.Set("Origin", item.origin) - - response, err := client.Do(request) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if item.allowed { - if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { - t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) - } - - if response.Header.Get("Access-Control-Allow-Credentials") == "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to be set") - } - - if response.Header.Get("Access-Control-Allow-Headers") == "" { - t.Errorf("Expected Access-Control-Allow-Headers header to be set") - } - - if response.Header.Get("Access-Control-Allow-Methods") == "" { - t.Errorf("Expected Access-Control-Allow-Methods header to be set") - } - } else { - if response.Header.Get("Access-Control-Allow-Origin") != "" { - t.Errorf("Expected Access-Control-Allow-Origin header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Credentials") != "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Headers") != "" { - t.Errorf("Expected Access-Control-Allow-Headers header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Methods") != "" { - t.Errorf("Expected Access-Control-Allow-Methods header to not be set") - } - } - } -} - func TestCreateChecksAPIVersion(t *testing.T) { handler := handle(map[string]rest.Storage{"simple": &SimpleRESTStorage{}}) server := httptest.NewServer(handler) diff --git a/pkg/apiserver/audit/audit.go b/pkg/apiserver/audit/audit.go index 63803d220306..1a9d437e34a1 100644 --- a/pkg/apiserver/audit/audit.go +++ b/pkg/apiserver/audit/audit.go @@ -72,7 +72,8 @@ var _ http.Flusher = &fancyResponseWriterDelegator{} var _ http.Hijacker = &fancyResponseWriterDelegator{} // WithAudit decorates a http.Handler with audit logging information for all the -// requests coming to the server. Each audit log contains two entries: +// requests coming to the server. If out is nil, no decoration takes place. +// Each audit log contains two entries: // 1. the request line containing: // - unique id allowing to match the response line (see 2) // - source ip of the request @@ -85,6 +86,9 @@ var _ http.Hijacker = &fancyResponseWriterDelegator{} // - the unique id from 1 // - response code func WithAudit(handler http.Handler, attributeGetter apiserver.RequestAttributeGetter, out io.Writer) http.Handler { + if out == nil { + return handler + } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { attribs := attributeGetter.GetAttribs(req) asuser := req.Header.Get("Impersonate-User") diff --git a/pkg/apiserver/handler_impersonation_test.go b/pkg/apiserver/handler_impersonation_test.go index 50cce24be1e2..313870d79e62 100644 --- a/pkg/apiserver/handler_impersonation_test.go +++ b/pkg/apiserver/handler_impersonation_test.go @@ -302,7 +302,7 @@ func TestImpersonationFilter(t *testing.T) { delegate.ServeHTTP(w, req) }) }(WithImpersonation(doNothingHandler, requestContextMapper, impersonateAuthorizer{})) - handler, _ = api.NewRequestContextFilter(requestContextMapper, handler) + handler = api.WithRequestContext(handler, requestContextMapper) server := httptest.NewServer(handler) defer server.Close() diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go index bcf29c49065b..ea4eff9ff2d0 100644 --- a/pkg/apiserver/handlers.go +++ b/pkg/apiserver/handlers.go @@ -17,16 +17,10 @@ limitations under the License. package apiserver import ( - "bufio" - "encoding/json" "fmt" - "net" "net/http" - "regexp" "runtime/debug" "strings" - "sync" - "time" "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" @@ -53,10 +47,6 @@ var namespaceSubresources = sets.NewString("status", "finalize") // NamespaceSubResourcesForTest exports namespaceSubresources for testing in pkg/master/master_test.go, so we never drift var NamespaceSubResourcesForTest = sets.NewString(namespaceSubresources.List()...) -// Constant for the retry-after interval on rate limiting. -// TODO: maybe make this dynamic? or user-adjustable? -const RetryAfter = "1" - // IsReadOnlyReq() is true for any (or at least many) request which has no observable // side effects on state of apiserver (though there may be internal side effects like // caching and logging). @@ -80,59 +70,6 @@ func ReadOnly(handler http.Handler) http.Handler { }) } -type LongRunningRequestCheck func(r *http.Request) bool - -// BasicLongRunningRequestCheck pathRegex operates against the url path, the queryParams match is case insensitive. -// Any one match flags the request. -// TODO tighten this check to eliminate the abuse potential by malicious clients that start setting queryParameters -// to bypass the rate limitter. This could be done using a full parse and special casing the bits we need. -func BasicLongRunningRequestCheck(pathRegex *regexp.Regexp, queryParams map[string]string) LongRunningRequestCheck { - return func(r *http.Request) bool { - if pathRegex.MatchString(r.URL.Path) { - return true - } - - for key, expectedValue := range queryParams { - if strings.ToLower(expectedValue) == strings.ToLower(r.URL.Query().Get(key)) { - return true - } - } - - return false - } -} - -// MaxInFlight limits the number of in-flight requests to buffer size of the passed in channel. -func MaxInFlightLimit(c chan bool, longRunningRequestCheck LongRunningRequestCheck, handler http.Handler) http.Handler { - if c == nil { - return handler - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if longRunningRequestCheck(r) { - // Skip tracking long running events. - handler.ServeHTTP(w, r) - return - } - select { - case c <- true: - defer func() { <-c }() - handler.ServeHTTP(w, r) - default: - tooManyRequests(r, w) - } - }) -} - -func tooManyRequests(req *http.Request, w http.ResponseWriter) { - // "Too Many Requests" response is returned before logger is setup for the request. - // So we need to explicitly log it here. - defer httplog.NewLogged(req, &w).Log() - - // Return a 429 status indicating "Too Many Requests" - w.Header().Set("Retry-After", RetryAfter) - http.Error(w, "Too many requests, please try again later.", errors.StatusTooManyRequests) -} - // RecoverPanics wraps an http Handler to recover and log panics. func RecoverPanics(handler http.Handler, resolver *RequestInfoResolver) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -168,261 +105,6 @@ func RecoverPanics(handler http.Handler, resolver *RequestInfoResolver) http.Han }) } -var errConnKilled = fmt.Errorf("kill connection/stream") - -// TimeoutHandler returns an http.Handler that runs h with a timeout -// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle -// each request, but if a call runs for longer than its time limit, the -// handler responds with a 503 Service Unavailable error and the message -// provided. (If msg is empty, a suitable default message will be sent.) After -// the handler times out, writes by h to its http.ResponseWriter will return -// http.ErrHandlerTimeout. If timeoutFunc returns a nil timeout channel, no -// timeout will be enforced. -func TimeoutHandler(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, msg string)) http.Handler { - return &timeoutHandler{h, timeoutFunc} -} - -type timeoutHandler struct { - handler http.Handler - timeout func(*http.Request) (<-chan time.Time, string) -} - -func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - after, msg := t.timeout(r) - if after == nil { - t.handler.ServeHTTP(w, r) - return - } - - done := make(chan struct{}) - tw := newTimeoutWriter(w) - go func() { - t.handler.ServeHTTP(tw, r) - close(done) - }() - select { - case <-done: - return - case <-after: - tw.timeout(msg) - } -} - -type timeoutWriter interface { - http.ResponseWriter - timeout(string) -} - -func newTimeoutWriter(w http.ResponseWriter) timeoutWriter { - base := &baseTimeoutWriter{w: w} - - _, notifiable := w.(http.CloseNotifier) - _, hijackable := w.(http.Hijacker) - - switch { - case notifiable && hijackable: - return &closeHijackTimeoutWriter{base} - case notifiable: - return &closeTimeoutWriter{base} - case hijackable: - return &hijackTimeoutWriter{base} - default: - return base - } -} - -type baseTimeoutWriter struct { - w http.ResponseWriter - - mu sync.Mutex - // if the timeout handler has timedout - timedOut bool - // if this timeout writer has wrote header - wroteHeader bool - // if this timeout writer has been hijacked - hijacked bool -} - -func (tw *baseTimeoutWriter) Header() http.Header { - tw.mu.Lock() - defer tw.mu.Unlock() - - if tw.timedOut { - return http.Header{} - } - - return tw.w.Header() -} - -func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { - tw.mu.Lock() - defer tw.mu.Unlock() - - if tw.timedOut { - return 0, http.ErrHandlerTimeout - } - if tw.hijacked { - return 0, http.ErrHijacked - } - - tw.wroteHeader = true - return tw.w.Write(p) -} - -func (tw *baseTimeoutWriter) Flush() { - tw.mu.Lock() - defer tw.mu.Unlock() - - if tw.timedOut { - return - } - - if flusher, ok := tw.w.(http.Flusher); ok { - flusher.Flush() - } -} - -func (tw *baseTimeoutWriter) WriteHeader(code int) { - tw.mu.Lock() - defer tw.mu.Unlock() - - if tw.timedOut || tw.wroteHeader || tw.hijacked { - return - } - - tw.wroteHeader = true - tw.w.WriteHeader(code) -} - -func (tw *baseTimeoutWriter) timeout(msg string) { - tw.mu.Lock() - defer tw.mu.Unlock() - - tw.timedOut = true - - // The timeout writer has not been used by the inner handler. - // We can safely timeout the HTTP request by sending by a timeout - // handler - if !tw.wroteHeader && !tw.hijacked { - tw.w.WriteHeader(http.StatusGatewayTimeout) - if msg != "" { - tw.w.Write([]byte(msg)) - } else { - enc := json.NewEncoder(tw.w) - enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0)) - } - } else { - // The timeout writer has been used by the inner handler. There is - // no way to timeout the HTTP request at the point. We have to shutdown - // the connection for HTTP1 or reset stream for HTTP2. - // - // Note from: Brad Fitzpatrick - // if the ServeHTTP goroutine panics, that will do the best possible thing for both - // HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and - // you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP - // connection immediately without a proper 0-byte EOF chunk, so the peer will recognize - // the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving - // the TCP connection open, but resetting the stream to the peer so it'll have an error, - // like the HTTP/1 case. - panic(errConnKilled) - } -} - -func (tw *baseTimeoutWriter) closeNotify() <-chan bool { - tw.mu.Lock() - defer tw.mu.Unlock() - - if tw.timedOut { - done := make(chan bool) - close(done) - return done - } - - return tw.w.(http.CloseNotifier).CloseNotify() -} - -func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) { - tw.mu.Lock() - defer tw.mu.Unlock() - - if tw.timedOut { - return nil, nil, http.ErrHandlerTimeout - } - conn, rw, err := tw.w.(http.Hijacker).Hijack() - if err == nil { - tw.hijacked = true - } - return conn, rw, err -} - -type closeTimeoutWriter struct { - *baseTimeoutWriter -} - -func (tw *closeTimeoutWriter) CloseNotify() <-chan bool { - return tw.closeNotify() -} - -type hijackTimeoutWriter struct { - *baseTimeoutWriter -} - -func (tw *hijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return tw.hijack() -} - -type closeHijackTimeoutWriter struct { - *baseTimeoutWriter -} - -func (tw *closeHijackTimeoutWriter) CloseNotify() <-chan bool { - return tw.closeNotify() -} - -func (tw *closeHijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return tw.hijack() -} - -// TODO: use restful.CrossOriginResourceSharing -// Simple CORS implementation that wraps an http Handler -// For a more detailed implementation use https://github.com/martini-contrib/cors -// or implement CORS at your proxy layer -// Pass nil for allowedMethods and allowedHeaders to use the defaults -func CORS(handler http.Handler, allowedOriginPatterns []*regexp.Regexp, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - origin := req.Header.Get("Origin") - if origin != "" { - allowed := false - for _, pattern := range allowedOriginPatterns { - if allowed = pattern.MatchString(origin); allowed { - break - } - } - if allowed { - w.Header().Set("Access-Control-Allow-Origin", origin) - // Set defaults for methods and headers if nothing was passed - if allowedMethods == nil { - allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"} - } - if allowedHeaders == nil { - allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} - } - w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) - w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) - w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) - - // Stop here if its a preflight OPTIONS request - if req.Method == "OPTIONS" { - w.WriteHeader(http.StatusNoContent) - return - } - } - } - // Dispatch to the next handler - handler.ServeHTTP(w, req) - }) -} - // RequestAttributeGetter is a function that extracts authorizer.Attributes from an http.Request type RequestAttributeGetter interface { GetAttribs(req *http.Request) (attribs authorizer.Attributes) @@ -467,7 +149,11 @@ func (r *requestAttributeGetter) GetAttribs(req *http.Request) authorizer.Attrib } // WithAuthorizationCheck passes all authorized requests on to handler, and returns a forbidden error otherwise. -func WithAuthorizationCheck(handler http.Handler, getAttribs RequestAttributeGetter, a authorizer.Authorizer) http.Handler { +func WithAuthorization(handler http.Handler, getAttribs RequestAttributeGetter, a authorizer.Authorizer) http.Handler { + if a == nil { + glog.Warningf("Authorization is disabled") + return handler + } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { authorized, reason, err := a.Authorize(getAttribs.GetAttribs(req)) if err != nil { diff --git a/pkg/apiserver/handlers_test.go b/pkg/apiserver/handlers_test.go index 50058bdff3fd..a9d679b021eb 100644 --- a/pkg/apiserver/handlers_test.go +++ b/pkg/apiserver/handlers_test.go @@ -17,19 +17,12 @@ limitations under the License. package apiserver import ( - "fmt" - "io/ioutil" "net/http" "net/http/httptest" "reflect" - "regexp" - "strings" - "sync" "testing" - "time" "k8s.io/kubernetes/pkg/api" - "k8s.io/kubernetes/pkg/api/errors" "k8s.io/kubernetes/pkg/api/testapi" "k8s.io/kubernetes/pkg/apis/extensions" "k8s.io/kubernetes/pkg/auth/authorizer" @@ -42,17 +35,6 @@ func (fakeRL) Stop() {} func (f fakeRL) TryAccept() bool { return bool(f) } func (f fakeRL) Accept() {} -func expectHTTP(url string, code int) error { - r, err := http.Get(url) - if err != nil { - return fmt.Errorf("unexpected error: %v", err) - } - if r.StatusCode != code { - return fmt.Errorf("unexpected response: %v", r.StatusCode) - } - return nil -} - func getPath(resource, namespace, name string) string { return testapi.Default.ResourcePath(resource, namespace, name) } @@ -61,111 +43,6 @@ func pathWithPrefix(prefix, resource, namespace, name string) string { return testapi.Default.ResourcePathWithPrefix(prefix, resource, namespace, name) } -// Tests that MaxInFlightLimit works, i.e. -// - "long" requests such as proxy or watch, identified by regexp are not accounted despite -// hanging for the long time, -// - "short" requests are correctly accounted, i.e. there can be only size of channel passed to the -// constructor in flight at any given moment, -// - subsequent "short" requests are rejected instantly with appropriate error, -// - subsequent "long" requests are handled normally, -// - we correctly recover after some "short" requests finish, i.e. we can process new ones. -func TestMaxInFlight(t *testing.T) { - const AllowedInflightRequestsNo = 3 - // Size of inflightRequestsChannel determines how many concurrent inflight requests - // are allowed. - inflightRequestsChannel := make(chan bool, AllowedInflightRequestsNo) - // notAccountedPathsRegexp specifies paths requests to which we don't account into - // requests in flight. - notAccountedPathsRegexp := regexp.MustCompile(".*\\/watch") - longRunningRequestCheck := BasicLongRunningRequestCheck(notAccountedPathsRegexp, map[string]string{"watch": "true"}) - - // Calls is used to wait until all server calls are received. We are sending - // AllowedInflightRequestsNo of 'long' not-accounted requests and the same number of - // 'short' accounted ones. - calls := &sync.WaitGroup{} - calls.Add(AllowedInflightRequestsNo * 2) - - // Responses is used to wait until all responses are - // received. This prevents some async requests getting EOF - // errors from prematurely closing the server - responses := sync.WaitGroup{} - responses.Add(AllowedInflightRequestsNo * 2) - - // Block is used to keep requests in flight for as long as we need to. All requests will - // be unblocked at the same time. - block := sync.WaitGroup{} - block.Add(1) - - server := httptest.NewServer( - MaxInFlightLimit( - inflightRequestsChannel, - longRunningRequestCheck, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // A short, accounted request that does not wait for block WaitGroup. - if strings.Contains(r.URL.Path, "dontwait") { - return - } - if calls != nil { - calls.Done() - } - block.Wait() - }), - ), - ) - defer server.Close() - - // These should hang, but not affect accounting. use a query param match - for i := 0; i < AllowedInflightRequestsNo; i++ { - // These should hang waiting on block... - go func() { - if err := expectHTTP(server.URL+"/foo/bar?watch=true", http.StatusOK); err != nil { - t.Error(err) - } - responses.Done() - }() - } - // Check that sever is not saturated by not-accounted calls - if err := expectHTTP(server.URL+"/dontwait", http.StatusOK); err != nil { - t.Error(err) - } - - // These should hang and be accounted, i.e. saturate the server - for i := 0; i < AllowedInflightRequestsNo; i++ { - // These should hang waiting on block... - go func() { - if err := expectHTTP(server.URL, http.StatusOK); err != nil { - t.Error(err) - } - responses.Done() - }() - } - // We wait for all calls to be received by the server - calls.Wait() - // Disable calls notifications in the server - calls = nil - - // Do this multiple times to show that it rate limit rejected requests don't block. - for i := 0; i < 2; i++ { - if err := expectHTTP(server.URL, errors.StatusTooManyRequests); err != nil { - t.Error(err) - } - } - // Validate that non-accounted URLs still work. use a path regex match - if err := expectHTTP(server.URL+"/dontwait/watch", http.StatusOK); err != nil { - t.Error(err) - } - - // Let all hanging requests finish - block.Done() - - // Show that we recover from being blocked up. - // Too avoid flakyness we need to wait until at least one of the requests really finishes. - responses.Wait() - if err := expectHTTP(server.URL, http.StatusOK); err != nil { - t.Error(err) - } -} - func TestReadOnly(t *testing.T) { server := httptest.NewServer(ReadOnly(http.HandlerFunc( func(w http.ResponseWriter, req *http.Request) { @@ -184,62 +61,6 @@ func TestReadOnly(t *testing.T) { } } -func TestTimeout(t *testing.T) { - sendResponse := make(chan struct{}, 1) - writeErrors := make(chan error, 1) - timeout := make(chan time.Time, 1) - resp := "test response" - timeoutResp := "test timeout" - - ts := httptest.NewServer(TimeoutHandler(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - <-sendResponse - _, err := w.Write([]byte(resp)) - writeErrors <- err - }), - func(*http.Request) (<-chan time.Time, string) { - return timeout, timeoutResp - })) - defer ts.Close() - - // No timeouts - sendResponse <- struct{}{} - res, err := http.Get(ts.URL) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusOK { - t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusOK) - } - body, _ := ioutil.ReadAll(res.Body) - if string(body) != resp { - t.Errorf("got body %q; expected %q", string(body), resp) - } - if err := <-writeErrors; err != nil { - t.Errorf("got unexpected Write error on first request: %v", err) - } - - // Times out - timeout <- time.Time{} - res, err = http.Get(ts.URL) - if err != nil { - t.Error(err) - } - if res.StatusCode != http.StatusGatewayTimeout { - t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) - } - body, _ = ioutil.ReadAll(res.Body) - if string(body) != timeoutResp { - t.Errorf("got body %q; expected %q", string(body), timeoutResp) - } - - // Now try to send a response - sendResponse <- struct{}{} - if err := <-writeErrors; err != http.ErrHandlerTimeout { - t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout) - } -} - func TestGetAttribs(t *testing.T) { r := &requestAttributeGetter{api.NewRequestContextMapper(), &RequestInfoResolver{sets.NewString("api", "apis"), sets.NewString("api")}} diff --git a/pkg/auth/handlers/handlers.go b/pkg/auth/handlers/handlers.go index 0f6e226e9c01..215032ff12d7 100644 --- a/pkg/auth/handlers/handlers.go +++ b/pkg/auth/handlers/handlers.go @@ -41,12 +41,15 @@ func init() { prometheus.MustRegister(authenticatedUserCounter) } -// NewRequestAuthenticator creates an http handler that tries to authenticate the given request as a user, and then +// WithAuthentication creates an http handler that tries to authenticate the given request as a user, and then // stores any such user found onto the provided context for the request. If authentication fails or returns an error // the failed handler is used. On success, handler is invoked to serve the request. -func NewRequestAuthenticator(mapper api.RequestContextMapper, auth authenticator.Request, failed http.Handler, handler http.Handler) (http.Handler, error) { - return api.NewRequestContextFilter( - mapper, +func WithAuthentication(handler http.Handler, mapper api.RequestContextMapper, auth authenticator.Request, failed http.Handler) http.Handler { + if auth == nil { + glog.Warningf("Authentication is disabled") + return handler + } + return api.WithRequestContext( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { user, ok, err := auth.AuthenticateRequest(req) if err != nil || !ok { @@ -65,6 +68,7 @@ func NewRequestAuthenticator(mapper api.RequestContextMapper, auth authenticator handler.ServeHTTP(w, req) }), + mapper, ) } diff --git a/pkg/auth/handlers/handlers_test.go b/pkg/auth/handlers/handlers_test.go index 1118081207b5..141f41fcc2cb 100644 --- a/pkg/auth/handlers/handlers_test.go +++ b/pkg/auth/handlers/handlers_test.go @@ -30,14 +30,7 @@ import ( func TestAuthenticateRequest(t *testing.T) { success := make(chan struct{}) contextMapper := api.NewRequestContextMapper() - auth, err := NewRequestAuthenticator( - contextMapper, - authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { - return &user.DefaultInfo{Name: "user"}, true, nil - }), - http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - t.Errorf("unexpected call to failed") - }), + auth := WithAuthentication( http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { ctx, ok := contextMapper.Get(req) if ctx == nil || !ok { @@ -49,6 +42,13 @@ func TestAuthenticateRequest(t *testing.T) { } close(success) }), + contextMapper, + authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { + return &user.DefaultInfo{Name: "user"}, true, nil + }), + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Errorf("unexpected call to failed") + }), ) auth.ServeHTTP(httptest.NewRecorder(), &http.Request{}) @@ -66,7 +66,10 @@ func TestAuthenticateRequest(t *testing.T) { func TestAuthenticateRequestFailed(t *testing.T) { failed := make(chan struct{}) contextMapper := api.NewRequestContextMapper() - auth, err := NewRequestAuthenticator( + auth := WithAuthentication( + http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + t.Errorf("unexpected call to handler") + }), contextMapper, authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { return nil, false, nil @@ -74,9 +77,6 @@ func TestAuthenticateRequestFailed(t *testing.T) { http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { close(failed) }), - http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { - t.Errorf("unexpected call to handler") - }), ) auth.ServeHTTP(httptest.NewRecorder(), &http.Request{}) @@ -94,7 +94,10 @@ func TestAuthenticateRequestFailed(t *testing.T) { func TestAuthenticateRequestError(t *testing.T) { failed := make(chan struct{}) contextMapper := api.NewRequestContextMapper() - auth, err := NewRequestAuthenticator( + auth := WithAuthentication( + http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + t.Errorf("unexpected call to handler") + }), contextMapper, authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { return nil, false, errors.New("failure") @@ -102,9 +105,6 @@ func TestAuthenticateRequestError(t *testing.T) { http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { close(failed) }), - http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { - t.Errorf("unexpected call to handler") - }), ) auth.ServeHTTP(httptest.NewRecorder(), &http.Request{}) diff --git a/pkg/genericapiserver/config.go b/pkg/genericapiserver/config.go index 842fd57b5906..aabcf1f0b639 100644 --- a/pkg/genericapiserver/config.go +++ b/pkg/genericapiserver/config.go @@ -25,7 +25,6 @@ import ( "os" "regexp" "strconv" - "strings" "time" "github.com/emicklei/go-restful" @@ -40,15 +39,15 @@ import ( "k8s.io/kubernetes/pkg/apiserver/audit" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" - "k8s.io/kubernetes/pkg/auth/handlers" + authhandlers "k8s.io/kubernetes/pkg/auth/handlers" "k8s.io/kubernetes/pkg/cloudprovider" + genericfilters "k8s.io/kubernetes/pkg/genericapiserver/filters" "k8s.io/kubernetes/pkg/genericapiserver/openapi/common" "k8s.io/kubernetes/pkg/genericapiserver/options" "k8s.io/kubernetes/pkg/genericapiserver/routes" genericvalidation "k8s.io/kubernetes/pkg/genericapiserver/validation" ipallocator "k8s.io/kubernetes/pkg/registry/core/service/ipallocator" "k8s.io/kubernetes/pkg/runtime" - "k8s.io/kubernetes/pkg/util" utilnet "k8s.io/kubernetes/pkg/util/net" ) @@ -333,18 +332,58 @@ func (c Config) New() (*GenericAPIServer, error) { }) } + if len(c.AuditLogPath) != 0 { + s.auditWriter = &lumberjack.Logger{ + Filename: c.AuditLogPath, + MaxAge: c.AuditLogMaxAge, + MaxBackups: c.AuditLogMaxBackups, + MaxSize: c.AuditLogMaxSize, + } + } + // Send correct mime type for .svg files. // TODO: remove when https://github.com/golang/go/commit/21e47d831bafb59f22b1ea8098f709677ec8ce33 // makes it into all of our supported go versions (only in v1.7.1 now). mime.AddExtensionType(".svg", "image/svg+xml") - // Register root handler. - // We do not register this using restful Webservice since we do not want to surface this in api docs. - // Allow GenericAPIServer to be embedded in contexts which already have something registered at the root + s.installAPI(&c) + s.Handler, s.InsecureHandler = s.buildHandlerChains(&c, http.Handler(s.Mux.BaseMux().(*http.ServeMux))) + + return s, nil +} + +func (s *GenericAPIServer) buildHandlerChains(c *Config, handler http.Handler) (secure http.Handler, insecure http.Handler) { + longRunningRE := regexp.MustCompile(c.LongRunningRequestRE) + longRunningFunc := genericfilters.BasicLongRunningRequestCheck(longRunningRE, map[string]string{"watch": "true"}) + + // filters which insecure and secure have in common + handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, "true") + + // insecure filters + insecure = handler + insecure = api.WithRequestContext(insecure, c.RequestContextMapper) + insecure = apiserver.RecoverPanics(insecure, s.NewRequestInfoResolver()) + insecure = genericfilters.WithTimeoutForNonLongRunningRequests(insecure, longRunningFunc) + + // secure filters + attributeGetter := apiserver.NewRequestAttributeGetter(c.RequestContextMapper, s.NewRequestInfoResolver()) + secure = handler + secure = apiserver.WithAuthorization(secure, attributeGetter, c.Authorizer) + secure = apiserver.WithImpersonation(secure, c.RequestContextMapper, c.Authorizer) + secure = audit.WithAudit(secure, attributeGetter, s.auditWriter) // before impersonation to read original user + secure = authhandlers.WithAuthentication(secure, c.RequestContextMapper, c.Authenticator, authhandlers.Unauthorized(c.SupportsBasicAuth)) + secure = api.WithRequestContext(secure, c.RequestContextMapper) + secure = apiserver.RecoverPanics(secure, s.NewRequestInfoResolver()) + secure = genericfilters.WithTimeoutForNonLongRunningRequests(secure, longRunningFunc) + secure = genericfilters.WithMaxInFlightLimit(secure, c.MaxRequestsInFlight, longRunningFunc) + + return +} + +func (s *GenericAPIServer) installAPI(c *Config) { if c.EnableIndex { routes.Index{}.Install(s.Mux, s.HandlerContainer) } - if c.EnableSwaggerSupport && c.EnableSwaggerUI { routes.SwaggerUI{}.Install(s.Mux, s.HandlerContainer) } @@ -354,79 +393,7 @@ func (c Config) New() (*GenericAPIServer, error) { if c.EnableVersion { routes.Version{}.Install(s.Mux, s.HandlerContainer) } - - handler := http.Handler(s.Mux.BaseMux().(*http.ServeMux)) - - // TODO: handle CORS and auth using go-restful - // See github.com/emicklei/go-restful/blob/master/examples/restful-CORS-filter.go, and - // github.com/emicklei/go-restful/blob/master/examples/restful-basic-authentication.go - - if len(c.CorsAllowedOriginList) > 0 { - allowedOriginRegexps, err := util.CompileRegexps(c.CorsAllowedOriginList) - if err != nil { - glog.Fatalf("Invalid CORS allowed origin, --cors-allowed-origins flag was set to %v - %v", strings.Join(c.CorsAllowedOriginList, ","), err) - } - handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true") - } - - s.InsecureHandler = handler - - attributeGetter := apiserver.NewRequestAttributeGetter(c.RequestContextMapper, s.NewRequestInfoResolver()) - handler = apiserver.WithAuthorizationCheck(handler, attributeGetter, c.Authorizer) - handler = apiserver.WithImpersonation(handler, c.RequestContextMapper, c.Authorizer) - if len(c.AuditLogPath) != 0 { - // audit handler must comes before the impersonationFilter to read the original user - writer := &lumberjack.Logger{ - Filename: c.AuditLogPath, - MaxAge: c.AuditLogMaxAge, - MaxBackups: c.AuditLogMaxBackups, - MaxSize: c.AuditLogMaxSize, - } - handler = audit.WithAudit(handler, attributeGetter, writer) - } - - // Install Authenticator - if c.Authenticator != nil { - authenticatedHandler, err := handlers.NewRequestAuthenticator(c.RequestContextMapper, c.Authenticator, handlers.Unauthorized(c.SupportsBasicAuth), handler) - if err != nil { - glog.Fatalf("Could not initialize authenticator: %v", err) - } - handler = authenticatedHandler - } - - handler, err := api.NewRequestContextFilter(c.RequestContextMapper, handler) - if err != nil { - glog.Fatalf("Could not initialize request context filter for s.Handler: %v", err) - } - - longRunningRE := regexp.MustCompile(c.LongRunningRequestRE) - longRunningRequestCheck := apiserver.BasicLongRunningRequestCheck(longRunningRE, map[string]string{"watch": "true"}) - longRunningTimeout := func(req *http.Request) (<-chan time.Time, string) { - // TODO unify this with apiserver.MaxInFlightLimit - if longRunningRequestCheck(req) { - return nil, "" - } - return time.After(globalTimeout), "" - } - handler = apiserver.TimeoutHandler(apiserver.RecoverPanics(handler, s.NewRequestInfoResolver()), longRunningTimeout) - - var inFlightTokens chan bool - if c.MaxRequestsInFlight > 0 { - inFlightTokens = make(chan bool, c.MaxRequestsInFlight) - } - handler = apiserver.MaxInFlightLimit(inFlightTokens, longRunningRequestCheck, handler) - s.Handler = handler - - handler, err = api.NewRequestContextFilter(c.RequestContextMapper, s.InsecureHandler) - if err != nil { - glog.Fatalf("Could not initialize request context filter for s.InsecureHandler: %v", err) - } - handler = apiserver.TimeoutHandler(apiserver.RecoverPanics(handler, s.NewRequestInfoResolver()), longRunningTimeout) - s.InsecureHandler = handler - - s.installGroupsDiscoveryHandler() - - return s, nil + s.HandlerContainer.Add(s.DynamicApisDiscovery()) } func DefaultAndValidateRunOptions(options *options.ServerRunOptions) { diff --git a/pkg/genericapiserver/filters/cors.go b/pkg/genericapiserver/filters/cors.go new file mode 100644 index 000000000000..ac904e221d07 --- /dev/null +++ b/pkg/genericapiserver/filters/cors.go @@ -0,0 +1,83 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + "regexp" + "strings" + + "github.com/golang/glog" + + "k8s.io/kubernetes/pkg/util" +) + +// TODO: use restful.CrossOriginResourceSharing +// See github.com/emicklei/go-restful/blob/master/examples/restful-CORS-filter.go, and +// github.com/emicklei/go-restful/blob/master/examples/restful-basic-authentication.go +// Or, for a more detailed implementation use https://github.com/martini-contrib/cors +// or implement CORS at your proxy layer. + +// WithCORS is a simple CORS implementation that wraps an http Handler. +// Pass nil for allowedMethods and allowedHeaders to use the defaults. If allowedOriginPatterns +// is empty or nil, no CORS support is installed. +func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { + if len(allowedOriginPatterns) == 0 { + return handler + } + allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns) + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + origin := req.Header.Get("Origin") + if origin != "" { + allowed := false + for _, re := range allowedOriginPatternsREs { + if allowed = re.MatchString(origin); allowed { + break + } + } + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + // Set defaults for methods and headers if nothing was passed + if allowedMethods == nil { + allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"} + } + if allowedHeaders == nil { + allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} + } + w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) + w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) + + // Stop here if its a preflight OPTIONS request + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + } + } + // Dispatch to the next handler + handler.ServeHTTP(w, req) + }) +} + +func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp { + res, err := util.CompileRegexps(allowedOrigins) + if err != nil { + glog.Fatalf("Invalid CORS allowed origin, --cors-allowed-origins flag was set to %v - %v", strings.Join(allowedOrigins, ","), err) + } + return res +} diff --git a/pkg/genericapiserver/filters/cors_test.go b/pkg/genericapiserver/filters/cors_test.go new file mode 100644 index 000000000000..d62e83d51be5 --- /dev/null +++ b/pkg/genericapiserver/filters/cors_test.go @@ -0,0 +1,93 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestCORSAllowedOrigins(t *testing.T) { + table := []struct { + allowedOrigins []string + origin string + allowed bool + }{ + {[]string{}, "example.com", false}, + {[]string{"example.com"}, "example.com", true}, + {[]string{"example.com"}, "not-allowed.com", false}, + {[]string{"not-matching.com", "example.com"}, "example.com", true}, + {[]string{".*"}, "example.com", true}, + } + + for _, item := range table { + handler := WithCORS( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}), + item.allowedOrigins, nil, nil, "true", + ) + server := httptest.NewServer(handler) + defer server.Close() + client := http.Client{} + + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", item.origin) + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if item.allowed { + if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { + t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) + } + + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } + } else { + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } + } + } +} diff --git a/pkg/genericapiserver/filters/doc.go b/pkg/genericapiserver/filters/doc.go new file mode 100644 index 000000000000..3fbe2978e10f --- /dev/null +++ b/pkg/genericapiserver/filters/doc.go @@ -0,0 +1,19 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package filters contains all the http handler chain filters which +// are not api related. +package filters // import "k8s.io/kubernetes/pkg/genericapiserver/filters" diff --git a/pkg/genericapiserver/filters/longrunning.go b/pkg/genericapiserver/filters/longrunning.go new file mode 100644 index 000000000000..7e26dccee4af --- /dev/null +++ b/pkg/genericapiserver/filters/longrunning.go @@ -0,0 +1,46 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + "regexp" + "strings" +) + +// LongRunningRequestCheck is a predicate which is true for long-running http requests. +type LongRunningRequestCheck func(r *http.Request) bool + +// BasicLongRunningRequestCheck pathRegex operates against the url path, the queryParams match is case insensitive. +// Any one match flags the request. +// TODO tighten this check to eliminate the abuse potential by malicious clients that start setting queryParameters +// to bypass the rate limitter. This could be done using a full parse and special casing the bits we need. +func BasicLongRunningRequestCheck(pathRegex *regexp.Regexp, queryParams map[string]string) LongRunningRequestCheck { + return func(r *http.Request) bool { + if pathRegex.MatchString(r.URL.Path) { + return true + } + + for key, expectedValue := range queryParams { + if strings.ToLower(expectedValue) == strings.ToLower(r.URL.Query().Get(key)) { + return true + } + } + + return false + } +} diff --git a/pkg/genericapiserver/filters/maxinflight.go b/pkg/genericapiserver/filters/maxinflight.go new file mode 100644 index 000000000000..12f3cfc176ba --- /dev/null +++ b/pkg/genericapiserver/filters/maxinflight.go @@ -0,0 +1,60 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + + "k8s.io/kubernetes/pkg/api/errors" + "k8s.io/kubernetes/pkg/httplog" +) + +// Constant for the retry-after interval on rate limiting. +// TODO: maybe make this dynamic? or user-adjustable? +const retryAfter = "1" + +// WithMaxInFlightLimit limits the number of in-flight requests to buffer size of the passed in channel. +func WithMaxInFlightLimit(handler http.Handler, limit int, longRunningRequestCheck LongRunningRequestCheck) http.Handler { + if limit == 0 { + return handler + } + c := make(chan bool, limit) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if longRunningRequestCheck(r) { + // Skip tracking long running events. + handler.ServeHTTP(w, r) + return + } + select { + case c <- true: + defer func() { <-c }() + handler.ServeHTTP(w, r) + default: + tooManyRequests(r, w) + } + }) +} + +func tooManyRequests(req *http.Request, w http.ResponseWriter) { + // "Too Many Requests" response is returned before logger is setup for the request. + // So we need to explicitly log it here. + defer httplog.NewLogged(req, &w).Log() + + // Return a 429 status indicating "Too Many Requests" + w.Header().Set("Retry-After", retryAfter) + http.Error(w, "Too many requests, please try again later.", errors.StatusTooManyRequests) +} diff --git a/pkg/genericapiserver/filters/maxinflight_test.go b/pkg/genericapiserver/filters/maxinflight_test.go new file mode 100644 index 000000000000..d5b2bed5bedd --- /dev/null +++ b/pkg/genericapiserver/filters/maxinflight_test.go @@ -0,0 +1,143 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "sync" + "testing" + + "k8s.io/kubernetes/pkg/api/errors" +) + +// Tests that MaxInFlightLimit works, i.e. +// - "long" requests such as proxy or watch, identified by regexp are not accounted despite +// hanging for the long time, +// - "short" requests are correctly accounted, i.e. there can be only size of channel passed to the +// constructor in flight at any given moment, +// - subsequent "short" requests are rejected instantly with appropriate error, +// - subsequent "long" requests are handled normally, +// - we correctly recover after some "short" requests finish, i.e. we can process new ones. +func TestMaxInFlight(t *testing.T) { + const AllowedInflightRequestsNo = 3 + + // notAccountedPathsRegexp specifies paths requests to which we don't account into + // requests in flight. + notAccountedPathsRegexp := regexp.MustCompile(".*\\/watch") + longRunningRequestCheck := BasicLongRunningRequestCheck(notAccountedPathsRegexp, map[string]string{"watch": "true"}) + + // Calls is used to wait until all server calls are received. We are sending + // AllowedInflightRequestsNo of 'long' not-accounted requests and the same number of + // 'short' accounted ones. + calls := &sync.WaitGroup{} + calls.Add(AllowedInflightRequestsNo * 2) + + // Responses is used to wait until all responses are + // received. This prevents some async requests getting EOF + // errors from prematurely closing the server + responses := sync.WaitGroup{} + responses.Add(AllowedInflightRequestsNo * 2) + + // Block is used to keep requests in flight for as long as we need to. All requests will + // be unblocked at the same time. + block := sync.WaitGroup{} + block.Add(1) + + server := httptest.NewServer( + WithMaxInFlightLimit( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // A short, accounted request that does not wait for block WaitGroup. + if strings.Contains(r.URL.Path, "dontwait") { + return + } + if calls != nil { + calls.Done() + } + block.Wait() + }), + AllowedInflightRequestsNo, + longRunningRequestCheck, + ), + ) + defer server.Close() + + // These should hang, but not affect accounting. use a query param match + for i := 0; i < AllowedInflightRequestsNo; i++ { + // These should hang waiting on block... + go func() { + if err := expectHTTP(server.URL+"/foo/bar?watch=true", http.StatusOK); err != nil { + t.Error(err) + } + responses.Done() + }() + } + // Check that sever is not saturated by not-accounted calls + if err := expectHTTP(server.URL+"/dontwait", http.StatusOK); err != nil { + t.Error(err) + } + + // These should hang and be accounted, i.e. saturate the server + for i := 0; i < AllowedInflightRequestsNo; i++ { + // These should hang waiting on block... + go func() { + if err := expectHTTP(server.URL, http.StatusOK); err != nil { + t.Error(err) + } + responses.Done() + }() + } + // We wait for all calls to be received by the server + calls.Wait() + // Disable calls notifications in the server + calls = nil + + // Do this multiple times to show that it rate limit rejected requests don't block. + for i := 0; i < 2; i++ { + if err := expectHTTP(server.URL, errors.StatusTooManyRequests); err != nil { + t.Error(err) + } + } + // Validate that non-accounted URLs still work. use a path regex match + if err := expectHTTP(server.URL+"/dontwait/watch", http.StatusOK); err != nil { + t.Error(err) + } + + // Let all hanging requests finish + block.Done() + + // Show that we recover from being blocked up. + // Too avoid flakyness we need to wait until at least one of the requests really finishes. + responses.Wait() + if err := expectHTTP(server.URL, http.StatusOK); err != nil { + t.Error(err) + } +} + +func expectHTTP(url string, code int) error { + r, err := http.Get(url) + if err != nil { + return fmt.Errorf("unexpected error: %v", err) + } + if r.StatusCode != code { + return fmt.Errorf("unexpected response: %v", r.StatusCode) + } + return nil +} diff --git a/pkg/genericapiserver/filters/timeout.go b/pkg/genericapiserver/filters/timeout.go new file mode 100644 index 000000000000..4315af95e1e4 --- /dev/null +++ b/pkg/genericapiserver/filters/timeout.go @@ -0,0 +1,259 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "net/http" + "sync" + "time" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/api/errors" +) + +const globalTimeout = time.Minute + +var errConnKilled = fmt.Errorf("kill connection/stream") + +// WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by globalTimeout. +func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning LongRunningRequestCheck) http.Handler { + timeoutFunc := func(req *http.Request) (<-chan time.Time, string) { + // TODO unify this with apiserver.MaxInFlightLimit + if longRunning(req) { + return nil, "" + } + return time.After(globalTimeout), "" + } + return WithTimeout(handler, timeoutFunc) +} + +// WithTimeout returns an http.Handler that runs h with a timeout +// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle +// each request, but if a call runs for longer than its time limit, the +// handler responds with a 503 Service Unavailable error and the message +// provided. (If msg is empty, a suitable default message will be sent.) After +// the handler times out, writes by h to its http.ResponseWriter will return +// http.ErrHandlerTimeout. If timeoutFunc returns a nil timeout channel, no +// timeout will be enforced. +func WithTimeout(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, msg string)) http.Handler { + return &timeoutHandler{h, timeoutFunc} +} + +type timeoutHandler struct { + handler http.Handler + timeout func(*http.Request) (<-chan time.Time, string) +} + +func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + after, msg := t.timeout(r) + if after == nil { + t.handler.ServeHTTP(w, r) + return + } + + done := make(chan struct{}) + tw := newTimeoutWriter(w) + go func() { + t.handler.ServeHTTP(tw, r) + close(done) + }() + select { + case <-done: + return + case <-after: + tw.timeout(msg) + } +} + +type timeoutWriter interface { + http.ResponseWriter + timeout(string) +} + +func newTimeoutWriter(w http.ResponseWriter) timeoutWriter { + base := &baseTimeoutWriter{w: w} + + _, notifiable := w.(http.CloseNotifier) + _, hijackable := w.(http.Hijacker) + + switch { + case notifiable && hijackable: + return &closeHijackTimeoutWriter{base} + case notifiable: + return &closeTimeoutWriter{base} + case hijackable: + return &hijackTimeoutWriter{base} + default: + return base + } +} + +type baseTimeoutWriter struct { + w http.ResponseWriter + + mu sync.Mutex + // if the timeout handler has timedout + timedOut bool + // if this timeout writer has wrote header + wroteHeader bool + // if this timeout writer has been hijacked + hijacked bool +} + +func (tw *baseTimeoutWriter) Header() http.Header { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return http.Header{} + } + + return tw.w.Header() +} + +func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return 0, http.ErrHandlerTimeout + } + if tw.hijacked { + return 0, http.ErrHijacked + } + + tw.wroteHeader = true + return tw.w.Write(p) +} + +func (tw *baseTimeoutWriter) Flush() { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return + } + + if flusher, ok := tw.w.(http.Flusher); ok { + flusher.Flush() + } +} + +func (tw *baseTimeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut || tw.wroteHeader || tw.hijacked { + return + } + + tw.wroteHeader = true + tw.w.WriteHeader(code) +} + +func (tw *baseTimeoutWriter) timeout(msg string) { + tw.mu.Lock() + defer tw.mu.Unlock() + + tw.timedOut = true + + // The timeout writer has not been used by the inner handler. + // We can safely timeout the HTTP request by sending by a timeout + // handler + if !tw.wroteHeader && !tw.hijacked { + tw.w.WriteHeader(http.StatusGatewayTimeout) + if msg != "" { + tw.w.Write([]byte(msg)) + } else { + enc := json.NewEncoder(tw.w) + enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0)) + } + } else { + // The timeout writer has been used by the inner handler. There is + // no way to timeout the HTTP request at the point. We have to shutdown + // the connection for HTTP1 or reset stream for HTTP2. + // + // Note from: Brad Fitzpatrick + // if the ServeHTTP goroutine panics, that will do the best possible thing for both + // HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and + // you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP + // connection immediately without a proper 0-byte EOF chunk, so the peer will recognize + // the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving + // the TCP connection open, but resetting the stream to the peer so it'll have an error, + // like the HTTP/1 case. + panic(errConnKilled) + } +} + +func (tw *baseTimeoutWriter) closeNotify() <-chan bool { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + done := make(chan bool) + close(done) + return done + } + + return tw.w.(http.CloseNotifier).CloseNotify() +} + +func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return nil, nil, http.ErrHandlerTimeout + } + conn, rw, err := tw.w.(http.Hijacker).Hijack() + if err == nil { + tw.hijacked = true + } + return conn, rw, err +} + +type closeTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *closeTimeoutWriter) CloseNotify() <-chan bool { + return tw.closeNotify() +} + +type hijackTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *hijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return tw.hijack() +} + +type closeHijackTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *closeHijackTimeoutWriter) CloseNotify() <-chan bool { + return tw.closeNotify() +} + +func (tw *closeHijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return tw.hijack() +} diff --git a/pkg/genericapiserver/filters/timeout_test.go b/pkg/genericapiserver/filters/timeout_test.go new file mode 100644 index 000000000000..989ce331c2fd --- /dev/null +++ b/pkg/genericapiserver/filters/timeout_test.go @@ -0,0 +1,81 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestTimeout(t *testing.T) { + sendResponse := make(chan struct{}, 1) + writeErrors := make(chan error, 1) + timeout := make(chan time.Time, 1) + resp := "test response" + timeoutResp := "test timeout" + + ts := httptest.NewServer(WithTimeout(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + <-sendResponse + _, err := w.Write([]byte(resp)) + writeErrors <- err + }), + func(*http.Request) (<-chan time.Time, string) { + return timeout, timeoutResp + })) + defer ts.Close() + + // No timeouts + sendResponse <- struct{}{} + res, err := http.Get(ts.URL) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusOK) + } + body, _ := ioutil.ReadAll(res.Body) + if string(body) != resp { + t.Errorf("got body %q; expected %q", string(body), resp) + } + if err := <-writeErrors; err != nil { + t.Errorf("got unexpected Write error on first request: %v", err) + } + + // Times out + timeout <- time.Time{} + res, err = http.Get(ts.URL) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusGatewayTimeout { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) + } + body, _ = ioutil.ReadAll(res.Body) + if string(body) != timeoutResp { + t.Errorf("got body %q; expected %q", string(body), timeoutResp) + } + + // Now try to send a response + sendResponse <- struct{}{} + if err := <-writeErrors; err != http.ErrHandlerTimeout { + t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout) + } +} diff --git a/pkg/genericapiserver/genericapiserver.go b/pkg/genericapiserver/genericapiserver.go index 4b61f5cc2ae2..521d8bdd7345 100644 --- a/pkg/genericapiserver/genericapiserver.go +++ b/pkg/genericapiserver/genericapiserver.go @@ -19,6 +19,7 @@ package genericapiserver import ( "crypto/tls" "fmt" + "io" "net" "net/http" "path" @@ -31,9 +32,9 @@ import ( systemd "github.com/coreos/go-systemd/daemon" "github.com/emicklei/go-restful" "github.com/emicklei/go-restful/swagger" + "github.com/go-openapi/spec" "github.com/golang/glog" - "github.com/go-openapi/spec" "k8s.io/kubernetes/pkg/admission" "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/rest" @@ -51,8 +52,6 @@ import ( "k8s.io/kubernetes/pkg/util/sets" ) -const globalTimeout = time.Minute - // Info about an API group. type APIGroupInfo struct { GroupMeta apimachinery.GroupMeta @@ -159,6 +158,7 @@ type GenericAPIServer struct { enableOpenAPISupport bool openAPIInfo spec.Info openAPIDefaultResponse spec.Response + openAPIDefinitions *common.OpenAPIDefinitions // PostStartHooks are each called after the server has started listening, in a separate go func for each // with no guaranteee of ordering between them. The map key is a name used for error reporting. @@ -166,7 +166,9 @@ type GenericAPIServer struct { postStartHooks map[string]PostStartHookFunc postStartHookLock sync.Mutex postStartHooksCalled bool - openAPIDefinitions *common.OpenAPIDefinitions + + // Writer to write the audit log to. + auditWriter io.Writer } // RequestContextMapper is exposed so that third party resource storage can be build in a different location. @@ -215,32 +217,8 @@ func NewHandlerContainer(mux *http.ServeMux, s runtime.NegotiatedSerializer) *re return container } -// Installs handler at /apis to list all group versions for discovery -func (s *GenericAPIServer) installGroupsDiscoveryHandler() { - apiserver.AddApisWebService(s.Serializer, s.HandlerContainer, s.apiPrefix, func(req *restful.Request) []unversioned.APIGroup { - s.apiGroupsForDiscoveryLock.RLock() - defer s.apiGroupsForDiscoveryLock.RUnlock() - - // Return the list of supported groups in sorted order (to have a deterministic order). - groups := []unversioned.APIGroup{} - groupNames := make([]string, len(s.apiGroupsForDiscovery)) - var i int = 0 - for groupName := range s.apiGroupsForDiscovery { - groupNames[i] = groupName - i++ - } - sort.Strings(groupNames) - for _, groupName := range groupNames { - apiGroup := s.apiGroupsForDiscovery[groupName] - // Add ServerAddressByClientCIDRs. - apiGroup.ServerAddressByClientCIDRs = s.getServerAddressByClientCIDRs(req.Request) - groups = append(groups, apiGroup) - } - return groups - }) -} - func (s *GenericAPIServer) Run(options *options.ServerRunOptions) { + // install APIs which depend on other APIs to be installed if s.enableSwaggerSupport { s.InstallSwaggerAPI() } @@ -415,7 +393,7 @@ func (s *GenericAPIServer) InstallAPIGroup(apiGroupInfo *APIGroupInfo) error { } s.AddAPIGroupForDiscovery(apiGroup) - apiserver.AddGroupWebService(s.Serializer, s.HandlerContainer, apiPrefix+"/"+apiGroup.Name, apiGroup) + s.HandlerContainer.Add(apiserver.NewGroupWebService(s.Serializer, apiPrefix+"/"+apiGroup.Name, apiGroup)) } apiserver.InstallServiceErrorHandler(s.Serializer, s.HandlerContainer, s.NewRequestInfoResolver(), apiVersions) return nil @@ -438,8 +416,7 @@ func (s *GenericAPIServer) RemoveAPIGroupForDiscovery(groupName string) { func (s *GenericAPIServer) getServerAddressByClientCIDRs(req *http.Request) []unversioned.ServerAddressByClientCIDR { addressCIDRMap := []unversioned.ServerAddressByClientCIDR{ { - ClientCIDR: "0.0.0.0/0", - + ClientCIDR: "0.0.0.0/0", ServerAddress: s.ExternalAddress, }, } @@ -556,6 +533,34 @@ func (s *GenericAPIServer) InstallOpenAPI() { } } +// DynamicApisDiscovery returns a webservice serving api group discovery. +// Note: during the server runtime apiGroupsForDiscovery might change. +func (s *GenericAPIServer) DynamicApisDiscovery() *restful.WebService { + return apiserver.NewApisWebService(s.Serializer, s.apiPrefix, func(req *restful.Request) []unversioned.APIGroup { + s.apiGroupsForDiscoveryLock.RLock() + defer s.apiGroupsForDiscoveryLock.RUnlock() + + // sort to have a deterministic order + sortedGroups := []unversioned.APIGroup{} + groupNames := make([]string, 0, len(s.apiGroupsForDiscovery)) + for groupName := range s.apiGroupsForDiscovery { + groupNames = append(groupNames, groupName) + } + sort.Strings(groupNames) + for _, groupName := range groupNames { + sortedGroups = append(sortedGroups, s.apiGroupsForDiscovery[groupName]) + } + + serverCIDR := s.getServerAddressByClientCIDRs(req.Request) + groups := make([]unversioned.APIGroup, len(sortedGroups)) + for i := range sortedGroups { + groups[i] = sortedGroups[i] + groups[i].ServerAddressByClientCIDRs = serverCIDR + } + return groups + }) +} + // NewDefaultAPIGroupInfo returns an APIGroupInfo stubbed with "normal" values // exposed for easier composition from other packages func NewDefaultAPIGroupInfo(group string) APIGroupInfo { diff --git a/pkg/genericapiserver/routes/index.go b/pkg/genericapiserver/routes/index.go index 986024cf3d21..f3280d5b040d 100644 --- a/pkg/genericapiserver/routes/index.go +++ b/pkg/genericapiserver/routes/index.go @@ -26,9 +26,12 @@ import ( "k8s.io/kubernetes/pkg/apiserver" ) +// Index provides a webservice for the http root / listing all known paths. type Index struct{} +// Install adds the Index webservice to the given mux. func (i Index) Install(mux *apiserver.PathRecorderMux, c *restful.Container) { + // do not register this using restful Webservice since we do not want to surface this in api docs. mux.BaseMux().HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { status := http.StatusOK if r.URL.Path != "/" && r.URL.Path != "/index.html" { diff --git a/pkg/genericapiserver/routes/profiling.go b/pkg/genericapiserver/routes/profiling.go index b00e61ca6894..fd3dc369a124 100644 --- a/pkg/genericapiserver/routes/profiling.go +++ b/pkg/genericapiserver/routes/profiling.go @@ -26,6 +26,7 @@ import ( // Profiling adds handlers for pprof under /debug/pprof. type Profiling struct{} +// Install adds the Profiling webservice to the given mux. func (d Profiling) Install(mux *apiserver.PathRecorderMux, c *restful.Container) { mux.BaseMux().HandleFunc("/debug/pprof/", pprof.Index) mux.BaseMux().HandleFunc("/debug/pprof/profile", pprof.Profile) diff --git a/pkg/genericapiserver/routes/swaggerui.go b/pkg/genericapiserver/routes/swaggerui.go index ec529f535029..94349f3cdd32 100644 --- a/pkg/genericapiserver/routes/swaggerui.go +++ b/pkg/genericapiserver/routes/swaggerui.go @@ -29,6 +29,7 @@ import ( // SwaggerUI exposes files in third_party/swagger-ui/ under /swagger-ui. type SwaggerUI struct{} +// Install adds the SwaggerUI webservice to the given mux. func (l SwaggerUI) Install(mux *apiserver.PathRecorderMux, c *restful.Container) { fileServer := http.FileServer(&assetfs.AssetFS{ Asset: swagger.Asset, diff --git a/pkg/genericapiserver/routes/version.go b/pkg/genericapiserver/routes/version.go index 6d0db2be1789..b4ead0f27f03 100644 --- a/pkg/genericapiserver/routes/version.go +++ b/pkg/genericapiserver/routes/version.go @@ -25,9 +25,10 @@ import ( "k8s.io/kubernetes/pkg/version" ) +// Version provides a webservice with version information. type Version struct{} -// InstallVersionHandler registers the APIServer's `/version` handler +// Install registers the APIServer's `/version` handler. func (v Version) Install(mux *apiserver.PathRecorderMux, c *restful.Container) { // Set up a service to return the git code version. versionWS := new(restful.WebService) diff --git a/pkg/master/master.go b/pkg/master/master.go index ff0afd7e0149..aaafe8920eec 100644 --- a/pkg/master/master.go +++ b/pkg/master/master.go @@ -744,7 +744,7 @@ func (m *Master) InstallThirdPartyResource(rsrc *extensions.ThirdPartyResource) if err := thirdparty.InstallREST(m.HandlerContainer); err != nil { glog.Errorf("Unable to setup thirdparty api: %v", err) } - apiserver.AddGroupWebService(api.Codecs, m.HandlerContainer, path, apiGroup) + m.HandlerContainer.Add(apiserver.NewGroupWebService(api.Codecs, path, apiGroup)) m.addThirdPartyResourceStorage(path, plural.Resource, thirdparty.Storage[plural.Resource].(*thirdpartyresourcedataetcd.REST), apiGroup) apiserver.InstallServiceErrorHandler(api.Codecs, m.HandlerContainer, m.NewRequestInfoResolver(), []string{thirdparty.GroupVersion.String()})