Skip to content

Commit

Permalink
Refactored the XSRF plugin to only validate the token in Before and g…
Browse files Browse the repository at this point in the history
…enerate the protection in the Commit stage
  • Loading branch information
Mara Mihali committed Oct 20, 2020
1 parent 0f92e44 commit f6479e6
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 209 deletions.
75 changes: 36 additions & 39 deletions safehttp/plugins/xsrf/html/xsrf.go
Expand Up @@ -88,67 +88,64 @@ func addCookieID(w *safehttp.ResponseWriter) (*safehttp.Cookie, error) {
// interceptor checks for the presence of the XSRF token in the request body
// (expected to have been injected) and validates it.
func (it *Interceptor) Before(w *safehttp.ResponseWriter, r *safehttp.IncomingRequest, _ safehttp.InterceptorConfig) safehttp.Result {
needsValidation := !xsrf.StatePreserving(r)
if xsrf.StatePreserving(r) {
return safehttp.NotWritten()
}

cookieID, err := r.Cookie(cookieIDKey)
if err != nil {
if needsValidation {
return w.WriteError(safehttp.StatusForbidden)
}
cookieID, err = addCookieID(w)
if err != nil {
// An error is returned when the plugin fails to Set the Set-Cookie
// header in the response writer as this is a server misconfiguration.
return w.WriteError(safehttp.StatusInternalServerError)
}
return w.WriteError(safehttp.StatusForbidden)
}

actionID := r.URL.Path()
if needsValidation {
f, err := r.PostForm()
f, err := r.PostForm()
if err != nil {
// We fallback to checking whether the form is multipart. Both types
// are valid in an incoming request as long as the XSRF token is
// present.
mf, err := r.MultipartForm(32 << 20)
if err != nil {
// We fallback to checking whether the form is multipart. Both types
// are valid in an incoming request as long as the XSRF token is
// present.
mf, err := r.MultipartForm(32 << 20)
if err != nil {
return w.WriteError(safehttp.StatusBadRequest)
}
f = &mf.Form
return w.WriteError(safehttp.StatusBadRequest)
}
f = &mf.Form
}

tok := f.String(TokenKey, "")
if f.Err() != nil || tok == "" {
return w.WriteError(safehttp.StatusUnauthorized)
}
tok := f.String(TokenKey, "")
if f.Err() != nil || tok == "" {
return w.WriteError(safehttp.StatusUnauthorized)
}

if ok := xsrftoken.Valid(tok, it.SecretAppKey, cookieID.Value(), actionID); !ok {
return w.WriteError(safehttp.StatusForbidden)
}
if ok := xsrftoken.Valid(tok, it.SecretAppKey, cookieID.Value(), r.URL.Path()); !ok {
return w.WriteError(safehttp.StatusForbidden)
}

tok := xsrftoken.Generate(it.SecretAppKey, cookieID.Value(), actionID)
r.SetContext(context.WithValue(r.Context(), tokenCtxKey{}, tok))
return safehttp.NotWritten()
}

// Commit adds the XSRF token corresponding to the safehttp.TemplateResponse
// with key "XSRFToken". The token corresponds to the user information found in
// the request.
func (it *Interceptor) Commit(w *safehttp.ResponseWriter, r *safehttp.IncomingRequest, resp safehttp.Response, _ safehttp.InterceptorConfig) safehttp.Result {
cookieID, err := addCookieID(w)
if err != nil {
if xsrf.StatePreserving(r) {
// This should never happen as, if this is a state-changing request and it lacks the cookie, it would've been already rejected by Before.
return w.WriteError(safehttp.StatusInternalServerError)
}
cookieID, err = addCookieID(w)
if err != nil {
// This is a server misconfiguration.
return w.WriteError(safehttp.StatusInternalServerError)
}
}

tok := xsrftoken.Generate(it.SecretAppKey, cookieID.Value(), r.URL.Path())
r.SetContext(context.WithValue(r.Context(), tokenCtxKey{}, tok))

tmplResp, ok := resp.(safehttp.TemplateResponse)
if !ok {
return safehttp.NotWritten()
}

tok, err := Token(r)
if err != nil {
// The token should have been added in the Before stage and if that is
// not the case, a server misconfiguration occured.
return w.WriteError(safehttp.StatusInternalServerError)
}

// TODO(maramihali@): Change the key when function names are exported by
// htmlinject
// TODO: what should happen if the XSRFToken key is not present in the
// tr.FuncMap?
tmplResp.FuncMap[htmlinject.XSRFTokensDefaultFuncName] = func() string { return tok }
Expand Down
241 changes: 118 additions & 123 deletions safehttp/plugins/xsrf/html/xsrf_test.go
Expand Up @@ -15,7 +15,6 @@
package xsrfhtml

import (
"context"
"github.com/google/go-cmp/cmp"
"github.com/google/go-safeweb/safehttp"
"github.com/google/go-safeweb/safehttp/safehttptest"
Expand Down Expand Up @@ -216,66 +215,66 @@ func TestMissingTokenInBody(t *testing.T) {
}
}

func TestBeforeTokenInRequestContext(t *testing.T) {
rec := safehttptest.NewResponseRecorder()
req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
req.Header.Set("Cookie", cookieIDKey+"=abcdef")

i := Interceptor{SecretAppKey: "testSecretAppKey"}
i.Before(rec.ResponseWriter, req, nil)

tok, err := Token(req)
if tok == "" {
t.Error(`Token(req): got "", want token`)
}
if err != nil {
t.Errorf("Token(req): got %v, want nil", err)
}

if want, got := safehttp.StatusOK, rec.Status(); want != got {
t.Errorf("rec.Status(): got %v, want %v", got, want)
}
if diff := cmp.Diff(map[string][]string{}, map[string][]string(rec.Header())); diff != "" {
t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
}
if want, got := "", rec.Body(); got != want {
t.Errorf("rec.Body(): got %q want %q", got, want)
}

}

func TestTokenInRequestContext(t *testing.T) {
req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
req.SetContext(context.WithValue(req.Context(), tokenCtxKey{}, "pizza"))

got, err := Token(req)
if want := "pizza"; want != got {
t.Errorf("Token(req): got %v, want %v", got, want)
}
if err != nil {
t.Errorf("Token(req): got %v, want nil", err)
}
}

func TestMissingTokenInRequestContext(t *testing.T) {
req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
req.SetContext(context.Background())

got, err := Token(req)
if want := ""; want != got {
t.Errorf("Token(req): got %v, want %v", got, want)
}
if err == nil {
t.Error("Token(req): got nil, want error")
}
}
// func TestBeforeTokenInRequestContext(t *testing.T) {
// rec := safehttptest.NewResponseRecorder()
// req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
// req.Header.Set("Cookie", cookieIDKey+"=abcdef")

// i := Interceptor{SecretAppKey: "testSecretAppKey"}
// i.Before(rec.ResponseWriter, req, nil)

// tok, err := Token(req)
// if tok == "" {
// t.Error(`Token(req): got "", want token`)
// }
// if err != nil {
// t.Errorf("Token(req): got %v, want nil", err)
// }

// if want, got := safehttp.StatusOK, rec.Status(); want != got {
// t.Errorf("rec.Status(): got %v, want %v", got, want)
// }
// if diff := cmp.Diff(map[string][]string{}, map[string][]string(rec.Header())); diff != "" {
// t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
// }
// if want, got := "", rec.Body(); got != want {
// t.Errorf("rec.Body(): got %q want %q", got, want)
// }

// }

// func TestTokenInRequestContext(t *testing.T) {
// req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
// req.SetContext(context.WithValue(req.Context(), tokenCtxKey{}, "pizza"))

// got, err := Token(req)
// if want := "pizza"; want != got {
// t.Errorf("Token(req): got %v, want %v", got, want)
// }
// if err != nil {
// t.Errorf("Token(req): got %v, want nil", err)
// }
// }

// func TestMissingTokenInRequestContext(t *testing.T) {
// req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
// req.SetContext(context.Background())

// got, err := Token(req)
// if want := ""; want != got {
// t.Errorf("Token(req): got %v, want %v", got, want)
// }
// if err == nil {
// t.Error("Token(req): got nil, want error")
// }
// }

func TestMissingCookieInGetRequest(t *testing.T) {
rec := safehttptest.NewResponseRecorder()
req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)

i := Interceptor{SecretAppKey: "testSecretAppKey"}
i.Before(rec.ResponseWriter, req, nil)
i.Commit(rec.ResponseWriter, req, nil, nil)

if want, got := safehttp.StatusOK, rec.Status(); want != got {
t.Errorf("rec.Status(): got %v, want %v", got, want)
Expand Down Expand Up @@ -316,70 +315,70 @@ func TestMissingCookiePostRequest(t *testing.T) {
}
}

func TestCommitToken(t *testing.T) {
rec := safehttptest.NewResponseRecorder()
req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
req.SetContext(context.WithValue(req.Context(), tokenCtxKey{}, "pizza"))

i := Interceptor{SecretAppKey: "testSecretAppKey"}
tr := safehttp.TemplateResponse{FuncMap: map[string]interface{}{}}
i.Commit(rec.ResponseWriter, req, tr, nil)

tok, ok := tr.FuncMap["XSRFToken"]
if !ok {
t.Fatal(`tr.FuncMap["XSRFToken"] not found`)
}

fn, ok := tok.(func() string)
if !ok {
t.Fatalf(`tr.FuncMap["XSRFToken"]: got %T, want "func() string"`, fn)
}
if want, got := "pizza", fn(); want != got {
t.Errorf(`tr.FuncMap["XSRFToken"](): got %q, want %q`, got, want)
}

if want, got := safehttp.StatusOK, rec.Status(); want != got {
t.Errorf("rec.Status(): got %v, want %v", got, want)
}
if diff := cmp.Diff(map[string][]string{}, map[string][]string(rec.Header())); diff != "" {
t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
}

if want, got := "", rec.Body(); got != want {
t.Errorf("rec.Body(): got %q want %q", got, want)
}
}

func TestCommitMissingToken(t *testing.T) {
rec := safehttptest.NewResponseRecorder()
req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
req.SetContext(context.Background())

i := Interceptor{SecretAppKey: "testSecretAppKey"}
tr := safehttp.TemplateResponse{FuncMap: map[string]interface{}{}}
i.Commit(rec.ResponseWriter, req, tr, nil)

wantFuncMap := map[string]interface{}{}
if diff := cmp.Diff(wantFuncMap, tr.FuncMap); diff != "" {
t.Errorf("tr.FuncMap: mismatch (-want +got):\n%s", diff)
}

if want, got := safehttp.StatusInternalServerError, rec.Status(); got != want {
t.Errorf("rec.Status(): got %v, want %v", got, want)
}
wantHeaders := map[string][]string{
"Content-Type": {"text/plain; charset=utf-8"},
"X-Content-Type-Options": {"nosniff"},
}

if diff := cmp.Diff(wantHeaders, map[string][]string(rec.Header())); diff != "" {
t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
}

if want, got := "Internal Server Error\n", rec.Body(); got != want {
t.Errorf("rec.Body(): got %q want %q", got, want)
}
}
// func TestCommitToken(t *testing.T) {
// rec := safehttptest.NewResponseRecorder()
// req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
// req.SetContext(context.WithValue(req.Context(), tokenCtxKey{}, "pizza"))

// i := Interceptor{SecretAppKey: "testSecretAppKey"}
// tr := safehttp.TemplateResponse{FuncMap: map[string]interface{}{}}
// i.Commit(rec.ResponseWriter, req, tr, nil)

// tok, ok := tr.FuncMap["XSRFToken"]
// if !ok {
// t.Fatal(`tr.FuncMap["XSRFToken"] not found`)
// }

// fn, ok := tok.(func() string)
// if !ok {
// t.Fatalf(`tr.FuncMap["XSRFToken"]: got %T, want "func() string"`, fn)
// }
// if want, got := "pizza", fn(); want != got {
// t.Errorf(`tr.FuncMap["XSRFToken"](): got %q, want %q`, got, want)
// }

// if want, got := safehttp.StatusOK, rec.Status(); want != got {
// t.Errorf("rec.Status(): got %v, want %v", got, want)
// }
// if diff := cmp.Diff(map[string][]string{}, map[string][]string(rec.Header())); diff != "" {
// t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
// }

// if want, got := "", rec.Body(); got != want {
// t.Errorf("rec.Body(): got %q want %q", got, want)
// }
// }

// func TestCommitMissingToken(t *testing.T) {
// rec := safehttptest.NewResponseRecorder()
// req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
// req.SetContext(context.Background())

// i := Interceptor{SecretAppKey: "testSecretAppKey"}
// tr := safehttp.TemplateResponse{FuncMap: map[string]interface{}{}}
// i.Commit(rec.ResponseWriter, req, tr, nil)

// wantFuncMap := map[string]interface{}{}
// if diff := cmp.Diff(wantFuncMap, tr.FuncMap); diff != "" {
// t.Errorf("tr.FuncMap: mismatch (-want +got):\n%s", diff)
// }

// if want, got := safehttp.StatusInternalServerError, rec.Status(); got != want {
// t.Errorf("rec.Status(): got %v, want %v", got, want)
// }
// wantHeaders := map[string][]string{
// "Content-Type": {"text/plain; charset=utf-8"},
// "X-Content-Type-Options": {"nosniff"},
// }

// if diff := cmp.Diff(wantHeaders, map[string][]string(rec.Header())); diff != "" {
// t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
// }

// if want, got := "Internal Server Error\n", rec.Body(); got != want {
// t.Errorf("rec.Body(): got %q want %q", got, want)
// }
// }

func TestCommitNotTemplateResponse(t *testing.T) {
rec := safehttptest.NewResponseRecorder()
Expand All @@ -392,10 +391,6 @@ func TestCommitNotTemplateResponse(t *testing.T) {
t.Errorf("rec.Status(): got %v, want %v", got, want)
}

if diff := cmp.Diff(map[string][]string{}, map[string][]string(rec.Header())); diff != "" {
t.Errorf("rec.Header() mismatch (-want +got):\n%s", diff)
}

if want, got := "", rec.Body(); got != want {
t.Errorf("rec.Body(): got %q want %q", got, want)
}
Expand Down

0 comments on commit f6479e6

Please sign in to comment.