From cea049c06aa5be34e1e23bc4bf927b8680ca068b Mon Sep 17 00:00:00 2001 From: Umputun Date: Sun, 20 Aug 2023 17:38:34 -0500 Subject: [PATCH] check if user's provider is in the list of current providers (#176) --- middleware/auth.go | 22 ++++++- middleware/auth_test.go | 138 ++++++++++++++++++++++++++-------------- 2 files changed, 111 insertions(+), 49 deletions(-) diff --git a/middleware/auth.go b/middleware/auth.go index a507e57b..64d6c7ce 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -110,7 +110,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return } - if claims.Handshake != nil { // handshake in token indicate special use cases, not for login + if claims.Handshake != nil { // handshake in token indicates special use cases, not for login onError(h, w, r, fmt.Errorf("invalid kind of token")) return } @@ -128,6 +128,13 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return } + // check if user provider is allowed + if !a.isProviderAllowed(claims.User.ID) { + onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID)) + a.JWTService.Reset(w) + return + } + if a.JWTService.IsExpired(claims) { if claims, err = a.refreshExpiredToken(w, claims, tkn); err != nil { a.JWTService.Reset(w) @@ -146,6 +153,19 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return f } +// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890" +// this check is needed to reject users from providers what are used to be allowed but not anymore. +// Such users made token before the provider was disabled and should not be allowed to login anymore. +func (a *Authenticator) isProviderAllowed(userID string) bool { + userProvider := strings.Split(userID, "_")[0] + for _, p := range a.Providers { + if p.Name() == userProvider { + return true + } + } + return false +} + // refreshExpiredToken makes a new token with passed claims func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token.Claims, tkn string) (token.Claims, error) { diff --git a/middleware/auth_test.go b/middleware/auth_test.go index c2fd1774..18208aee 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -16,18 +16,21 @@ import ( "github.com/stretchr/testify/require" "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/provider" "github.com/go-pkgz/auth/token" ) -var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.OWPdibrSSSHuOV3DzzLH5soO6kUcERELL7_GLf7Ja_E" +var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.orBYt_pVA4uvCCw0JMQLla3DA0mpjRTl_U9vT_wtI30" -var testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6MTE4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.lJNUjG_9rpAghqy5GwIOrgfQnGDnF3PW5sGzKdijmmg" +var testJwtValidWrongProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjNfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.p0w7GmXKwujm0ROn0RIACnBwN4KmPcqXDMS9YoFq4jQ" + +var testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6MTE4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.PlRRc5YA6pvoVOT4NLLOoTwU2Kn3GaTfbjr6j-P6RhA" var testJwtWithHandshake = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0._2X1cAEoxjLA7XuN8xW8V9r7rYfP_m9lSRz_9_UFzac" var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4MjIsImp0aSI6InJhbmRvbSBpZCIsImlzcyI6InJlbWFyazQyIiwibmJmIjoxNTI2ODg0MjIyfQ.sBpblkbBRzZsBSPPNrTWqA5h7h54solrw5L4IypJT_o" -var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn0sInJvbGUiOiJlbXBsb3llZSJ9fQ.VLW4_LUDZq_eFc9F1Zx1lbv2Whic2VHy6C0dJ5azL8A" +var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g" func TestAuthJWTCookie(t *testing.T) { a := makeTestAuth(t) @@ -36,7 +39,7 @@ func TestAuthJWTCookie(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { u, err := token.GetUserInfo(r) assert.NoError(t, err) - assert.Equal(t, token.User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", + assert.Equal(t, token.User{Name: "name1", ID: "provider1_id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", Email: "me@example.com", Audience: "test_sys", Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, u) w.WriteHeader(201) @@ -45,40 +48,61 @@ func TestAuthJWTCookie(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() + client := &http.Client{Timeout: 5 * time.Second} expiration := int(365 * 24 * time.Hour.Seconds()) //nolint - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - client := &http.Client{Timeout: 5 * time.Second} - resp, err := client.Do(req) - require.NoError(t, err) - assert.Equal(t, 201, resp.StatusCode, "valid token user") + t.Run("valid token", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") - req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "wrong id") - resp, err = client.Do(req) - require.NoError(t, err) - assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) - req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err = client.Do(req) - require.NoError(t, err) - assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") + t.Run("valid token, wrong provider", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/", + MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") - req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) - req.Header.Add("X-XSRF-TOKEN", "random id") - resp, err = client.Do(req) - require.NoError(t, err) - assert.Equal(t, 401, resp.StatusCode, "no user info in the token") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed") + }) + + t.Run("xsrf mismatch", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "wrong id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") + }) + + t.Run("token expired and refreshed", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") + }) + + t.Run("no user info in the token", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "no user info in the token") + }) } func TestAuthJWTHeader(t *testing.T) { @@ -87,19 +111,32 @@ func TestAuthJWTHeader(t *testing.T) { defer server.Close() client := &http.Client{Timeout: 5 * time.Second} - req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.Header.Add("X-JWT", testJwtValid) - resp, err := client.Do(req) - require.NoError(t, err) - assert.Equal(t, 201, resp.StatusCode, "valid token user") + t.Run("valid token", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtValid) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) - req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) - require.Nil(t, err) - req.Header.Add("X-JWT", testJwtExpired) - resp, err = client.Do(req) - require.NoError(t, err) - assert.Equal(t, 401, resp.StatusCode, "token expired") + t.Run("valid token, wrong provider", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtValidWrongProvider) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "wrong provider") + }) + + t.Run("token expired", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtExpired) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "token expired") + }) } func TestAuthJWTRefresh(t *testing.T) { @@ -177,7 +214,7 @@ func TestAuthJWTRefreshConcurrentWithCache(t *testing.T) { // make another expired token c, err := a.JWTService.Parse(testJwtExpired) require.NoError(t, err) - c.User.ID = "other ID" + c.User.ID = "provider1_other ID" tkSvc := a.JWTService.(*token.Service) tkn, err := tkSvc.Token(c) require.NoError(t, err) @@ -413,7 +450,7 @@ func TestRBAC(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { u, err := token.GetUserInfo(r) assert.NoError(t, err) - assert.Equal(t, token.User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", + assert.Equal(t, token.User{Name: "name1", ID: "provider1_id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", Email: "me@example.com", Audience: "test_sys", Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}, Role: "employee"}, u) @@ -438,7 +475,7 @@ func TestRBAC(t *testing.T) { require.NoError(t, err) assert.Equal(t, 201, resp.StatusCode, "valid token user") - // employee route only, token without employee role + // employee route only, token without an employee role expiration = int(365 * 24 * time.Hour.Seconds()) //nolint req, err = http.NewRequest("GET", server.URL+"/authForEmployees", http.NoBody) require.Nil(t, err) @@ -480,6 +517,7 @@ func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler { } func makeTestAuth(_ *testing.T) Authenticator { + j := token.NewService(token.Opts{ SecretReader: token.SecretFunc(func(string) (string, error) { return "xyz 12345", nil }), SecureCookies: false, @@ -497,6 +535,10 @@ func makeTestAuth(_ *testing.T) Authenticator { JWTService: j, Validator: token.ValidatorFunc(func(token string, claims token.Claims) bool { return true }), L: logger.Std, + Providers: []provider.Service{ + {Provider: provider.DirectHandler{ProviderName: "provider1"}}, + {Provider: provider.DirectHandler{ProviderName: "provider2"}}, + }, } }