diff --git a/CHANGELOG.md b/CHANGELOG.md index 228fa75d..0d4d1d0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Most recent version is listed first. # v0.1.2 - ong/middleware: Fix a number of CORS issues: https://github.com/komuw/ong/pull/442 - ong/middleware: Eliminate panics: https://github.com/komuw/ong/pull/459 +- ong/middleware: AntiReplay function should not take a pointer to request: https://github.com/komuw/ong/pull/460 # v0.1.1 - ong/middleware: do not show hint: https://github.com/komuw/ong/pull/457 diff --git a/config/config.go b/config/config.go index 457a10c7..4941faf7 100644 --- a/config/config.go +++ b/config/config.go @@ -85,7 +85,7 @@ const ( // It is a no-op. // // [replay attacks]: https://en.wikipedia.org/wiki/Replay_attack -func DefaultSessionAntiReplayFunc(r *http.Request) string { return "" } +func DefaultSessionAntiReplayFunc(r http.Request) string { return "" } // ClientIPstrategy is a middleware option that describes the strategy to use when fetching the client's IP address. // @@ -274,7 +274,7 @@ func New( corsCacheDuration time.Duration, csrfTokenDuration time.Duration, sessionCookieDuration time.Duration, - sessionAntiReplayFunc func(r *http.Request) string, + sessionAntiReplayFunc func(r http.Request) string, // server maxBodyBytes uint64, serverLogLevel slog.Level, @@ -663,7 +663,7 @@ type middlewareOpts struct { // session SessionCookieDuration time.Duration - SessionAntiReplayFunc func(r *http.Request) string + SessionAntiReplayFunc func(r http.Request) string // Does NOT take a pointer to http.Request for security reasons. } // String implements [fmt.Stringer] @@ -732,7 +732,7 @@ func newMiddlewareOpts( corsCacheDuration time.Duration, csrfTokenDuration time.Duration, sessionCookieDuration time.Duration, - sessionAntiReplayFunc func(r *http.Request) string, + sessionAntiReplayFunc func(r http.Request) string, ) (middlewareOpts, error) { if err := acme.Validate(domain); err != nil { return middlewareOpts{}, err @@ -1096,7 +1096,7 @@ func (o Opts) Equal(other Opts) bool { if o.SessionCookieDuration != other.SessionCookieDuration { return false } - if o.SessionAntiReplayFunc(&http.Request{}) != other.SessionAntiReplayFunc(&http.Request{}) { + if o.SessionAntiReplayFunc(http.Request{}) != other.SessionAntiReplayFunc(http.Request{}) { return false } } diff --git a/config/config_test.go b/config/config_test.go index 2b2dba88..a69425ce 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -60,7 +60,7 @@ func validOpts(t *testing.T) Opts { // Expire session cookie after 6hours. 6*time.Hour, // Use a given header to try and mitigate against replay-attacks. - func(r *http.Request) string { return r.Header.Get("Anti-Replay") }, + func(r http.Request) string { return r.Header.Get("Anti-Replay") }, // // The maximum size in bytes for incoming request bodies. 2*1024*1024, diff --git a/config/example_test.go b/config/example_test.go index a13fd33c..4acaf8c1 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -57,7 +57,7 @@ func ExampleNew() { // Expire session cookie after 6hours. 6*time.Hour, // Use a given header to try and mitigate against replay-attacks. - func(r *http.Request) string { return r.Header.Get("Anti-Replay") }, + func(r http.Request) string { return r.Header.Get("Anti-Replay") }, // // The maximum size in bytes for incoming request bodies. 2*1024*1024, diff --git a/middleware/session.go b/middleware/session.go index 0516e31b..75c6a908 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -21,7 +21,7 @@ func session( secretKey string, domain string, sessionCookieDuration time.Duration, - antiReplay func(r *http.Request) string, + antiReplay func(r http.Request) string, ) http.HandlerFunc { if sessionCookieDuration < 1*time.Second { // It is measured in seconds. sessionCookieDuration = config.DefaultSessionCookieDuration @@ -31,7 +31,7 @@ func session( // 1. Set anti replay data. // 2. Read from cookies and check for session cookie. // 3. Get that cookie and save it to r.context - r = sess.Initialise(r, secretKey, antiReplay(r)) + r = sess.Initialise(r, secretKey, antiReplay(*r)) srw := newSessRW(w, r, domain, secretKey, sessionCookieDuration) diff --git a/middleware/session_test.go b/middleware/session_test.go index 356a806a..d5cb49ca 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -78,7 +78,7 @@ func TestSession(t *testing.T) { secretKey, domain, config.DefaultSessionCookieDuration, - func(r *http.Request) string { return r.RemoteAddr }, + func(r http.Request) string { return r.RemoteAddr }, ) rec := httptest.NewRecorder() @@ -105,7 +105,7 @@ func TestSession(t *testing.T) { value := "John Doe" header := "Anti-Replay" - antiReplayFunc := func(r *http.Request) string { + antiReplayFunc := func(r http.Request) string { return r.Header.Get(header) } wrappedHandler := session( @@ -150,7 +150,7 @@ func TestSession(t *testing.T) { // very important to do this assignment, since `antiReplayFunc` checks for IP mismatch. req2.Header.Add(header, headerVal) - req2 = cookie.SetAntiReplay(req2, antiReplayFunc(req2)) + req2 = cookie.SetAntiReplay(req2, antiReplayFunc(*req2)) attest.Ok(t, err) req2.AddCookie(&http.Cookie{ Name: res.Cookies()[0].Name, @@ -175,7 +175,7 @@ func TestSession(t *testing.T) { secretKey, domain, config.DefaultSessionCookieDuration, - func(r *http.Request) string { return r.RemoteAddr }, + func(r http.Request) string { return r.RemoteAddr }, ) rec := httptest.NewRecorder() @@ -205,7 +205,7 @@ func TestSession(t *testing.T) { key := "name" value := "John Doe" - antiReplayFunc := func(r *http.Request) string { return r.RemoteAddr } + antiReplayFunc := func(r http.Request) string { return r.RemoteAddr } wrappedHandler := session( someSessionHandler(msg, key, value), secretKey, @@ -233,7 +233,7 @@ func TestSession(t *testing.T) { req2 := httptest.NewRequest(http.MethodGet, "/hey-uri", nil) // very important to do this assignment, since `antiReplayFunc` checks for IP mismatch. req2.RemoteAddr = ip1 - req2 = cookie.SetAntiReplay(req2, antiReplayFunc(req2)) + req2 = cookie.SetAntiReplay(req2, antiReplayFunc(*req2)) req2.AddCookie(&http.Cookie{ Name: res.Cookies()[0].Name, Value: res.Cookies()[0].Value, @@ -249,7 +249,7 @@ func TestSession(t *testing.T) { req3 := httptest.NewRequest(http.MethodGet, "/hey-uri", nil) ip2 := "148.65.4.3" req3.RemoteAddr = ip2 - req3 = cookie.SetAntiReplay(req3, antiReplayFunc(req3)) + req3 = cookie.SetAntiReplay(req3, antiReplayFunc(*req3)) req3.AddCookie(&http.Cookie{ Name: res.Cookies()[0].Name, Value: res.Cookies()[0].Value, @@ -275,7 +275,7 @@ func TestSession(t *testing.T) { secretKey, domain, config.DefaultSessionCookieDuration, - func(r *http.Request) string { return r.RemoteAddr }, + func(r http.Request) string { return r.RemoteAddr }, ) runhandler := func() {