diff --git a/pkg/api/ldap_debug.go b/pkg/api/ldap_debug.go index ad223cdb66f9e1b..e970ee5e60a34bd 100644 --- a/pkg/api/ldap_debug.go +++ b/pkg/api/ldap_debug.go @@ -221,6 +221,11 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon ReqContext: c, ExternalUser: user, SignupAllowed: hs.Cfg.LDAPAllowSignup, + UserLookupParams: models.UserLookupParams{ + UserID: &query.Result.Id, // Upsert by ID only + Email: nil, + Login: nil, + }, } err = bus.Dispatch(c.Req.Context(), upsertCmd) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 92d19b4afdfee35..3faac6ced1bc829 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -313,6 +313,11 @@ func syncUser( ReqContext: ctx, ExternalUser: extUser, SignupAllowed: connect.IsSignupAllowed(), + UserLookupParams: models.UserLookupParams{ + Email: &extUser.Email, + UserID: nil, + Login: nil, + }, } if err := bus.Dispatch(ctx.Req.Context(), cmd); err != nil { return nil, err diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index e9aa77d46c317d5..49c9e09e7703bf2 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -67,7 +67,8 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { } idToken := "testidtoken" token = token.WithExtra(map[string]interface{}{"id_token": idToken}) - query := &models.GetUserByAuthInfoQuery{Login: "loginuser", AuthModule: "test", AuthId: "test"} + login := "loginuser" + query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: models.UserLookupParams{Login: &login}} cmd := &models.UpdateAuthInfoCommand{ UserId: user.Id, AuthId: query.AuthId, diff --git a/pkg/login/ldap_login.go b/pkg/login/ldap_login.go index 2658de7fd9d565e..55029f05b17605f 100644 --- a/pkg/login/ldap_login.go +++ b/pkg/login/ldap_login.go @@ -57,9 +57,13 @@ var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bo ReqContext: query.ReqContext, ExternalUser: externalUser, SignupAllowed: setting.LDAPAllowSignup, + UserLookupParams: models.UserLookupParams{ + Login: &externalUser.Login, + Email: &externalUser.Email, + UserID: nil, + }, } - err = bus.Dispatch(ctx, upsert) - if err != nil { + if err = bus.Dispatch(ctx, upsert); err != nil { return true, err } query.User = upsert.Result diff --git a/pkg/models/user_auth.go b/pkg/models/user_auth.go index 221b899c0e1ef79..1dba8ed27050413 100644 --- a/pkg/models/user_auth.go +++ b/pkg/models/user_auth.go @@ -55,11 +55,11 @@ type RequestURIKey struct{} // COMMANDS type UpsertUserCommand struct { - ReqContext *ReqContext - ExternalUser *ExternalUserInfo + ReqContext *ReqContext + ExternalUser *ExternalUserInfo + UserLookupParams + Result *User SignupAllowed bool - - Result *User } type SetAuthInfoCommand struct { @@ -96,9 +96,14 @@ type LoginUserQuery struct { type GetUserByAuthInfoQuery struct { AuthModule string AuthId string - UserId int64 - Email string - Login string + UserLookupParams +} + +type UserLookupParams struct { + // Describes lookup order as well + UserID *int64 // if set, will try to find the user by id + Email *string // if set, will try to find the user by email + Login *string // if set, will try to find the user by login } type GetExternalUserInfoByLoginQuery struct { diff --git a/pkg/services/contexthandler/auth_jwt.go b/pkg/services/contexthandler/auth_jwt.go index 748bd13c120b7cf..882971f788fa516 100644 --- a/pkg/services/contexthandler/auth_jwt.go +++ b/pkg/services/contexthandler/auth_jwt.go @@ -66,6 +66,11 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64) ReqContext: ctx, SignupAllowed: h.Cfg.JWTAuthAutoSignUp, ExternalUser: extUser, + UserLookupParams: models.UserLookupParams{ + UserID: nil, + Login: &query.Login, + Email: &query.Email, + }, } if err := bus.Dispatch(ctx.Req.Context(), upsert); err != nil { ctx.Logger.Error("Failed to upsert JWT user", "error", err) diff --git a/pkg/services/contexthandler/authproxy/authproxy.go b/pkg/services/contexthandler/authproxy/authproxy.go index d3efa3073e16949..db7f3596324c07a 100644 --- a/pkg/services/contexthandler/authproxy/authproxy.go +++ b/pkg/services/contexthandler/authproxy/authproxy.go @@ -247,6 +247,11 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, error) { ReqContext: auth.ctx, SignupAllowed: auth.cfg.LDAPAllowSignup, ExternalUser: extUser, + UserLookupParams: models.UserLookupParams{ + Login: &extUser.Login, + Email: &extUser.Email, + UserID: nil, + }, } if err := bus.Dispatch(auth.ctx.Req.Context(), upsert); err != nil { return 0, err @@ -303,6 +308,11 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { ReqContext: auth.ctx, SignupAllowed: auth.cfg.AuthProxyAutoSignUp, ExternalUser: extUser, + UserLookupParams: models.UserLookupParams{ + UserID: nil, + Login: &extUser.Login, + Email: &extUser.Email, + }, } err := bus.Dispatch(auth.ctx.Req.Context(), upsert) diff --git a/pkg/services/login/authinfoservice/service.go b/pkg/services/login/authinfoservice/service.go index e1a13cea8c2b93a..0c845584562b5a9 100644 --- a/pkg/services/login/authinfoservice/service.go +++ b/pkg/services/login/authinfoservice/service.go @@ -89,11 +89,12 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser } // if user id was specified and doesn't match the user_auth entry, remove it - if query.UserId != 0 && query.UserId != authQuery.Result.UserId { - err := s.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{ + if query.UserLookupParams.UserID != nil && + *query.UserLookupParams.UserID != 0 && + *query.UserLookupParams.UserID != authQuery.Result.UserId { + if err := s.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{ UserAuth: authQuery.Result, - }) - if err != nil { + }); err != nil { s.logger.Error("Error removing user_auth entry", "error", err) } @@ -124,42 +125,42 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser return false, nil, nil, models.ErrUserNotFound } -func (s *Implementation) LookupByOneOf(userId int64, email string, login string) (bool, *models.User, error) { - foundUser := false +func (s *Implementation) LookupByOneOf(ctx context.Context, params *models.UserLookupParams) (*models.User, error) { var user *models.User var err error + foundUser := false // If not found, try to find the user by id - if userId != 0 { - foundUser, user, err = s.getUserById(userId) + if params.UserID != nil && *params.UserID != 0 { + foundUser, user, err = s.getUserById(*params.UserID) if err != nil { - return false, nil, err + return nil, err } } // If not found, try to find the user by email address - if !foundUser && email != "" { - user = &models.User{Email: email} + if !foundUser && params.Email != nil && *params.Email != "" { + user = &models.User{Email: *params.Email} foundUser, err = s.getUser(user) if err != nil { - return false, nil, err + return nil, err } } // If not found, try to find the user by login - if !foundUser && login != "" { - user = &models.User{Login: login} + if !foundUser && params.Login != nil && *params.Login != "" { + user = &models.User{Login: *params.Login} foundUser, err = s.getUser(user) if err != nil { - return false, nil, err + return nil, err } } if !foundUser { - return false, nil, models.ErrUserNotFound + return nil, models.ErrUserNotFound } - return foundUser, user, nil + return user, nil } func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*models.UserAuth, error) { @@ -188,7 +189,7 @@ func (s *Implementation) LookupAndUpdate(ctx context.Context, query *models.GetU // 2. FindByUserDetails if !foundUser { - _, user, err = s.LookupByOneOf(query.UserId, query.Email, query.Login) + user, err = s.LookupByOneOf(ctx, &query.UserLookupParams) if err != nil { return nil, err } diff --git a/pkg/services/login/authinfoservice/user_auth_test.go b/pkg/services/login/authinfoservice/user_auth_test.go index a5cf659db1b8086..dd147106b5db1bf 100644 --- a/pkg/services/login/authinfoservice/user_auth_test.go +++ b/pkg/services/login/authinfoservice/user_auth_test.go @@ -39,7 +39,7 @@ func TestUserAuth(t *testing.T) { // By Login login := "loginuser0" - query := &models.GetUserByAuthInfoQuery{Login: login} + query := &models.GetUserByAuthInfoQuery{UserLookupParams: models.UserLookupParams{Login: &login}} user, err := srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -48,7 +48,9 @@ func TestUserAuth(t *testing.T) { // By ID id := user.Id - _, user, err = srv.LookupByOneOf(id, "", "") + user, err = srv.LookupByOneOf(context.Background(), &models.UserLookupParams{ + UserID: &id, + }) require.Nil(t, err) require.Equal(t, user.Id, id) @@ -56,7 +58,9 @@ func TestUserAuth(t *testing.T) { // By Email email := "user1@test.com" - _, user, err = srv.LookupByOneOf(0, email, "") + user, err = srv.LookupByOneOf(context.Background(), &models.UserLookupParams{ + Email: &email, + }) require.Nil(t, err) require.Equal(t, user.Email, email) @@ -64,7 +68,9 @@ func TestUserAuth(t *testing.T) { // Don't find nonexistent user email = "nonexistent@test.com" - _, user, err = srv.LookupByOneOf(0, email, "") + user, err = srv.LookupByOneOf(context.Background(), &models.UserLookupParams{ + Email: &email, + }) require.Equal(t, models.ErrUserNotFound, err) require.Nil(t, user) @@ -81,7 +87,7 @@ func TestUserAuth(t *testing.T) { // create user_auth entry login := "loginuser0" - query.Login = login + query.UserLookupParams.Login = &login user, err = srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -95,9 +101,9 @@ func TestUserAuth(t *testing.T) { require.Equal(t, user.Login, login) // get with non-matching id - id := user.Id + idPlusOne := user.Id + 1 - query.UserId = id + 1 + query.UserLookupParams.UserID = &idPlusOne user, err = srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -140,7 +146,9 @@ func TestUserAuth(t *testing.T) { login := "loginuser0" // Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table - query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"} + query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -189,7 +197,9 @@ func TestUserAuth(t *testing.T) { // Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table // Make the first log-in during the past getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } - query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"} + query := &models.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), query) getTime = time.Now @@ -199,7 +209,9 @@ func TestUserAuth(t *testing.T) { // Add a second auth module for this user // Have this module's last log-in be more recent getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) } - query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"} + query = &models.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err = srv.LookupAndUpdate(context.Background(), query) getTime = time.Now @@ -239,16 +251,21 @@ func TestUserAuth(t *testing.T) { // Expect to pass since there's a matching login user getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } - query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""} + query := &models.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), query) getTime = time.Now require.Nil(t, err) require.Equal(t, user.Login, login) + otherLoginUser := "aloginuser" // Should throw a "user not found" error since there's no matching login user getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } - query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""} + query = &models.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: models.UserLookupParams{ + Login: &otherLoginUser, + }} user, err = srv.LookupAndUpdate(context.Background(), query) getTime = time.Now diff --git a/pkg/services/login/loginservice/loginservice.go b/pkg/services/login/loginservice/loginservice.go index 6c088feb1399240..789a081f378b07b 100644 --- a/pkg/services/login/loginservice/loginservice.go +++ b/pkg/services/login/loginservice/loginservice.go @@ -45,11 +45,9 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUser extUser := cmd.ExternalUser user, err := ls.AuthInfoService.LookupAndUpdate(ctx, &models.GetUserByAuthInfoQuery{ - AuthModule: extUser.AuthModule, - AuthId: extUser.AuthId, - UserId: extUser.UserId, - Email: extUser.Email, - Login: extUser.Login, + AuthModule: extUser.AuthModule, + AuthId: extUser.AuthId, + UserLookupParams: cmd.UserLookupParams, }) if err != nil { if !errors.Is(err, models.ErrUserNotFound) { diff --git a/pkg/services/login/loginservice/loginservice_test.go b/pkg/services/login/loginservice/loginservice_test.go index b4aee777aaf913d..4803918752beb5e 100644 --- a/pkg/services/login/loginservice/loginservice_test.go +++ b/pkg/services/login/loginservice/loginservice_test.go @@ -93,10 +93,12 @@ func Test_teamSync(t *testing.T) { AuthInfoService: authInfoMock, } - upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: "test_user@example.org"}} + email := "test_user@example.org" + upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: email}, + UserLookupParams: models.UserLookupParams{Email: &email}} expectedUser := &models.User{ Id: 1, - Email: "test_user@example.org", + Email: email, Name: "test_user", Login: "test_user", }