Skip to content

Commit

Permalink
extend basic auth with simplified version and keep authorized status …
Browse files Browse the repository at this point in the history
…in ctx
  • Loading branch information
umputun committed Jan 16, 2023
1 parent 04ee52f commit 21c9dc2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
23 changes: 22 additions & 1 deletion basic_auth.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package rest

import (
"context"
"crypto/subtle"
"net/http"
)

const baContextKey = "authorizedWithBasicAuth"

// BasicAuth middleware requires basic auth and matches user & passwd with client-provided checker
func BasicAuth(checker func(user, passwd string) bool) func(http.Handler) http.Handler {

Expand All @@ -19,8 +23,25 @@ func BasicAuth(checker func(user, passwd string) bool) func(http.Handler) http.H
w.WriteHeader(http.StatusForbidden)
return
}
h.ServeHTTP(w, r)
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true)))
}
return http.HandlerFunc(fn)
}
}

// BasicAuthWithUserPasswd middleware requires basic auth and matches user & passwd with client-provided values
func BasicAuthWithUserPasswd(user, passwd string) func(http.Handler) http.Handler {
checkFn := func(reqUser, reqPasswd string) bool {
matchUser := subtle.ConstantTimeCompare([]byte(user), []byte(reqUser))
matchPass := subtle.ConstantTimeCompare([]byte(passwd), []byte(reqPasswd))
return matchUser == 1 && matchPass == 1
}
return BasicAuth(checkFn)
}

// IsAuthorized returns true is user authorized.
// it can be used in handlers to check if BasicAuth middleware was applied
func IsAuthorized(ctx context.Context) bool {
v := ctx.Value(contextKey(baContextKey))
return nil != v && v.(bool)
}
50 changes: 49 additions & 1 deletion basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ func TestBasicAuth(t *testing.T) {
return user == "dev" && passwd == "good"
})

ts := httptest.NewServer(mw(getTestHandlerBlah()))
ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("request %s", r.URL)
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("blah"))
require.NoError(t, err)
assert.True(t, IsAuthorized(r.Context()))
})))
defer ts.Close()

u := fmt.Sprintf("%s%s", ts.URL, "/something")
Expand Down Expand Up @@ -50,3 +56,45 @@ func TestBasicAuth(t *testing.T) {
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
}
}

func TestBasicAuthWithUserPasswd(t *testing.T) {
mw := BasicAuthWithUserPasswd("dev", "good")

ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("request %s", r.URL)
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("blah"))
require.NoError(t, err)
assert.True(t, IsAuthorized(r.Context()))
})))
defer ts.Close()

u := fmt.Sprintf("%s%s", ts.URL, "/something")

client := http.Client{Timeout: 5 * time.Second}

{
req, err := http.NewRequest("GET", u, http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}

{
req, err := http.NewRequest("GET", u, http.NoBody)
require.NoError(t, err)
req.SetBasicAuth("dev", "good")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
{
req, err := http.NewRequest("GET", u, http.NoBody)
require.NoError(t, err)
req.SetBasicAuth("dev", "bad")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
}
}

0 comments on commit 21c9dc2

Please sign in to comment.