From 21c9dc2e61e12b0e3a1f7b364e9f533349f4af3e Mon Sep 17 00:00:00 2001 From: Umputun Date: Sun, 15 Jan 2023 18:44:40 -0600 Subject: [PATCH] extend basic auth with simplified version and keep authorized status in ctx --- basic_auth.go | 23 ++++++++++++++++++++- basic_auth_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/basic_auth.go b/basic_auth.go index 6ed0c68..f63d9ce 100644 --- a/basic_auth.go +++ b/basic_auth.go @@ -1,9 +1,13 @@ package rest import ( + "context" + "crypto/subtle" "net/http" ) +const baContextKey = "authorizedWithBasicAuth" + // BasicAuth middleware requires basic auth and matches user & passwd with client-provided checker func BasicAuth(checker func(user, passwd string) bool) func(http.Handler) http.Handler { @@ -19,8 +23,25 @@ func BasicAuth(checker func(user, passwd string) bool) func(http.Handler) http.H w.WriteHeader(http.StatusForbidden) return } - h.ServeHTTP(w, r) + h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true))) } return http.HandlerFunc(fn) } } + +// BasicAuthWithUserPasswd middleware requires basic auth and matches user & passwd with client-provided values +func BasicAuthWithUserPasswd(user, passwd string) func(http.Handler) http.Handler { + checkFn := func(reqUser, reqPasswd string) bool { + matchUser := subtle.ConstantTimeCompare([]byte(user), []byte(reqUser)) + matchPass := subtle.ConstantTimeCompare([]byte(passwd), []byte(reqPasswd)) + return matchUser == 1 && matchPass == 1 + } + return BasicAuth(checkFn) +} + +// IsAuthorized returns true is user authorized. +// it can be used in handlers to check if BasicAuth middleware was applied +func IsAuthorized(ctx context.Context) bool { + v := ctx.Value(contextKey(baContextKey)) + return nil != v && v.(bool) +} diff --git a/basic_auth_test.go b/basic_auth_test.go index 160dcef..0212228 100644 --- a/basic_auth_test.go +++ b/basic_auth_test.go @@ -17,7 +17,13 @@ func TestBasicAuth(t *testing.T) { return user == "dev" && passwd == "good" }) - ts := httptest.NewServer(mw(getTestHandlerBlah())) + ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("request %s", r.URL) + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("blah")) + require.NoError(t, err) + assert.True(t, IsAuthorized(r.Context())) + }))) defer ts.Close() u := fmt.Sprintf("%s%s", ts.URL, "/something") @@ -50,3 +56,45 @@ func TestBasicAuth(t *testing.T) { assert.Equal(t, http.StatusForbidden, resp.StatusCode) } } + +func TestBasicAuthWithUserPasswd(t *testing.T) { + mw := BasicAuthWithUserPasswd("dev", "good") + + ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("request %s", r.URL) + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("blah")) + require.NoError(t, err) + assert.True(t, IsAuthorized(r.Context())) + }))) + defer ts.Close() + + u := fmt.Sprintf("%s%s", ts.URL, "/something") + + client := http.Client{Timeout: 5 * time.Second} + + { + req, err := http.NewRequest("GET", u, http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + } + + { + req, err := http.NewRequest("GET", u, http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("dev", "good") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + } + { + req, err := http.NewRequest("GET", u, http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("dev", "bad") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + } +}