Skip to content

Commit

Permalink
check if user's provider is in the list of current providers (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Aug 20, 2023
1 parent 8da8a5c commit cea049c
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 49 deletions.
22 changes: 21 additions & 1 deletion middleware/auth.go
Expand Up @@ -110,7 +110,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return
}

if claims.Handshake != nil { // handshake in token indicate special use cases, not for login
if claims.Handshake != nil { // handshake in token indicates special use cases, not for login
onError(h, w, r, fmt.Errorf("invalid kind of token"))
return
}
Expand All @@ -128,6 +128,13 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return
}

// check if user provider is allowed
if !a.isProviderAllowed(claims.User.ID) {
onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID))
a.JWTService.Reset(w)
return
}

if a.JWTService.IsExpired(claims) {
if claims, err = a.refreshExpiredToken(w, claims, tkn); err != nil {
a.JWTService.Reset(w)
Expand All @@ -146,6 +153,19 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return f
}

// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890"
// this check is needed to reject users from providers what are used to be allowed but not anymore.
// Such users made token before the provider was disabled and should not be allowed to login anymore.
func (a *Authenticator) isProviderAllowed(userID string) bool {
userProvider := strings.Split(userID, "_")[0]
for _, p := range a.Providers {
if p.Name() == userProvider {
return true
}
}
return false
}

// refreshExpiredToken makes a new token with passed claims
func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token.Claims, tkn string) (token.Claims, error) {

Expand Down
138 changes: 90 additions & 48 deletions middleware/auth_test.go
Expand Up @@ -16,18 +16,21 @@ import (
"github.com/stretchr/testify/require"

"github.com/go-pkgz/auth/logger"
"github.com/go-pkgz/auth/provider"
"github.com/go-pkgz/auth/token"
)

var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.OWPdibrSSSHuOV3DzzLH5soO6kUcERELL7_GLf7Ja_E"
var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.orBYt_pVA4uvCCw0JMQLla3DA0mpjRTl_U9vT_wtI30"

var testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6MTE4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.lJNUjG_9rpAghqy5GwIOrgfQnGDnF3PW5sGzKdijmmg"
var testJwtValidWrongProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjNfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.p0w7GmXKwujm0ROn0RIACnBwN4KmPcqXDMS9YoFq4jQ"

var testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6MTE4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.PlRRc5YA6pvoVOT4NLLOoTwU2Kn3GaTfbjr6j-P6RhA"

var testJwtWithHandshake = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0._2X1cAEoxjLA7XuN8xW8V9r7rYfP_m9lSRz_9_UFzac"

var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4MjIsImp0aSI6InJhbmRvbSBpZCIsImlzcyI6InJlbWFyazQyIiwibmJmIjoxNTI2ODg0MjIyfQ.sBpblkbBRzZsBSPPNrTWqA5h7h54solrw5L4IypJT_o"

var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn0sInJvbGUiOiJlbXBsb3llZSJ9fQ.VLW4_LUDZq_eFc9F1Zx1lbv2Whic2VHy6C0dJ5azL8A"
var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g"

func TestAuthJWTCookie(t *testing.T) {
a := makeTestAuth(t)
Expand All @@ -36,7 +39,7 @@ func TestAuthJWTCookie(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
u, err := token.GetUserInfo(r)
assert.NoError(t, err)
assert.Equal(t, token.User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png",
assert.Equal(t, token.User{Name: "name1", ID: "provider1_id1", Picture: "http://example.com/pic.png",
IP: "127.0.0.1", Email: "me@example.com", Audience: "test_sys",
Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, u)
w.WriteHeader(201)
Expand All @@ -45,40 +48,61 @@ func TestAuthJWTCookie(t *testing.T) {
server := httptest.NewServer(mux)
defer server.Close()

client := &http.Client{Timeout: 5 * time.Second}
expiration := int(365 * 24 * time.Hour.Seconds()) //nolint
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")

client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "valid token user")
t.Run("valid token", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")

req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "wrong id")
resp, err = client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err = client.Do(req)
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed")
t.Run("valid token, wrong provider", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/",
MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")

req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err = client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "no user info in the token")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed")
})

t.Run("xsrf mismatch", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "wrong id")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch")
})

t.Run("token expired and refreshed", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed")
})

t.Run("no user info in the token", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "no user info in the token")
})
}

func TestAuthJWTHeader(t *testing.T) {
Expand All @@ -87,19 +111,32 @@ func TestAuthJWTHeader(t *testing.T) {
defer server.Close()

client := &http.Client{Timeout: 5 * time.Second}
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.Header.Add("X-JWT", testJwtValid)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "valid token user")
t.Run("valid token", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.Header.Add("X-JWT", testJwtValid)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.Header.Add("X-JWT", testJwtExpired)
resp, err = client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "token expired")
t.Run("valid token, wrong provider", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.Header.Add("X-JWT", testJwtValidWrongProvider)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "wrong provider")
})

t.Run("token expired", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.Header.Add("X-JWT", testJwtExpired)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode, "token expired")
})
}

func TestAuthJWTRefresh(t *testing.T) {
Expand Down Expand Up @@ -177,7 +214,7 @@ func TestAuthJWTRefreshConcurrentWithCache(t *testing.T) {
// make another expired token
c, err := a.JWTService.Parse(testJwtExpired)
require.NoError(t, err)
c.User.ID = "other ID"
c.User.ID = "provider1_other ID"
tkSvc := a.JWTService.(*token.Service)
tkn, err := tkSvc.Token(c)
require.NoError(t, err)
Expand Down Expand Up @@ -413,7 +450,7 @@ func TestRBAC(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, err := token.GetUserInfo(r)
assert.NoError(t, err)
assert.Equal(t, token.User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png",
assert.Equal(t, token.User{Name: "name1", ID: "provider1_id1", Picture: "http://example.com/pic.png",
IP: "127.0.0.1", Email: "me@example.com", Audience: "test_sys",
Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"},
Role: "employee"}, u)
Expand All @@ -438,7 +475,7 @@ func TestRBAC(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 201, resp.StatusCode, "valid token user")

// employee route only, token without employee role
// employee route only, token without an employee role
expiration = int(365 * 24 * time.Hour.Seconds()) //nolint
req, err = http.NewRequest("GET", server.URL+"/authForEmployees", http.NoBody)
require.Nil(t, err)
Expand Down Expand Up @@ -480,6 +517,7 @@ func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler {
}

func makeTestAuth(_ *testing.T) Authenticator {

j := token.NewService(token.Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "xyz 12345", nil }),
SecureCookies: false,
Expand All @@ -497,6 +535,10 @@ func makeTestAuth(_ *testing.T) Authenticator {
JWTService: j,
Validator: token.ValidatorFunc(func(token string, claims token.Claims) bool { return true }),
L: logger.Std,
Providers: []provider.Service{
{Provider: provider.DirectHandler{ProviderName: "provider1"}},
{Provider: provider.DirectHandler{ProviderName: "provider2"}},
},
}
}

Expand Down

0 comments on commit cea049c

Please sign in to comment.