Skip to content

Commit

Permalink
AntiReplay function should not take a pointer to request (#460)
Browse files Browse the repository at this point in the history
- Doing is so, would enable people to mutate that request.

- Somehow tangentially related; https://jub0bs.com/posts/2023-02-08-fearless-cors/#do-not-support-custom-callbacks
  • Loading branch information
komuw committed Jun 18, 2024
1 parent 847e9b2 commit b711fc1
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Most recent version is listed first.
# v0.1.2
- ong/middleware: Fix a number of CORS issues: https://github.com/komuw/ong/pull/442
- ong/middleware: Eliminate panics: https://github.com/komuw/ong/pull/459
- ong/middleware: AntiReplay function should not take a pointer to request: https://github.com/komuw/ong/pull/460

# v0.1.1
- ong/middleware: do not show hint: https://github.com/komuw/ong/pull/457
Expand Down
10 changes: 5 additions & 5 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ const (
// It is a no-op.
//
// [replay attacks]: https://en.wikipedia.org/wiki/Replay_attack
func DefaultSessionAntiReplayFunc(r *http.Request) string { return "" }
func DefaultSessionAntiReplayFunc(r http.Request) string { return "" }

// ClientIPstrategy is a middleware option that describes the strategy to use when fetching the client's IP address.
//
Expand Down Expand Up @@ -274,7 +274,7 @@ func New(
corsCacheDuration time.Duration,
csrfTokenDuration time.Duration,
sessionCookieDuration time.Duration,
sessionAntiReplayFunc func(r *http.Request) string,
sessionAntiReplayFunc func(r http.Request) string,
// server
maxBodyBytes uint64,
serverLogLevel slog.Level,
Expand Down Expand Up @@ -663,7 +663,7 @@ type middlewareOpts struct {

// session
SessionCookieDuration time.Duration
SessionAntiReplayFunc func(r *http.Request) string
SessionAntiReplayFunc func(r http.Request) string // Does NOT take a pointer to http.Request for security reasons.
}

// String implements [fmt.Stringer]
Expand Down Expand Up @@ -732,7 +732,7 @@ func newMiddlewareOpts(
corsCacheDuration time.Duration,
csrfTokenDuration time.Duration,
sessionCookieDuration time.Duration,
sessionAntiReplayFunc func(r *http.Request) string,
sessionAntiReplayFunc func(r http.Request) string,
) (middlewareOpts, error) {
if err := acme.Validate(domain); err != nil {
return middlewareOpts{}, err
Expand Down Expand Up @@ -1096,7 +1096,7 @@ func (o Opts) Equal(other Opts) bool {
if o.SessionCookieDuration != other.SessionCookieDuration {
return false
}
if o.SessionAntiReplayFunc(&http.Request{}) != other.SessionAntiReplayFunc(&http.Request{}) {
if o.SessionAntiReplayFunc(http.Request{}) != other.SessionAntiReplayFunc(http.Request{}) {
return false
}
}
Expand Down
2 changes: 1 addition & 1 deletion config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func validOpts(t *testing.T) Opts {
// Expire session cookie after 6hours.
6*time.Hour,
// Use a given header to try and mitigate against replay-attacks.
func(r *http.Request) string { return r.Header.Get("Anti-Replay") },
func(r http.Request) string { return r.Header.Get("Anti-Replay") },
//
// The maximum size in bytes for incoming request bodies.
2*1024*1024,
Expand Down
2 changes: 1 addition & 1 deletion config/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func ExampleNew() {
// Expire session cookie after 6hours.
6*time.Hour,
// Use a given header to try and mitigate against replay-attacks.
func(r *http.Request) string { return r.Header.Get("Anti-Replay") },
func(r http.Request) string { return r.Header.Get("Anti-Replay") },
//
// The maximum size in bytes for incoming request bodies.
2*1024*1024,
Expand Down
4 changes: 2 additions & 2 deletions middleware/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func session(
secretKey string,
domain string,
sessionCookieDuration time.Duration,
antiReplay func(r *http.Request) string,
antiReplay func(r http.Request) string,
) http.HandlerFunc {
if sessionCookieDuration < 1*time.Second { // It is measured in seconds.
sessionCookieDuration = config.DefaultSessionCookieDuration
Expand All @@ -31,7 +31,7 @@ func session(
// 1. Set anti replay data.
// 2. Read from cookies and check for session cookie.
// 3. Get that cookie and save it to r.context
r = sess.Initialise(r, secretKey, antiReplay(r))
r = sess.Initialise(r, secretKey, antiReplay(*r))

srw := newSessRW(w, r, domain, secretKey, sessionCookieDuration)

Expand Down
16 changes: 8 additions & 8 deletions middleware/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func TestSession(t *testing.T) {
secretKey,
domain,
config.DefaultSessionCookieDuration,
func(r *http.Request) string { return r.RemoteAddr },
func(r http.Request) string { return r.RemoteAddr },
)

rec := httptest.NewRecorder()
Expand All @@ -105,7 +105,7 @@ func TestSession(t *testing.T) {
value := "John Doe"

header := "Anti-Replay"
antiReplayFunc := func(r *http.Request) string {
antiReplayFunc := func(r http.Request) string {
return r.Header.Get(header)
}
wrappedHandler := session(
Expand Down Expand Up @@ -150,7 +150,7 @@ func TestSession(t *testing.T) {

// very important to do this assignment, since `antiReplayFunc` checks for IP mismatch.
req2.Header.Add(header, headerVal)
req2 = cookie.SetAntiReplay(req2, antiReplayFunc(req2))
req2 = cookie.SetAntiReplay(req2, antiReplayFunc(*req2))
attest.Ok(t, err)
req2.AddCookie(&http.Cookie{
Name: res.Cookies()[0].Name,
Expand All @@ -175,7 +175,7 @@ func TestSession(t *testing.T) {
secretKey,
domain,
config.DefaultSessionCookieDuration,
func(r *http.Request) string { return r.RemoteAddr },
func(r http.Request) string { return r.RemoteAddr },
)

rec := httptest.NewRecorder()
Expand Down Expand Up @@ -205,7 +205,7 @@ func TestSession(t *testing.T) {
key := "name"
value := "John Doe"

antiReplayFunc := func(r *http.Request) string { return r.RemoteAddr }
antiReplayFunc := func(r http.Request) string { return r.RemoteAddr }
wrappedHandler := session(
someSessionHandler(msg, key, value),
secretKey,
Expand Down Expand Up @@ -233,7 +233,7 @@ func TestSession(t *testing.T) {
req2 := httptest.NewRequest(http.MethodGet, "/hey-uri", nil)
// very important to do this assignment, since `antiReplayFunc` checks for IP mismatch.
req2.RemoteAddr = ip1
req2 = cookie.SetAntiReplay(req2, antiReplayFunc(req2))
req2 = cookie.SetAntiReplay(req2, antiReplayFunc(*req2))
req2.AddCookie(&http.Cookie{
Name: res.Cookies()[0].Name,
Value: res.Cookies()[0].Value,
Expand All @@ -249,7 +249,7 @@ func TestSession(t *testing.T) {
req3 := httptest.NewRequest(http.MethodGet, "/hey-uri", nil)
ip2 := "148.65.4.3"
req3.RemoteAddr = ip2
req3 = cookie.SetAntiReplay(req3, antiReplayFunc(req3))
req3 = cookie.SetAntiReplay(req3, antiReplayFunc(*req3))
req3.AddCookie(&http.Cookie{
Name: res.Cookies()[0].Name,
Value: res.Cookies()[0].Value,
Expand All @@ -275,7 +275,7 @@ func TestSession(t *testing.T) {
secretKey,
domain,
config.DefaultSessionCookieDuration,
func(r *http.Request) string { return r.RemoteAddr },
func(r http.Request) string { return r.RemoteAddr },
)

runhandler := func() {
Expand Down

0 comments on commit b711fc1

Please sign in to comment.