From 91d2a4b932d52742bd3d15a69f460bcf53cb115d Mon Sep 17 00:00:00 2001 From: wangjianyu Date: Tue, 19 Jul 2022 21:38:32 +0800 Subject: [PATCH] koordlet: optimize auditor UT with httptest.Server (#382) Signed-off-by: wangjianyu --- pkg/koordlet/audit/auditor.go | 2 +- pkg/koordlet/audit/auditor_test.go | 78 ++++++++---------------------- 2 files changed, 20 insertions(+), 60 deletions(-) diff --git a/pkg/koordlet/audit/auditor.go b/pkg/koordlet/audit/auditor.go index cb5bd1e41..0192d101d 100644 --- a/pkg/koordlet/audit/auditor.go +++ b/pkg/koordlet/audit/auditor.go @@ -159,7 +159,7 @@ func (a *auditor) HttpHandler() func(http.ResponseWriter, *http.Request) { } else { activeReader = a.findActiveReader(pageToken) if activeReader == nil { - http.Error(rw, fmt.Sprintf("reader %v is existed", pageToken), http.StatusConflict) + http.Error(rw, fmt.Sprintf("invalid pageToken %s", pageToken), http.StatusConflict) return } } diff --git a/pkg/koordlet/audit/auditor_test.go b/pkg/koordlet/audit/auditor_test.go index 7562f1496..7f18d05cb 100644 --- a/pkg/koordlet/audit/auditor_test.go +++ b/pkg/koordlet/audit/auditor_test.go @@ -18,50 +18,23 @@ package audit import ( "bytes" - "context" "encoding/json" "fmt" "io/ioutil" - "net" "net/http" + "net/http/httptest" "testing" "time" ) -type TestServer struct { - l net.Listener - server *http.Server -} - -func (t *TestServer) Serve() { - t.server.Serve(t.l) -} - -func (t *TestServer) Shutdown() error { - t.l.Close() - return t.server.Shutdown(context.TODO()) -} - -func (t *TestServer) URL(size int, pageToken string) string { - url := fmt.Sprintf("http://127.0.0.1:%d?size=%d", t.l.Addr().(*net.TCPAddr).Port, size) +func makeRequestUrl(size int, serverUrl, pageToken string) string { + url := fmt.Sprintf("%s?size=%d", serverUrl, size) if pageToken != "" { url += fmt.Sprintf("&pageToken=%s", pageToken) } return url } -func mustCreateHttpServer(t *testing.T, handler http.Handler) *TestServer { - l, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - server := &http.Server{Handler: handler} - return &TestServer{ - l: l, - server: server, - } -} - func TestAuditorLogger(t *testing.T) { tempDir := t.TempDir() @@ -76,14 +49,11 @@ func TestAuditorLogger(t *testing.T) { } logger.Flush() - server := mustCreateHttpServer(t, http.HandlerFunc(auditor.HttpHandler())) - defer server.Shutdown() - go func() { - server.Serve() - }() + server := httptest.NewServer(http.HandlerFunc(auditor.HttpHandler())) + defer server.Close() client := http.Client{} - req, _ := http.NewRequest("GET", server.URL(10, ""), nil) + req, _ := http.NewRequest("GET", makeRequestUrl(10, server.URL, ""), nil) req.Header.Add("Accept", "application/json") resp, err := client.Do(req) @@ -103,7 +73,7 @@ func TestAuditorLogger(t *testing.T) { } // continue read logs - req, _ = http.NewRequest("GET", server.URL(1, response.NextPageToken), nil) + req, _ = http.NewRequest("GET", makeRequestUrl(1, server.URL, response.NextPageToken), nil) req.Header.Add("Accept", "application/json") resp, err = client.Do(req) if err != nil { @@ -127,7 +97,7 @@ func TestAuditorLogger(t *testing.T) { count := 0 stepSize := 5 for { - req, _ = http.NewRequest("GET", server.URL(stepSize, response.NextPageToken), nil) + req, _ = http.NewRequest("GET", makeRequestUrl(stepSize, server.URL, response.NextPageToken), nil) req.Header.Add("Accept", "application/json") resp, err = client.Do(req) if err != nil { @@ -149,7 +119,6 @@ func TestAuditorLogger(t *testing.T) { t.Errorf("failed to read to the end, expected %v actual %v", len(blocks)-11, count) } }() - } func TestAuditorLoggerTxtOutput(t *testing.T) { @@ -166,14 +135,11 @@ func TestAuditorLoggerTxtOutput(t *testing.T) { } logger.Flush() - server := mustCreateHttpServer(t, http.HandlerFunc(auditor.HttpHandler())) - defer server.Shutdown() - go func() { - server.Serve() - }() + server := httptest.NewServer(http.HandlerFunc(auditor.HttpHandler())) + defer server.Close() client := http.Client{} - req, _ := http.NewRequest("GET", server.URL(10, ""), nil) + req, _ := http.NewRequest("GET", makeRequestUrl(10, server.URL, ""), nil) resp, err := client.Do(req) if err != nil { t.Fatalf("failed to get events: %v", err) @@ -207,14 +173,11 @@ func TestAuditorLoggerReaderInvalidPageToken(t *testing.T) { } logger.Flush() - server := mustCreateHttpServer(t, http.HandlerFunc(auditor.HttpHandler())) - defer server.Shutdown() - go func() { - server.Serve() - }() + server := httptest.NewServer(http.HandlerFunc(auditor.HttpHandler())) + defer server.Close() client := http.Client{} - req, _ := http.NewRequest("GET", server.URL(10, ""), nil) + req, _ := http.NewRequest("GET", makeRequestUrl(10, server.URL, ""), nil) resp, err := client.Do(req) if err != nil { t.Fatalf("failed to get events: %v", err) @@ -238,7 +201,7 @@ func TestAuditorLoggerReaderInvalidPageToken(t *testing.T) { time.Sleep(time.Second) // request with expired token - req, _ = http.NewRequest("GET", server.URL(10, nextPageTokens[0]), nil) + req, _ = http.NewRequest("GET", makeRequestUrl(10, server.URL, nextPageTokens[0]), nil) resp, err = client.Do(req) if err != nil { t.Fatalf("failed to get events: %v", err) @@ -248,7 +211,7 @@ func TestAuditorLoggerReaderInvalidPageToken(t *testing.T) { } // request with not exists token - req, _ = http.NewRequest("GET", server.URL(10, "not-exists-token"), nil) + req, _ = http.NewRequest("GET", makeRequestUrl(10, server.URL, "not-exists-token"), nil) resp, err = client.Do(req) if err != nil { t.Fatalf("failed to get events: %v", err) @@ -272,16 +235,13 @@ func TestAuditorLoggerMaxActiveReaders(t *testing.T) { } logger.Flush() - server := mustCreateHttpServer(t, http.HandlerFunc(ad.HttpHandler())) - defer server.Shutdown() - go func() { - server.Serve() - }() + server := httptest.NewServer(http.HandlerFunc(ad.HttpHandler())) + defer server.Close() client := http.Client{} for i := 0; i < c.MaxConcurrentReaders+5; i++ { - req, _ := http.NewRequest("GET", server.URL(10, ""), nil) + req, _ := http.NewRequest("GET", makeRequestUrl(10, server.URL, ""), nil) resp, err := client.Do(req) if err != nil { t.Fatalf("failed to get events: %v", err)