diff --git a/internal/frontend/badge_test.go b/internal/frontend/badge_test.go
index f90bf7188..def3ef9a9 100644
--- a/internal/frontend/badge_test.go
+++ b/internal/frontend/badge_test.go
@@ -11,7 +11,7 @@ import (
)
func TestBadgeHandler_ServeSVG(t *testing.T) {
- _, handler, _ := newTestServer(t, nil)
+ _, handler, _ := newTestServer(t, nil, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, httptest.NewRequest("GET", "/badge/net/http", nil))
if got, want := w.Result().Header.Get("Content-Type"), "image/svg+xml"; got != want {
@@ -20,7 +20,7 @@ func TestBadgeHandler_ServeSVG(t *testing.T) {
}
func TestBadgeHandler_ServeBadgeTool(t *testing.T) {
- _, handler, _ := newTestServer(t, nil)
+ _, handler, _ := newTestServer(t, nil, nil)
tests := []struct {
url string
diff --git a/internal/frontend/fetch_test.go b/internal/frontend/fetch_test.go
index 2a1bc7a2a..40237535f 100644
--- a/internal/frontend/fetch_test.go
+++ b/internal/frontend/fetch_test.go
@@ -79,7 +79,7 @@ func TestFetch(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
- s, _, teardown := newTestServer(t, testModulesForProxy)
+ s, _, teardown := newTestServer(t, testModulesForProxy, nil)
defer teardown()
ctx, cancel := context.WithTimeout(context.Background(), testFetchTimeout)
@@ -143,7 +143,7 @@ func TestFetchErrors(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), test.fetchTimeout)
defer cancel()
- s, _, teardown := newTestServer(t, testModulesForProxy)
+ s, _, teardown := newTestServer(t, testModulesForProxy, nil)
defer teardown()
got, err := s.fetchAndPoll(ctx, s.getDataSource(ctx), test.modulePath, test.fullPath, test.version)
@@ -180,7 +180,7 @@ func TestFetchPathAlreadyExists(t *testing.T) {
t.Fatal(err)
}
- s, _, teardown := newTestServer(t, testModulesForProxy)
+ s, _, teardown := newTestServer(t, testModulesForProxy, nil)
defer teardown()
got, _ := s.fetchAndPoll(ctx, s.getDataSource(ctx), sample.ModulePath, sample.PackagePath, sample.VersionString)
if got != test.want {
diff --git a/internal/frontend/server_test.go b/internal/frontend/server_test.go
index 46a46bbc2..a0bfaeaf3 100644
--- a/internal/frontend/server_test.go
+++ b/internal/frontend/server_test.go
@@ -6,7 +6,9 @@ package frontend
import (
"context"
+ "errors"
"fmt"
+ "io"
"net/http"
"net/http/httptest"
"os"
@@ -15,6 +17,8 @@ import (
"testing"
"time"
+ "github.com/alicebob/miniredis/v2"
+ "github.com/go-redis/redis/v8"
"github.com/google/safehtml/template"
"github.com/jba/templatecheck"
"golang.org/x/net/html"
@@ -41,7 +45,7 @@ func TestMain(m *testing.M) {
}
func TestHTMLInjection(t *testing.T) {
- _, handler, _ := newTestServer(t, nil)
+ _, handler, _ := newTestServer(t, nil, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, httptest.NewRequest("GET", "/UHOH", nil))
if strings.Contains(w.Body.String(), "") {
@@ -1168,7 +1172,7 @@ func testServer(t *testing.T, testCases []serverTestCase, experimentNames ...str
if err := testDB.InsertExcludedPrefix(ctx, excludedModulePath, "testuser", "testreason"); err != nil {
t.Fatal(err)
}
- _, handler, _ := newTestServer(t, nil, experimentNames...)
+ _, handler, _ := newTestServer(t, nil, nil, experimentNames...)
experimentsSet := experiment.NewSet(experimentNames...)
@@ -1218,7 +1222,7 @@ func isSubset(subset, set *experiment.Set) bool {
}
func TestServerErrors(t *testing.T) {
- _, handler, _ := newTestServer(t, nil)
+ _, handler, _ := newTestServer(t, nil, nil)
for _, test := range []struct {
name, path string
wantCode int
@@ -1287,7 +1291,13 @@ func TestServer404Redirect(t *testing.T) {
t.Fatal(err)
}
- _, handler, _ := newTestServer(t, nil)
+ rs, err := miniredis.Run()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rs.Close()
+
+ _, handler, _ := newTestServer(t, nil, redis.NewClient(&redis.Options{Addr: rs.Addr()}))
for _, test := range []struct {
name, path, flash string
@@ -1304,7 +1314,8 @@ func TestServer404Redirect(t *testing.T) {
if w.Code != http.StatusFound {
t.Errorf("%q: got status code = %d, want %d", test.path, w.Code, http.StatusFound)
}
- c := findCookie(cookie.AlternativeModuleFlash, w.Result().Cookies())
+ res := w.Result()
+ c := findCookie(cookie.AlternativeModuleFlash, res.Cookies())
if c == nil && test.flash != "" {
t.Error("got no flash cookie, expected one")
} else if c != nil {
@@ -1313,9 +1324,28 @@ func TestServer404Redirect(t *testing.T) {
t.Fatal(err)
}
if val != test.flash {
- t.Errorf("got cookie value %q, want %q", val, test.flash)
+ t.Fatalf("got cookie value %q, want %q", val, test.flash)
+ }
+ // If we have a cookie, then following the redirect URL with the cookie
+ // should serve a "redirected from" banner.
+ loc := res.Header.Get("Location")
+ r := httptest.NewRequest("GET", loc, nil)
+ r.AddCookie(c)
+ w = httptest.NewRecorder()
+ handler.ServeHTTP(w, r)
+ if err := checkBanner(w.Result().Body, val); err != nil {
+ t.Fatalf("banner: %v", err)
+ }
+ // Visiting the same page again without the cookie should not
+ // display the banner.
+ r = httptest.NewRequest("GET", loc, nil)
+ w = httptest.NewRecorder()
+ handler.ServeHTTP(w, r)
+ if err := checkBanner(w.Result().Body, val); err != errNoBanner {
+ t.Fatalf("banner #2: got %v, want %v", err, errNoBanner)
}
}
+
})
}
}
@@ -1329,6 +1359,24 @@ func findCookie(name string, cookies []*http.Cookie) *http.Cookie {
return nil
}
+var errNoBanner = errors.New("no redirect banner")
+
+func checkBanner(body io.ReadCloser, path string) error {
+ doc, err := html.Parse(body)
+ if err != nil {
+ return err
+ }
+ _ = body.Close()
+
+ if in(".UnitHeader-redirectedFromBanner--none")(doc) == nil {
+ return errNoBanner
+ }
+ if err := in(".UnitHeader-redirectedFromBanner", hasText(path))(doc); err != nil {
+ return err
+ }
+ return nil
+}
+
func mustRequest(urlPath string, t *testing.T) *http.Request {
t.Helper()
r, err := http.NewRequest(http.MethodGet, "http://localhost"+urlPath, nil)
@@ -1390,7 +1438,7 @@ func TestTagRoute(t *testing.T) {
}
}
-func newTestServer(t *testing.T, proxyModules []*proxy.Module, experimentNames ...string) (*Server, http.Handler, func()) {
+func newTestServer(t *testing.T, proxyModules []*proxy.Module, redisClient *redis.Client, experimentNames ...string) (*Server, http.Handler, func()) {
t.Helper()
proxyClient, teardown := proxy.SetupTestClient(t, proxyModules)
sourceClient := source.NewClient(sourceTimeout)
@@ -1413,7 +1461,7 @@ func newTestServer(t *testing.T, proxyModules []*proxy.Module, experimentNames .
t.Fatal(err)
}
mux := http.NewServeMux()
- s.Install(mux.Handle, nil, nil)
+ s.Install(mux.Handle, redisClient, nil)
var exps []*internal.Experiment
for _, n := range experimentNames {
@@ -1424,6 +1472,7 @@ func newTestServer(t *testing.T, proxyModules []*proxy.Module, experimentNames .
t.Fatal(err)
}
mw := middleware.Chain(
+ middleware.RedirectedFrom(),
middleware.Experiment(exp),
middleware.LatestVersions(s.GetLatestInfo))
return s, mw(mux), func() {