diff --git a/internal/frontend/badge_test.go b/internal/frontend/badge_test.go index f90bf7188..def3ef9a9 100644 --- a/internal/frontend/badge_test.go +++ b/internal/frontend/badge_test.go @@ -11,7 +11,7 @@ import ( ) func TestBadgeHandler_ServeSVG(t *testing.T) { - _, handler, _ := newTestServer(t, nil) + _, handler, _ := newTestServer(t, nil, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, httptest.NewRequest("GET", "/badge/net/http", nil)) if got, want := w.Result().Header.Get("Content-Type"), "image/svg+xml"; got != want { @@ -20,7 +20,7 @@ func TestBadgeHandler_ServeSVG(t *testing.T) { } func TestBadgeHandler_ServeBadgeTool(t *testing.T) { - _, handler, _ := newTestServer(t, nil) + _, handler, _ := newTestServer(t, nil, nil) tests := []struct { url string diff --git a/internal/frontend/fetch_test.go b/internal/frontend/fetch_test.go index 2a1bc7a2a..40237535f 100644 --- a/internal/frontend/fetch_test.go +++ b/internal/frontend/fetch_test.go @@ -79,7 +79,7 @@ func TestFetch(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - s, _, teardown := newTestServer(t, testModulesForProxy) + s, _, teardown := newTestServer(t, testModulesForProxy, nil) defer teardown() ctx, cancel := context.WithTimeout(context.Background(), testFetchTimeout) @@ -143,7 +143,7 @@ func TestFetchErrors(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), test.fetchTimeout) defer cancel() - s, _, teardown := newTestServer(t, testModulesForProxy) + s, _, teardown := newTestServer(t, testModulesForProxy, nil) defer teardown() got, err := s.fetchAndPoll(ctx, s.getDataSource(ctx), test.modulePath, test.fullPath, test.version) @@ -180,7 +180,7 @@ func TestFetchPathAlreadyExists(t *testing.T) { t.Fatal(err) } - s, _, teardown := newTestServer(t, testModulesForProxy) + s, _, teardown := newTestServer(t, testModulesForProxy, nil) defer teardown() got, _ := s.fetchAndPoll(ctx, s.getDataSource(ctx), sample.ModulePath, sample.PackagePath, sample.VersionString) if got != test.want { diff --git a/internal/frontend/server_test.go b/internal/frontend/server_test.go index 46a46bbc2..a0bfaeaf3 100644 --- a/internal/frontend/server_test.go +++ b/internal/frontend/server_test.go @@ -6,7 +6,9 @@ package frontend import ( "context" + "errors" "fmt" + "io" "net/http" "net/http/httptest" "os" @@ -15,6 +17,8 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" + "github.com/go-redis/redis/v8" "github.com/google/safehtml/template" "github.com/jba/templatecheck" "golang.org/x/net/html" @@ -41,7 +45,7 @@ func TestMain(m *testing.M) { } func TestHTMLInjection(t *testing.T) { - _, handler, _ := newTestServer(t, nil) + _, handler, _ := newTestServer(t, nil, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, httptest.NewRequest("GET", "/UHOH", nil)) if strings.Contains(w.Body.String(), "") { @@ -1168,7 +1172,7 @@ func testServer(t *testing.T, testCases []serverTestCase, experimentNames ...str if err := testDB.InsertExcludedPrefix(ctx, excludedModulePath, "testuser", "testreason"); err != nil { t.Fatal(err) } - _, handler, _ := newTestServer(t, nil, experimentNames...) + _, handler, _ := newTestServer(t, nil, nil, experimentNames...) experimentsSet := experiment.NewSet(experimentNames...) @@ -1218,7 +1222,7 @@ func isSubset(subset, set *experiment.Set) bool { } func TestServerErrors(t *testing.T) { - _, handler, _ := newTestServer(t, nil) + _, handler, _ := newTestServer(t, nil, nil) for _, test := range []struct { name, path string wantCode int @@ -1287,7 +1291,13 @@ func TestServer404Redirect(t *testing.T) { t.Fatal(err) } - _, handler, _ := newTestServer(t, nil) + rs, err := miniredis.Run() + if err != nil { + t.Fatal(err) + } + defer rs.Close() + + _, handler, _ := newTestServer(t, nil, redis.NewClient(&redis.Options{Addr: rs.Addr()})) for _, test := range []struct { name, path, flash string @@ -1304,7 +1314,8 @@ func TestServer404Redirect(t *testing.T) { if w.Code != http.StatusFound { t.Errorf("%q: got status code = %d, want %d", test.path, w.Code, http.StatusFound) } - c := findCookie(cookie.AlternativeModuleFlash, w.Result().Cookies()) + res := w.Result() + c := findCookie(cookie.AlternativeModuleFlash, res.Cookies()) if c == nil && test.flash != "" { t.Error("got no flash cookie, expected one") } else if c != nil { @@ -1313,9 +1324,28 @@ func TestServer404Redirect(t *testing.T) { t.Fatal(err) } if val != test.flash { - t.Errorf("got cookie value %q, want %q", val, test.flash) + t.Fatalf("got cookie value %q, want %q", val, test.flash) + } + // If we have a cookie, then following the redirect URL with the cookie + // should serve a "redirected from" banner. + loc := res.Header.Get("Location") + r := httptest.NewRequest("GET", loc, nil) + r.AddCookie(c) + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + if err := checkBanner(w.Result().Body, val); err != nil { + t.Fatalf("banner: %v", err) + } + // Visiting the same page again without the cookie should not + // display the banner. + r = httptest.NewRequest("GET", loc, nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, r) + if err := checkBanner(w.Result().Body, val); err != errNoBanner { + t.Fatalf("banner #2: got %v, want %v", err, errNoBanner) } } + }) } } @@ -1329,6 +1359,24 @@ func findCookie(name string, cookies []*http.Cookie) *http.Cookie { return nil } +var errNoBanner = errors.New("no redirect banner") + +func checkBanner(body io.ReadCloser, path string) error { + doc, err := html.Parse(body) + if err != nil { + return err + } + _ = body.Close() + + if in(".UnitHeader-redirectedFromBanner--none")(doc) == nil { + return errNoBanner + } + if err := in(".UnitHeader-redirectedFromBanner", hasText(path))(doc); err != nil { + return err + } + return nil +} + func mustRequest(urlPath string, t *testing.T) *http.Request { t.Helper() r, err := http.NewRequest(http.MethodGet, "http://localhost"+urlPath, nil) @@ -1390,7 +1438,7 @@ func TestTagRoute(t *testing.T) { } } -func newTestServer(t *testing.T, proxyModules []*proxy.Module, experimentNames ...string) (*Server, http.Handler, func()) { +func newTestServer(t *testing.T, proxyModules []*proxy.Module, redisClient *redis.Client, experimentNames ...string) (*Server, http.Handler, func()) { t.Helper() proxyClient, teardown := proxy.SetupTestClient(t, proxyModules) sourceClient := source.NewClient(sourceTimeout) @@ -1413,7 +1461,7 @@ func newTestServer(t *testing.T, proxyModules []*proxy.Module, experimentNames . t.Fatal(err) } mux := http.NewServeMux() - s.Install(mux.Handle, nil, nil) + s.Install(mux.Handle, redisClient, nil) var exps []*internal.Experiment for _, n := range experimentNames { @@ -1424,6 +1472,7 @@ func newTestServer(t *testing.T, proxyModules []*proxy.Module, experimentNames . t.Fatal(err) } mw := middleware.Chain( + middleware.RedirectedFrom(), middleware.Experiment(exp), middleware.LatestVersions(s.GetLatestInfo)) return s, mw(mux), func() {