Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated cherry pick of #85410: fix potential memory leak issue in processing watch request #94351

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go
Expand Up @@ -64,6 +64,8 @@ func (w *realTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) {
// serveWatch will serve a watch response.
// TODO: the functionality in this method and in WatchServer.Serve is not cleanly decoupled.
func serveWatch(watcher watch.Interface, scope *RequestScope, mediaTypeOptions negotiation.MediaTypeOptions, req *http.Request, w http.ResponseWriter, timeout time.Duration) {
defer watcher.Stop()

options, err := optionsForTransform(mediaTypeOptions, req)
if err != nil {
scope.err(err, w, req)
Expand Down Expand Up @@ -201,7 +203,6 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// ensure the connection times out
timeoutCh, cleanup := s.TimeoutFactory.TimeoutCh()
defer cleanup()
defer s.Watching.Stop()

// begin the stream
w.Header().Set("Content-Type", s.MediaType)
Expand Down Expand Up @@ -286,8 +287,6 @@ func (s *WatchServer) HandleWS(ws *websocket.Conn) {
streamBuf := &bytes.Buffer{}
ch := s.Watching.ResultChan()

defer s.Watching.Stop()

for {
select {
case <-done:
Expand Down
92 changes: 82 additions & 10 deletions staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go
Expand Up @@ -43,6 +43,7 @@ import (
"k8s.io/apimachinery/pkg/watch"
example "k8s.io/apiserver/pkg/apis/example"
"k8s.io/apiserver/pkg/endpoints/handlers"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
apitesting "k8s.io/apiserver/pkg/endpoints/testing"
"k8s.io/apiserver/pkg/registry/rest"
"k8s.io/client-go/dynamic"
Expand Down Expand Up @@ -565,6 +566,21 @@ func (t *fakeTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) {
}
}

// serveWatch will serve a watch response according to the watcher and watchServer.
// Before watchServer.ServeHTTP, an error may occur like k8s.io/apiserver/pkg/endpoints/handlers/watch.go#serveWatch does.
func serveWatch(watcher watch.Interface, watchServer *handlers.WatchServer, preServeErr error) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
defer watcher.Stop()

if preServeErr != nil {
responsewriters.ErrorNegotiated(preServeErr, watchServer.Scope.Serializer, watchServer.Scope.Kind.GroupVersion(), w, req)
return
}

watchServer.ServeHTTP(w, req)
}
}

func TestWatchHTTPErrors(t *testing.T) {
watcher := watch.NewFake()
timeoutCh := make(chan time.Time)
Expand All @@ -590,9 +606,7 @@ func TestWatchHTTPErrors(t *testing.T) {
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
}

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
watchServer.ServeHTTP(w, req)
}))
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
defer s.Close()

// Setup a client
Expand Down Expand Up @@ -629,6 +643,68 @@ func TestWatchHTTPErrors(t *testing.T) {
}
}

func TestWatchHTTPErrorsBeforeServe(t *testing.T) {
watcher := watch.NewFake()
timeoutCh := make(chan time.Time)
done := make(chan struct{})

info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON)
if !ok || info.StreamSerializer == nil {
t.Fatal(info)
}
serializer := info.StreamSerializer

// Setup a new watchserver
watchServer := &handlers.WatchServer{
Scope: &handlers.RequestScope{
Serializer: runtime.NewSimpleNegotiatedSerializer(info),
Kind: testGroupVersion.WithKind("test"),
},
Watching: watcher,

MediaType: "testcase/json",
Framer: serializer.Framer,
Encoder: newCodec,
EmbeddedEncoder: newCodec,

Fixup: func(obj runtime.Object) runtime.Object { return obj },
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
}

errStatus := errors.NewInternalError(fmt.Errorf("we got an error"))

s := httptest.NewServer(serveWatch(watcher, watchServer, errStatus))
defer s.Close()

// Setup a client
dest, _ := url.Parse(s.URL)
dest.Path = "/" + prefix + "/" + newGroupVersion.Group + "/" + newGroupVersion.Version + "/simple"
dest.RawQuery = "watch=true"

req, _ := http.NewRequest("GET", dest.String(), nil)
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// We had already got an error before watch serve started
decoder := json.NewDecoder(resp.Body)
var status *metav1.Status
err = decoder.Decode(&status)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if status.Kind != "Status" || status.APIVersion != "v1" || status.Code != 500 || status.Status != "Failure" || !strings.Contains(status.Message, "we got an error") {
t.Fatalf("error: %#v", status)
}

// check for leaks
if !watcher.IsStopped() {
t.Errorf("Leaked watcher goruntine after request done")
}
}

func TestWatchHTTPDynamicClientErrors(t *testing.T) {
watcher := watch.NewFake()
timeoutCh := make(chan time.Time)
Expand All @@ -654,9 +730,7 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) {
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
}

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
watchServer.ServeHTTP(w, req)
}))
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
defer s.Close()
defer s.CloseClientConnections()

Expand Down Expand Up @@ -699,9 +773,7 @@ func TestWatchHTTPTimeout(t *testing.T) {
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
}

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
watchServer.ServeHTTP(w, req)
}))
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
defer s.Close()

// Setup a client
Expand Down Expand Up @@ -729,7 +801,7 @@ func TestWatchHTTPTimeout(t *testing.T) {
close(timeoutCh)
select {
case <-done:
if !watcher.Stopped {
if !watcher.IsStopped() {
t.Errorf("Leaked watch on timeout")
}
case <-time.After(wait.ForeverTestTimeout):
Expand Down