diff --git a/auth/import_users.go b/auth/import_users.go index 0febf31f..198d2613 100644 --- a/auth/import_users.go +++ b/auth/import_users.go @@ -236,11 +236,8 @@ func (u *UserToImport) validatedUserInfo() (map[string]interface{}, error) { if providers, ok := info["providerUserInfo"]; ok { for _, p := range providers.([]*UserProvider) { - if p.UID == "" { - return nil, fmt.Errorf("user provdier must specify a uid") - } - if p.ProviderID == "" { - return nil, fmt.Errorf("user provider must specify a provider ID") + if err := validateProviderUserInfo(p); err != nil { + return nil, err } } } diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 7bb54426..e4ada620 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -30,8 +30,9 @@ import ( ) const ( - maxLenPayloadCC = 1000 - defaultProviderID = "firebase" + maxLenPayloadCC = 1000 + defaultProviderID = "firebase" + idToolkitV1Endpoint = "https://identitytoolkit.googleapis.com/v1" // Maximum number of users allowed to batch get at a time. maxGetAccountsBatchSize = 100 @@ -217,6 +218,34 @@ func (u *UserToUpdate) PhotoURL(url string) *UserToUpdate { return u.set("photoUrl", url) } +// ProviderToLink links this user to the specified provider. +// +// Linking a provider to an existing user account does not invalidate the +// refresh token of that account. In other words, the existing account would +// continue to be able to access resources, despite not having used the newly +// linked provider to log in. If you wish to force the user to authenticate +// with this new provider, you need to (a) revoke their refresh token (see +// https://firebase.google.com/docs/auth/admin/manage-sessions#revoke_refresh_tokens), +// and (b) ensure no other authentication methods are present on this account. +func (u *UserToUpdate) ProviderToLink(userProvider *UserProvider) *UserToUpdate { + return u.set("linkProviderUserInfo", userProvider) +} + +// ProvidersToDelete unlinks this user from the specified providers. +func (u *UserToUpdate) ProvidersToDelete(providerIds []string) *UserToUpdate { + // skip setting the value to empty if it's already empty. + if len(providerIds) == 0 { + if u.params == nil { + return u + } + if _, ok := u.params["providersToDelete"]; !ok { + return u + } + } + + return u.set("providersToDelete", providerIds) +} + // revokeRefreshTokens revokes all refresh tokens for a user by setting the validSince property // to the present in epoch seconds. func (u *UserToUpdate) revokeRefreshTokens() *UserToUpdate { @@ -296,6 +325,78 @@ func (u *UserToUpdate) validatedRequest() (map[string]interface{}, error) { return nil, err } } + + if linkProviderUserInfo, ok := req["linkProviderUserInfo"]; ok { + userProvider := linkProviderUserInfo.(*UserProvider) + if err := validateProviderUserInfo(userProvider); err != nil { + return nil, err + } + + // Although we don't really advertise it, we want to also handle linking of + // non-federated idps with this call. So if we detect one of them, we'll + // adjust the properties parameter appropriately. This *does* imply that a + // conflict could arise, e.g. if the user provides a phoneNumber property, + // but also provides a providerToLink with a 'phone' provider id. In that + // case, we'll return an error. + + if userProvider.ProviderID == "email" { + if _, ok := req["email"]; ok { + // We could relax this to only return an error if the email addrs don't + // match. But for now, we'll be extra picky. + return nil, errors.New( + "both UserToUpdate.Email and UserToUpdate.ProviderToLink.ProviderID='email' " + + "were set; to link to the email/password provider, only specify the " + + "UserToUpdate.Email field") + } + req["email"] = userProvider.UID + delete(req, "linkProviderUserInfo") + } else if userProvider.ProviderID == "phone" { + if _, ok := req["phoneNumber"]; ok { + // We could relax this to only return an error if the phone numbers don't + // match. But for now, we'll be extra picky. + return nil, errors.New( + "both UserToUpdate.PhoneNumber and UserToUpdate.ProviderToLink.ProviderID='phone' " + + "were set; to link to the phone provider, only specify the " + + "UserToUpdate.PhoneNumber field") + } + req["phoneNumber"] = userProvider.UID + delete(req, "linkProviderUserInfo") + } + } + + if providersToDelete, ok := req["providersToDelete"]; ok { + var deleteProvider []string + list, ok := req["deleteProvider"] + if ok { + deleteProvider = list.([]string) + } + + for _, providerToDelete := range providersToDelete.([]string) { + if providerToDelete == "" { + return nil, errors.New("providersToDelete must not include empty strings") + } + + // If we've been told to unlink the phone provider both via setting + // phoneNumber to "" *and* by setting providersToDelete to include + // 'phone', then we'll reject that. Though it might also be reasonable to + // relax this restriction and just unlink it. + if providerToDelete == "phone" { + for _, prov := range deleteProvider { + if prov == "phone" { + return nil, errors.New("both UserToUpdate.PhoneNumber='' and " + + "UserToUpdate.ProvidersToDelete=['phone'] were set; to unlink from a " + + "phone provider, only specify the UserToUpdate.PhoneNumber='' field") + } + } + } + + deleteProvider = append(deleteProvider, providerToDelete) + } + + req["deleteProvider"] = deleteProvider + delete(req, "providersToDelete") + } + return req, nil } @@ -455,6 +556,16 @@ func validatePhone(phone string) error { return nil } +func validateProviderUserInfo(p *UserProvider) error { + if p.UID == "" { + return fmt.Errorf("user provider must specify a uid") + } + if p.ProviderID == "" { + return fmt.Errorf("user provider must specify a provider ID") + } + return nil +} + func validateProvider(providerID string, providerUID string) error { if providerID == "" { return fmt.Errorf("providerID must be a non-empty string") @@ -498,6 +609,47 @@ func (c *baseClient) GetUserByPhoneNumber(ctx context.Context, phone string) (*U }) } +// GetUserByProviderID gets the user data for the user corresponding to a given provider ID. +// +// See +// [Retrieve user data](https://firebase.google.com/docs/auth/admin/manage-users#retrieve_user_data) +// for code samples and detailed documentation. +// +// `providerID` indicates the provider, such as 'google.com' for the Google provider. +// `providerUID` is the user identifier for the given provider. +func (c *baseClient) GetUserByProviderID(ctx context.Context, providerID string, providerUID string) (*UserRecord, error) { + // Although we don't really advertise it, we want to also handle non-federated + // IDPs with this call. So if we detect one of them, we'll reroute this + // request appropriately. + if providerID == "phone" { + return c.GetUserByPhoneNumber(ctx, providerUID) + } else if providerID == "email" { + return c.GetUserByEmail(ctx, providerUID) + } + + if err := validateProvider(providerID, providerUID); err != nil { + return nil, err + } + + getUsersResult, err := c.GetUsers(ctx, []UserIdentifier{&ProviderIdentifier{providerID, providerUID}}) + if err != nil { + return nil, err + } + + if len(getUsersResult.Users) == 0 { + return nil, &internal.FirebaseError{ + ErrorCode: internal.NotFound, + String: fmt.Sprintf("cannot find user from providerID: { %s, %s }", providerID, providerUID), + Response: nil, + Ext: map[string]interface{}{ + authErrorCode: userNotFound, + }, + } + } + + return getUsersResult.Users[0], nil +} + type userQuery struct { field string value string diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index e8393a7e..32b2e726 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -141,22 +141,96 @@ func TestGetUserByPhoneNumber(t *testing.T) { } } +func TestGetUserByProviderIDNotFound(t *testing.T) { + mockUsers := []byte(`{ "users": [] }`) + s := echoServer(mockUsers, t) + defer s.Close() + + userRecord, err := s.Client.GetUserByProviderID(context.Background(), "google.com", "google_uid1") + want := "cannot find user from providerID: { google.com, google_uid1 }" + if userRecord != nil || err == nil || err.Error() != want || !IsUserNotFound(err) { + t.Errorf("GetUserByProviderID() = (%v, %q); want = (nil, %q)", userRecord, err, want) + } +} + +func TestGetUserByProviderId(t *testing.T) { + cases := []struct { + providerID string + providerUID string + want string + }{ + { + "google.com", + "google_uid1", + `{"federatedUserId":[{"providerId":"google.com","rawId":"google_uid1"}]}`, + }, { + "phone", + "+15555550001", + `{"phoneNumber":["+15555550001"]}`, + }, { + "email", + "user@example.com", + `{"email":["user@example.com"]}`, + }, + } + + // The resulting user isn't parsed, so it just needs to exist (even if it's empty). + mockUsers := []byte(`{ "users": [{}] }`) + s := echoServer(mockUsers, t) + defer s.Close() + + for _, tc := range cases { + t.Run(tc.providerID+":"+tc.providerUID, func(t *testing.T) { + + _, err := s.Client.GetUserByProviderID(context.Background(), tc.providerID, tc.providerUID) + if err != nil { + t.Fatalf("GetUserByProviderID() = %q", err) + } + + got := string(s.Rbody) + if got != tc.want { + t.Errorf("GetUserByProviderID() Req = %v; want = %v", got, tc.want) + } + + wantPath := "/projects/mock-project-id/accounts:lookup" + if s.Req[0].RequestURI != wantPath { + t.Errorf("GetUserByProviderID() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) + } + }) + } +} + func TestInvalidGetUser(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } + user, err := client.GetUser(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUser('') = (%v, %v); want = (nil, error)", user, err) } + user, err = client.GetUserByEmail(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUserByEmail('') = (%v, %v); want = (nil, error)", user, err) } + user, err = client.GetUserByPhoneNumber(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUserPhoneNumber('') = (%v, %v); want = (nil, error)", user, err) } + + userRecord, err := client.GetUserByProviderID(context.Background(), "", "google_uid1") + want := "providerID must be a non-empty string" + if userRecord != nil || err == nil || err.Error() != want { + t.Errorf("GetUserByProviderID() = (%v, %q); want = (nil, %q)", userRecord, err, want) + } + + userRecord, err = client.GetUserByProviderID(context.Background(), "google.com", "") + want = "providerUID must be a non-empty string" + if userRecord != nil || err == nil || err.Error() != want { + t.Errorf("GetUserByProviderID() = (%v, %q); want = (nil, %q)", userRecord, err, want) + } } // Checks to see if the users list contain the given uids. Order is ignored. @@ -654,6 +728,48 @@ func TestInvalidUpdateUser(t *testing.T) { }, { (&UserToUpdate{}).Password("short"), "password must be a string at least 6 characters long", + }, { + (&UserToUpdate{}).ProviderToLink(&UserProvider{UID: "google_uid"}), + "user provider must specify a provider ID", + }, { + (&UserToUpdate{}).ProviderToLink(&UserProvider{ProviderID: "google.com"}), + "user provider must specify a uid", + }, { + (&UserToUpdate{}).ProviderToLink(&UserProvider{ProviderID: "google.com", UID: ""}), + "user provider must specify a uid", + }, { + (&UserToUpdate{}).ProviderToLink(&UserProvider{ProviderID: "", UID: "google_uid"}), + "user provider must specify a provider ID", + }, { + (&UserToUpdate{}).ProvidersToDelete([]string{""}), + "providersToDelete must not include empty strings", + }, { + (&UserToUpdate{}). + Email("user@example.com"). + ProviderToLink(&UserProvider{ + ProviderID: "email", + UID: "user@example.com", + }), + "both UserToUpdate.Email and UserToUpdate.ProviderToLink.ProviderID='email' " + + "were set; to link to the email/password provider, only specify the " + + "UserToUpdate.Email field", + }, { + (&UserToUpdate{}). + PhoneNumber("+15555550001"). + ProviderToLink(&UserProvider{ + ProviderID: "phone", + UID: "+15555550001", + }), + "both UserToUpdate.PhoneNumber and UserToUpdate.ProviderToLink.ProviderID='phone' " + + "were set; to link to the phone provider, only specify the " + + "UserToUpdate.PhoneNumber field", + }, { + (&UserToUpdate{}). + PhoneNumber(""). + ProvidersToDelete([]string{"phone"}), + "both UserToUpdate.PhoneNumber='' and " + + "UserToUpdate.ProvidersToDelete=['phone'] were set; to unlink from a " + + "phone provider, only specify the UserToUpdate.PhoneNumber='' field", }, } @@ -752,6 +868,43 @@ var updateUserCases = []struct { "deleteProvider": []string{"phone"}, }, }, + { + (&UserToUpdate{}).ProviderToLink(&UserProvider{ + ProviderID: "google.com", + UID: "google_uid", + }), + map[string]interface{}{ + "linkProviderUserInfo": &UserProvider{ + ProviderID: "google.com", + UID: "google_uid", + }}, + }, + { + (&UserToUpdate{}).PhoneNumber("").ProvidersToDelete([]string{"google.com"}), + map[string]interface{}{ + "deleteProvider": []string{"phone", "google.com"}, + }, + }, + { + (&UserToUpdate{}).ProvidersToDelete([]string{"email", "phone"}), + map[string]interface{}{ + "deleteProvider": []string{"email", "phone"}, + }, + }, + { + (&UserToUpdate{}).ProviderToLink(&UserProvider{ + ProviderID: "email", + UID: "user@example.com", + }), + map[string]interface{}{"email": "user@example.com"}, + }, + { + (&UserToUpdate{}).ProviderToLink(&UserProvider{ + ProviderID: "phone", + UID: "+15555550001", + }), + map[string]interface{}{"phoneNumber": "+15555550001"}, + }, { (&UserToUpdate{}).CustomClaims(map[string]interface{}{"a": strings.Repeat("a", 992)}), map[string]interface{}{"customAttributes": fmt.Sprintf(`{"a":%q}`, strings.Repeat("a", 992))}, @@ -1115,7 +1268,7 @@ func TestUserToImportError(t *testing.T) { ProviderID: "google.com", }, }), - "user provdier must specify a uid", + "user provider must specify a uid", }, } diff --git a/firebase.go b/firebase.go index 40f09a0b..68e0fdc7 100644 --- a/firebase.go +++ b/firebase.go @@ -38,7 +38,7 @@ import ( var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. -const Version = "4.2.0" +const Version = "4.3.0" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index c80d5e30..13a27ddf 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -80,6 +80,34 @@ func TestGetUser(t *testing.T) { } } +func TestGetUserByProviderID(t *testing.T) { + // TODO(rsgowman): Once we can link a provider id with a user, just do that + // here instead of importing a new user. + importUserUID := randomUID() + providerUID := "google_" + importUserUID + userToImport := (&auth.UserToImport{}). + UID(importUserUID). + Email(randomEmail(importUserUID)). + PhoneNumber(randomPhoneNumber()). + ProviderData([](*auth.UserProvider){ + &auth.UserProvider{ + ProviderID: "google.com", + UID: providerUID, + }, + }) + importUser(t, importUserUID, userToImport) + defer deleteUser(importUserUID) + + userRecord, err := client.GetUserByProviderID(context.Background(), "google.com", providerUID) + if err != nil { + t.Fatalf("GetUserByProviderID() = %q", err) + } + + if userRecord.UID != importUserUID { + t.Errorf("GetUserByProviderID().UID = %v; want = %v", userRecord.UID, importUserUID) + } +} + func TestGetNonExistingUser(t *testing.T) { user, err := client.GetUser(context.Background(), "non.existing") if user != nil || !auth.IsUserNotFound(err) { @@ -90,6 +118,16 @@ func TestGetNonExistingUser(t *testing.T) { if user != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUserByEmail(non.existing) = (%v, %v); want = (nil, error)", user, err) } + + user, err = client.GetUserByPhoneNumber(context.Background(), "+14044040404") + if user != nil || !auth.IsUserNotFound(err) { + t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) + } + + user, err = client.GetUserByProviderID(context.Background(), "google.com", "a-uid-that-doesnt-exist") + if user != nil || !auth.IsUserNotFound(err) { + t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) + } } func TestGetUsers(t *testing.T) { @@ -386,37 +424,239 @@ func TestCreateUser(t *testing.T) { } func TestUpdateUser(t *testing.T) { - user := newUserWithParams(t) - defer deleteUser(user.UID) + // Creates a new user for testing purposes. The user's uid will be + // '$name_$tenRandomChars' and email will be + // '$name_$tenRandomChars@example.com'. + createTestUser := func(name string) *auth.UserRecord { + // TODO(rsgowman: This function could usefully be employed throughout + // this file. + tenRandomChars := generateRandomAlphaNumericString(10) + userRecord, err := client.CreateUser(context.Background(), + (&auth.UserToCreate{}). + UID(name+"_"+tenRandomChars). + DisplayName(name). + Email(name+"_"+tenRandomChars+"@example.com"), + ) + if err != nil { + t.Fatal(err) + } + return userRecord + } - uid := randomUID() - newEmail := randomEmail(uid) - newPhone := randomPhoneNumber() - want := auth.UserInfo{ - UID: user.UID, - Email: newEmail, - PhoneNumber: newPhone, - DisplayName: "Updated Name", - ProviderID: "firebase", - PhotoURL: "https://example.com/updated.png", - } - params := (&auth.UserToUpdate{}). - Email(newEmail). - PhoneNumber(newPhone). - DisplayName("Updated Name"). - PhotoURL("https://example.com/updated.png"). - EmailVerified(true). - Password("newpassowrd") - got, err := client.UpdateUser(context.Background(), user.UID, params) - if err != nil { - t.Fatal(err) + mapToProviderUIDs := func(userInfos [](*auth.UserInfo)) []string { + providerUIDs := []string{} + for i := range userInfos { + providerUIDs = append(providerUIDs, userInfos[i].UID) + } + return providerUIDs } - if !reflect.DeepEqual(*got.UserInfo, want) { - t.Errorf("UpdateUser().UserInfo = (%#v, %v); want = (%#v, nil)", *got.UserInfo, err, want) + + mapToProviderIDs := func(userInfos [](*auth.UserInfo)) []string { + providerIDs := []string{} + for i := range userInfos { + providerIDs = append(providerIDs, userInfos[i].ProviderID) + } + return providerIDs + } + + contains := func(list []string, target string) bool { + for i := range list { + if list[i] == target { + return true + } + } + return false + } + + containsAll := func(list []string, targets []string) bool { + for i := range targets { + if !contains(list, targets[i]) { + return false + } + } + return true } - if !got.EmailVerified { - t.Error("UpdateUser().EmailVerified = false; want = true") + + containsNone := func(list []string, targets []string) bool { + for i := range targets { + if contains(list, targets[i]) { + return false + } + } + return true } + + updateUser := createTestUser("UpdateUser") + defer deleteUser(updateUser.UID) + + t.Run("SimpleUpdate", func(t *testing.T) { + uid := randomUID() + newEmail := randomEmail(uid) + newPhone := randomPhoneNumber() + want := auth.UserInfo{ + UID: updateUser.UID, + Email: newEmail, + PhoneNumber: newPhone, + DisplayName: "Updated Name", + ProviderID: "firebase", + PhotoURL: "https://example.com/updated.png", + } + params := (&auth.UserToUpdate{}). + Email(newEmail). + PhoneNumber(newPhone). + DisplayName("Updated Name"). + PhotoURL("https://example.com/updated.png"). + EmailVerified(true). + Password("newpassowrd") + got, err := client.UpdateUser(context.Background(), updateUser.UID, params) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*got.UserInfo, want) { + t.Errorf("UpdateUser().UserInfo = (%#v, %v); want = (%#v, nil)", *got.UserInfo, err, want) + } + if !got.EmailVerified { + t.Error("UpdateUser().EmailVerified = false; want = true") + } + }) + + t.Run("LinkFederatedProvider", func(t *testing.T) { + // Link user to federated provider + googleFederatedUID := "google_uid_" + generateRandomAlphaNumericString(10) + params := (&auth.UserToUpdate{}). + ProviderToLink((&auth.UserProvider{ + ProviderID: "google.com", + UID: googleFederatedUID, + })) + userRecord, err := client.UpdateUser(context.Background(), updateUser.UID, params) + if err != nil { + t.Fatal(err) + } + defer func() { + // Unlink user from federated provider + params = (&auth.UserToUpdate{}).ProvidersToDelete([]string{"google.com"}) + userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, params) + if err != nil { + t.Fatal(err) + } + }() + + // Ensure link operation worked as expected + providerUIDs := mapToProviderUIDs(userRecord.ProviderUserInfo) + providerIDs := mapToProviderIDs(userRecord.ProviderUserInfo) + if !contains(providerUIDs, googleFederatedUID) { + t.Errorf("UpdateUser().ProviderUserInfo[*].UID = %v; want include %q", + providerUIDs, googleFederatedUID) + } + if !contains(providerIDs, "google.com") { + t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID = %v; want include 'google.com'", + providerIDs) + } + }) + + t.Run("UnlinkFederatedProvider", func(t *testing.T) { + // Link user to federated provider + googleFederatedUID := "google_uid_" + generateRandomAlphaNumericString(10) + params := (&auth.UserToUpdate{}). + ProviderToLink((&auth.UserProvider{ + ProviderID: "google.com", + UID: googleFederatedUID, + })) + userRecord, err := client.UpdateUser(context.Background(), updateUser.UID, params) + if err != nil { + t.Fatal(err) + } + + // Unlink user from federated provider + params = (&auth.UserToUpdate{}).ProvidersToDelete([]string{"google.com"}) + userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, params) + if err != nil { + t.Fatal(err) + } + + // Ensure unlink operation worked as expected + providerUIDs := mapToProviderUIDs(userRecord.ProviderUserInfo) + providerIDs := mapToProviderIDs(userRecord.ProviderUserInfo) + if contains(providerUIDs, googleFederatedUID) { + t.Errorf("UpdateUser().ProviderUserInfo[*].UID = %v; want NOT include %q", + providerUIDs, googleFederatedUID) + } + if contains(providerIDs, "google.com") { + t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID = %v; want NOT include 'google.com'", + providerIDs) + } + }) + + t.Run("UnlinkMultipleProvidersAtOnce", func(t *testing.T) { + deletePhoneNumberUser(t, "+15555550001") + + googleFederatedUID := "google_uid_" + generateRandomAlphaNumericString(10) + facebookFederatedUID := "facebook_uid_" + generateRandomAlphaNumericString(10) + + userRecord, err := client.UpdateUser(context.Background(), updateUser.UID, + (&auth.UserToUpdate{}). + PhoneNumber("+15555550001"). + ProviderToLink((&auth.UserProvider{ + ProviderID: "google.com", + UID: googleFederatedUID, + }))) + if err != nil { + t.Fatal(err) + } + userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, + (&auth.UserToUpdate{}). + ProviderToLink((&auth.UserProvider{ + ProviderID: "facebook.com", + UID: facebookFederatedUID, + }))) + if err != nil { + t.Fatal(err) + } + + providerUIDs := mapToProviderUIDs(userRecord.ProviderUserInfo) + providerIDs := mapToProviderIDs(userRecord.ProviderUserInfo) + wantAll := []string{googleFederatedUID, facebookFederatedUID, "+15555550001"} + if !containsAll(providerUIDs, wantAll) { + t.Errorf("UpdateUser().ProviderUserInfo[*].UID want include all %v; got %v", + wantAll, providerUIDs) + } + wantAll = []string{"google.com", "facebook.com", "phone"} + if !containsAll(providerIDs, wantAll) { + t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID want include all %v; got %v", + wantAll, providerIDs) + } + + userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, + (&auth.UserToUpdate{}). + ProvidersToDelete([]string{"google.com", "facebook.com", "phone"})) + if err != nil { + t.Fatal(err) + } + + providerUIDs = mapToProviderUIDs(userRecord.ProviderUserInfo) + providerIDs = mapToProviderIDs(userRecord.ProviderUserInfo) + notWantAll := []string{googleFederatedUID, facebookFederatedUID, "+15555550001"} + if !containsNone(providerUIDs, notWantAll) { + t.Errorf("UpdateUser().ProviderUserInfo[*].UID want not include all %v; got %v", + notWantAll, providerUIDs) + } + notWantAll = []string{"google.com", "facebook.com", "phone"} + if !containsNone(providerIDs, notWantAll) { + t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID want not include all %v; got %v", + notWantAll, providerIDs) + } + }) + + t.Run("ErrorsGivenEmptyProvidersToDelete", func(t *testing.T) { + userRecord := createTestUser("ErrorWithEmptyProvidersToDeleteUser") + defer deleteUser(userRecord.UID) + + gotUserRecord, err := client.UpdateUser(context.Background(), userRecord.UID, + (&auth.UserToUpdate{}).ProvidersToDelete([]string{})) + if err == nil || gotUserRecord != nil { + t.Errorf("UpdateUser() = (%#v, nil); want (nil, error)", gotUserRecord) + } + }) } func TestDisableUser(t *testing.T) { @@ -922,27 +1162,35 @@ func signInWithEmailLink(email, oobCode string) (string, error) { var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) func randomUID() string { - var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") - b := make([]rune, 32) - for i := range b { - b[i] = letters[seededRand.Intn(len(letters))] - } - return string(b) + return generateRandomAlphaNumericString(32) } func randomPhoneNumber() string { - var digits = []rune("0123456789") - b := make([]rune, 10) - for i := range b { - b[i] = digits[rand.Intn(len(digits))] - } - return "+1" + string(b) + return "+1" + generateRandomNumericString(10) } func randomEmail(uid string) string { return strings.ToLower(fmt.Sprintf("%s@example.%s.com", uid[:12], uid[12:])) } +func generateRandomNumericString(length int) string { + digits := []rune("0123456789") + return generateRandomString(length, digits) +} + +func generateRandomAlphaNumericString(length int) string { + letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + return generateRandomString(length, letters) +} + +func generateRandomString(length int, runes []rune) string { + b := make([]rune, length) + for i := range b { + b[i] = runes[seededRand.Intn(len(runes))] + } + return string(b) +} + func newUserWithParams(t *testing.T) *auth.UserRecord { uid := randomUID() email := randomEmail(uid) @@ -983,3 +1231,22 @@ func importUser(t *testing.T, uid string, userToImport *auth.UserToImport) *auth } return userRecord } + +// Helper function that deletes the user with the specified phone number if it +// exists. +// TODO(rsgowman): This function was ported from node.js port; a number of tests +// there use this, but haven't been ported to go yet. Do so. +func deletePhoneNumberUser(t *testing.T, phoneNumber string) { + userRecord, err := client.GetUserByPhoneNumber(context.Background(), phoneNumber) + if err != nil { + if auth.IsUserNotFound(err) { + // User already doesn't exist. + return + } + t.Fatal(err) + } + + if err = client.DeleteUser(context.Background(), userRecord.UID); err != nil { + t.Fatal(err) + } +}