diff --git a/auth/user_mgt.go b/auth/user_mgt.go index e1a18a0f..e601d8bc 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -34,6 +34,12 @@ const ( maxLenPayloadCC = 1000 defaultProviderID = "firebase" idToolkitV1Endpoint = "https://identitytoolkit.googleapis.com/v1" + + // Maximum number of users allowed to batch get at a time. + maxGetAccountsBatchSize = 100 + + // Maximum number of users allowed to batch delete at a time. + maxDeleteAccountsBatchSize = 1000 ) // 'REDACTED', encoded as a base64 string. @@ -57,6 +63,9 @@ type UserInfo struct { type UserMetadata struct { CreationTimestamp int64 LastLogInTimestamp int64 + // The time at which the user was last active (ID token refreshed), or 0 if + // the user was never active. + LastRefreshTimestamp int64 } // UserRecord contains metadata associated with a Firebase user account. @@ -491,6 +500,15 @@ func validatePhone(phone string) error { return nil } +func validateProvider(providerID string, providerUID string) error { + if providerID == "" { + return fmt.Errorf("providerID must be a non-empty string") + } else if providerUID == "" { + return fmt.Errorf("providerUID must be a non-empty string") + } + return nil +} + // End of validators // GetUser gets the user data corresponding to the specified user ID. @@ -545,12 +563,13 @@ func (q *userQuery) build() map[string]interface{} { } } +type getAccountInfoResponse struct { + Users []*userQueryResponse `json:"users"` +} + func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) { - var parsed struct { - Users []*userQueryResponse `json:"users"` - } - _, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed) - if err != nil { + var parsed getAccountInfoResponse + if _, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed); err != nil { return nil, err } @@ -561,6 +580,195 @@ func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord return parsed.Users[0].makeUserRecord() } +// A UserIdentifier identifies a user to be looked up. +type UserIdentifier interface { + matches(ur *UserRecord) bool + populate(req *getAccountInfoRequest) +} + +// A UIDIdentifier is used for looking up an account by uid. +// +// See GetUsers function. +type UIDIdentifier struct { + UID string +} + +func (id UIDIdentifier) matches(ur *UserRecord) bool { + return id.UID == ur.UID +} + +func (id UIDIdentifier) populate(req *getAccountInfoRequest) { + req.LocalID = append(req.LocalID, id.UID) +} + +// An EmailIdentifier is used for looking up an account by email. +// +// See GetUsers function. +type EmailIdentifier struct { + Email string +} + +func (id EmailIdentifier) matches(ur *UserRecord) bool { + return id.Email == ur.Email +} + +func (id EmailIdentifier) populate(req *getAccountInfoRequest) { + req.Email = append(req.Email, id.Email) +} + +// A PhoneIdentifier is used for looking up an account by phone number. +// +// See GetUsers function. +type PhoneIdentifier struct { + PhoneNumber string +} + +func (id PhoneIdentifier) matches(ur *UserRecord) bool { + return id.PhoneNumber == ur.PhoneNumber +} + +func (id PhoneIdentifier) populate(req *getAccountInfoRequest) { + req.PhoneNumber = append(req.PhoneNumber, id.PhoneNumber) +} + +// A ProviderIdentifier is used for looking up an account by federated provider. +// +// See GetUsers function. +type ProviderIdentifier struct { + ProviderID string + ProviderUID string +} + +func (id ProviderIdentifier) matches(ur *UserRecord) bool { + for _, userInfo := range ur.ProviderUserInfo { + if id.ProviderID == userInfo.ProviderID && id.ProviderUID == userInfo.UID { + return true + } + } + return false +} + +func (id ProviderIdentifier) populate(req *getAccountInfoRequest) { + req.FederatedUserID = append( + req.FederatedUserID, + federatedUserIdentifier{ProviderID: id.ProviderID, RawID: id.ProviderUID}) +} + +// A GetUsersResult represents the result of the GetUsers() API. +type GetUsersResult struct { + // Set of UserRecords corresponding to the set of users that were requested. + // Only users that were found are listed here. The result set is unordered. + Users []*UserRecord + + // Set of UserIdentifiers that were requested, but not found. + NotFound []UserIdentifier +} + +type federatedUserIdentifier struct { + ProviderID string `json:"providerId,omitempty"` + RawID string `json:"rawId,omitempty"` +} + +type getAccountInfoRequest struct { + LocalID []string `json:"localId,omitempty"` + Email []string `json:"email,omitempty"` + PhoneNumber []string `json:"phoneNumber,omitempty"` + FederatedUserID []federatedUserIdentifier `json:"federatedUserId,omitempty"` +} + +func (req *getAccountInfoRequest) validate() error { + for i := range req.LocalID { + if err := validateUID(req.LocalID[i]); err != nil { + return err + } + } + + for i := range req.Email { + if err := validateEmail(req.Email[i]); err != nil { + return err + } + } + + for i := range req.PhoneNumber { + if err := validatePhone(req.PhoneNumber[i]); err != nil { + return err + } + } + + for i := range req.FederatedUserID { + id := &req.FederatedUserID[i] + if err := validateProvider(id.ProviderID, id.RawID); err != nil { + return err + } + } + + return nil +} + +func isUserFound(id UserIdentifier, urs [](*UserRecord)) bool { + for i := range urs { + if id.matches(urs[i]) { + return true + } + } + return false +} + +// GetUsers returns the user data corresponding to the specified identifiers. +// +// There are no ordering guarantees; in particular, the nth entry in the users +// result list is not guaranteed to correspond to the nth entry in the input +// parameters list. +// +// A maximum of 100 identifiers may be supplied. If more than 100 +// identifiers are supplied, this method returns an error. +// +// Returns the corresponding user records. An error is returned instead if any +// of the identifiers are invalid or if more than 100 identifiers are +// specified. +func (c *baseClient) GetUsers( + ctx context.Context, identifiers []UserIdentifier, +) (*GetUsersResult, error) { + if len(identifiers) == 0 { + return &GetUsersResult{[](*UserRecord){}, [](UserIdentifier){}}, nil + } else if len(identifiers) > maxGetAccountsBatchSize { + return nil, fmt.Errorf( + "`identifiers` parameter must have <= %d entries", maxGetAccountsBatchSize) + } + + var request getAccountInfoRequest + for i := range identifiers { + identifiers[i].populate(&request) + } + + if err := request.validate(); err != nil { + return nil, err + } + + var parsed getAccountInfoResponse + if _, err := c.post(ctx, "/accounts:lookup", request, &parsed); err != nil { + return nil, err + } + + var userRecords [](*UserRecord) + for _, user := range parsed.Users { + userRecord, err := user.makeUserRecord() + if err != nil { + return nil, err + } + userRecords = append(userRecords, userRecord) + } + + var notFound []UserIdentifier + for i := range identifiers { + if !isUserFound(identifiers[i], userRecords) { + notFound = append(notFound, identifiers[i]) + } + } + + return &GetUsersResult{userRecords, notFound}, nil +} + type userQueryResponse struct { UID string `json:"localId,omitempty"` DisplayName string `json:"displayName,omitempty"` @@ -569,6 +777,7 @@ type userQueryResponse struct { PhotoURL string `json:"photoUrl,omitempty"` CreationTimestamp int64 `json:"createdAt,string,omitempty"` LastLogInTimestamp int64 `json:"lastLoginAt,string,omitempty"` + LastRefreshAt string `json:"lastRefreshAt,omitempty"` ProviderID string `json:"providerId,omitempty"` CustomAttributes string `json:"customAttributes,omitempty"` Disabled bool `json:"disabled,omitempty"` @@ -592,8 +801,7 @@ func (r *userQueryResponse) makeUserRecord() (*UserRecord, error) { func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error) { var customClaims map[string]interface{} if r.CustomAttributes != "" { - err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims) - if err != nil { + if err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims); err != nil { return nil, err } if len(customClaims) == 0 { @@ -609,6 +817,15 @@ func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error hash = "" } + var lastRefreshTimestamp int64 + if r.LastRefreshAt != "" { + t, err := time.Parse(time.RFC3339, r.LastRefreshAt) + if err != nil { + return nil, err + } + lastRefreshTimestamp = t.Unix() * 1000 + } + return &ExportedUserRecord{ UserRecord: &UserRecord{ UserInfo: &UserInfo{ @@ -626,8 +843,9 @@ func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error TenantID: r.TenantID, TokensValidAfterMillis: r.ValidSinceSeconds * 1000, UserMetadata: &UserMetadata{ - LastLogInTimestamp: r.LastLogInTimestamp, - CreationTimestamp: r.CreationTimestamp, + LastLogInTimestamp: r.LastLogInTimestamp, + CreationTimestamp: r.CreationTimestamp, + LastRefreshTimestamp: lastRefreshTimestamp, }, }, PasswordHash: hash, @@ -728,6 +946,91 @@ func (c *baseClient) DeleteUser(ctx context.Context, uid string) error { return err } +// A DeleteUsersResult represents the result of the DeleteUsers() call. +type DeleteUsersResult struct { + // The number of users that were deleted successfully (possibly zero). Users + // that did not exist prior to calling DeleteUsers() are considered to be + // successfully deleted. + SuccessCount int + + // The number of users that failed to be deleted (possibly zero). + FailureCount int + + // A list of DeleteUsersErrorInfo instances describing the errors that were + // encountered during the deletion. Length of this list is equal to the value + // of FailureCount. + Errors []*DeleteUsersErrorInfo +} + +// DeleteUsersErrorInfo represents an error encountered while deleting a user +// account. +// +// The Index field corresponds to the index of the failed user in the uids +// array that was passed to DeleteUsers(). +type DeleteUsersErrorInfo struct { + Index int `json:"index,omitEmpty"` + Reason string `json:"message,omitEmpty"` +} + +// DeleteUsers deletes the users specified by the given identifiers. +// +// Deleting a non-existing user won't generate an error. (i.e. this method is +// idempotent.) Non-existing users are considered to be successfully +// deleted, and are therefore counted in the DeleteUsersResult.SuccessCount +// value. +// +// A maximum of 1000 identifiers may be supplied. If more than 1000 +// identifiers are supplied, this method returns an error. +// +// This API is currently rate limited at the server to 1 QPS. If you exceed +// this, you may get a quota exceeded error. Therefore, if you want to delete +// more than 1000 users, you may need to add a delay to ensure you don't go +// over this limit. +// +// Returns the total number of successful/failed deletions, as well as the +// array of errors that correspond to the failed deletions. An error is +// returned if any of the identifiers are invalid or if more than 1000 +// identifiers are specified. +func (c *baseClient) DeleteUsers(ctx context.Context, uids []string) (*DeleteUsersResult, error) { + if len(uids) == 0 { + return &DeleteUsersResult{}, nil + } else if len(uids) > maxDeleteAccountsBatchSize { + return nil, fmt.Errorf( + "`uids` parameter must have <= %d entries", maxDeleteAccountsBatchSize) + } + + var payload struct { + LocalIds []string `json:"localIds"` + Force bool `json:"force"` + } + payload.Force = true + + for i := range uids { + if err := validateUID(uids[i]); err != nil { + return nil, err + } + + payload.LocalIds = append(payload.LocalIds, uids[i]) + } + + type batchDeleteAccountsResponse struct { + Errors []*DeleteUsersErrorInfo `json:"errors"` + } + + resp := batchDeleteAccountsResponse{} + if _, err := c.post(ctx, "/accounts:batchDelete", payload, &resp); err != nil { + return nil, err + } + + result := DeleteUsersResult{ + FailureCount: len(resp.Errors), + SuccessCount: len(uids) - len(resp.Errors), + Errors: resp.Errors, + } + + return &result, nil +} + // SessionCookie creates a new Firebase session cookie from the given ID token and expiry // duration. The returned JWT can be set as a server-side session cookie with a custom cookie // policy. Expiry duration must be at least 5 minutes but may not exceed 14 days. diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index b2591f1e..0ee5c678 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -24,6 +24,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "sort" "strconv" "strings" "testing" @@ -157,6 +158,236 @@ func TestInvalidGetUser(t *testing.T) { } } +// Checks to see if the users list contain the given uids. Order is ignored. +// +// Behaviour is undefined if there are duplicate entries in either of the +// slices. +// +// This function is identical to the one in integration/auth/user_mgt_test.go +func sameUsers(users [](*UserRecord), uids []string) bool { + if len(users) != len(uids) { + return false + } + + sort.Slice(users, func(i, j int) bool { + return users[i].UID < users[j].UID + }) + sort.Slice(uids, func(i, j int) bool { + return uids[i] < uids[j] + }) + + for i := range users { + if users[i].UID != uids[i] { + return false + } + } + + return true +} + +func TestGetUsersExceeds100(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + var identifiers [101]UserIdentifier + for i := 0; i < 101; i++ { + identifiers[i] = &UIDIdentifier{UID: fmt.Sprintf("id%d", i)} + } + + getUsersResult, err := client.GetUsers(context.Background(), identifiers[:]) + want := "`identifiers` parameter must have <= 100 entries" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf( + "GetUsers() = (%v, %q); want = (nil, %q)", + getUsersResult, err, want) + } +} + +func TestGetUsersEmpty(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers(context.Background(), [](UserIdentifier){}) + if getUsersResult == nil || err != nil { + t.Fatalf("GetUsers([]) = %q", err) + } + + if len(getUsersResult.Users) != 0 { + t.Errorf("len(GetUsers([]).Users) = %d; want 0", len(getUsersResult.Users)) + } + if len(getUsersResult.NotFound) != 0 { + t.Errorf("len(GetUsers([]).NotFound) = %d; want 0", len(getUsersResult.NotFound)) + } +} + +func TestGetUsersAllNonExisting(t *testing.T) { + resp := `{ + "kind" : "identitytoolkit#GetAccountInfoResponse", + "users" : [] + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + notFoundIds := []UserIdentifier{&UIDIdentifier{"id that doesnt exist"}} + getUsersResult, err := s.Client.GetUsers(context.Background(), notFoundIds) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if len(getUsersResult.Users) != 0 { + t.Errorf( + "len(GetUsers().Users) = %d; want 0", + len(getUsersResult.Users)) + } + if len(getUsersResult.NotFound) != len(notFoundIds) { + t.Errorf("len(GetUsers()).NotFound) = %d; want %d", + len(getUsersResult.NotFound), len(notFoundIds)) + } else { + for i := range notFoundIds { + if getUsersResult.NotFound[i] != notFoundIds[i] { + t.Errorf("GetUsers().NotFound[%d] = %v; want %v", + i, getUsersResult.NotFound[i], notFoundIds[i]) + } + } + } +} + +func TestGetUsersInvalidUid(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers( + context.Background(), + []UserIdentifier{&UIDIdentifier{"too long " + strings.Repeat(".", 128)}}) + want := "uid string must not be longer than 128 characters" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersInvalidEmail(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers( + context.Background(), + []UserIdentifier{EmailIdentifier{"invalid email addr"}}) + want := `malformed email string: "invalid email addr"` + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersInvalidPhoneNumber(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers(context.Background(), []UserIdentifier{ + PhoneIdentifier{"invalid phone number"}, + }) + want := "phone number must be a valid, E.164 compliant identifier" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersInvalidProvider(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers(context.Background(), []UserIdentifier{ + ProviderIdentifier{ProviderID: "", ProviderUID: ""}, + }) + want := "providerID must be a non-empty string" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersSingleBadIdentifier(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + identifiers := []UserIdentifier{ + UIDIdentifier{"valid_id1"}, + UIDIdentifier{"valid_id2"}, + UIDIdentifier{"invalid id; too long. " + strings.Repeat(".", 128)}, + UIDIdentifier{"valid_id3"}, + UIDIdentifier{"valid_id4"}, + } + + getUsersResult, err := client.GetUsers(context.Background(), identifiers) + want := "uid string must not be longer than 128 characters" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersMultipleIdentifierTypes(t *testing.T) { + mockUsers := []byte(` + { + "users": [{ + "localId": "uid1", + "email": "user1@example.com", + "phoneNumber": "+15555550001" + }, { + "localId": "uid2", + "email": "user2@example.com", + "phoneNumber": "+15555550002" + }, { + "localId": "uid3", + "email": "user3@example.com", + "phoneNumber": "+15555550003" + }, { + "localId": "uid4", + "email": "user4@example.com", + "phoneNumber": "+15555550004", + "providerUserInfo": [{ + "providerId": "google.com", + "rawId": "google_uid4" + }] + }] + }`) + s := echoServer(mockUsers, t) + defer s.Close() + + identifiers := []UserIdentifier{ + &UIDIdentifier{"uid1"}, + &EmailIdentifier{"user2@example.com"}, + &PhoneIdentifier{"+15555550003"}, + &ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_uid4"}, + &UIDIdentifier{"this-user-doesnt-exist"}, + } + + getUsersResult, err := s.Client.GetUsers(context.Background(), identifiers) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if !sameUsers(getUsersResult.Users, []string{"uid1", "uid2", "uid3", "uid4"}) { + t.Errorf("GetUsers() = %v; want = (uids from) %v (in any order)", + getUsersResult.Users, []string{"uid1", "uid2", "uid3", "uid4"}) + } + if len(getUsersResult.NotFound) != 1 { + t.Errorf("GetUsers() = %d; want = 1", len(getUsersResult.NotFound)) + } else { + if id, ok := getUsersResult.NotFound[0].(*UIDIdentifier); !ok { + t.Errorf("GetUsers().NotFound[0] not a UIDIdentifier") + } else { + if id.UID != "this-user-doesnt-exist" { + t.Errorf("GetUsers().NotFound[0].UID = %s; want = 'this-user-doesnt-exist'", id.UID) + } + } + } +} + func TestGetNonExistingUser(t *testing.T) { resp := `{ "kind" : "identitytoolkit#GetAccountInfoResponse", @@ -1079,6 +1310,110 @@ func TestInvalidDeleteUser(t *testing.T) { } } +func TestDeleteUsers(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + t.Run("should succeed given an empty list", func(t *testing.T) { + result, err := client.DeleteUsers(context.Background(), []string{}) + + if err != nil { + t.Fatalf("DeleteUsers([]) error %v; want = nil", err) + } + + if result.SuccessCount != 0 { + t.Errorf("DeleteUsers([]).SuccessCount = %d; want = 0", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers([]).FailureCount = %d; want = 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers([]).Errors) = %d; want = 0", len(result.Errors)) + } + }) + + t.Run("should be rejected when given more than 1000 identifiers", func(t *testing.T) { + uids := []string{} + for i := 0; i < 1001; i++ { + uids = append(uids, fmt.Sprintf("id%d", i)) + } + + _, err := client.DeleteUsers(context.Background(), uids) + if err == nil { + t.Fatalf("DeleteUsers([too_many_uids]) error nil; want not nil") + } + + if err.Error() != "`uids` parameter must have <= 1000 entries" { + t.Errorf( + "DeleteUsers([too_many_uids]) returned an error of '%s'; "+ + "expected '`uids` parameter must have <= 1000 entries'", + err.Error()) + } + }) + + t.Run("should immediately fail given an invalid id", func(t *testing.T) { + tooLongUID := "too long " + strings.Repeat(".", 128) + _, err := client.DeleteUsers(context.Background(), []string{tooLongUID}) + + if err == nil { + t.Fatalf("DeleteUsers([too_long_uid]) error nil; want not nil") + } + + if err.Error() != "uid string must not be longer than 128 characters" { + t.Errorf( + "DeleteUsers([too_long_uid]) returned an error of '%s'; "+ + "expected 'uid string must not be longer than 128 characters'", + err.Error()) + } + }) + + t.Run("should index errors correctly in result", func(t *testing.T) { + resp := `{ + "errors": [{ + "index": 0, + "localId": "uid1", + "message": "Error Message 1" + }, { + "index": 2, + "localId": "uid3", + "message": "Error Message 2" + }] + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + result, err := s.Client.DeleteUsers(context.Background(), []string{"uid1", "uid2", "uid3", "uid4"}) + + if err != nil { + t.Fatalf("DeleteUsers([...]) error %v; want = nil", err) + } + + if result.SuccessCount != 2 { + t.Errorf("DeleteUsers([...]).SuccessCount = %d; want 2", result.SuccessCount) + } + if result.FailureCount != 2 { + t.Errorf("DeleteUsers([...]).FailureCount = %d; want 2", result.FailureCount) + } + if len(result.Errors) != 2 { + t.Errorf("len(DeleteUsers([...]).Errors) = %d; want 2", len(result.Errors)) + } else { + if result.Errors[0].Index != 0 { + t.Errorf("DeleteUsers([...]).Errors[0].Index = %d; want 0", result.Errors[0].Index) + } + if result.Errors[0].Reason != "Error Message 1" { + t.Errorf("DeleteUsers([...]).Errors[0].Reason = %s; want Error Message 1", result.Errors[0].Reason) + } + if result.Errors[1].Index != 2 { + t.Errorf("DeleteUsers([...]).Errors[1].Index = %d; want 2", result.Errors[1].Index) + } + if result.Errors[1].Reason != "Error Message 2" { + t.Errorf("DeleteUsers([...]).Errors[1].Reason = %s; want Error Message 2", result.Errors[1].Reason) + } + } + }) +} + func TestMakeExportedUser(t *testing.T) { queryResponse := &userQueryResponse{ UID: "testuser", diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index eab03f22..a9fa7b3a 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -239,8 +239,9 @@ func signInWithCustomTokenForTenant(token string, tenantID string) (string, erro func signInWithPassword(email, password string) (string, error) { req, err := json.Marshal(map[string]interface{}{ - "email": email, - "password": password, + "email": email, + "password": password, + "returnSecureToken": true, }) if err != nil { return "", err diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 1d4ef64e..f5348299 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -23,6 +23,7 @@ import ( "math/rand" "net/url" "reflect" + "sort" "strings" "testing" "time" @@ -91,6 +92,179 @@ func TestGetNonExistingUser(t *testing.T) { } } +func TestGetUsers(t *testing.T) { + // Checks to see if the users list contain the given uids. Order is ignored. + // + // Behaviour is undefined if there are duplicate entries in either of the + // slices. + // + // This function is identical to the one in auth/user_mgt_test.go + sameUsers := func(users [](*auth.UserRecord), uids []string) bool { + if len(users) != len(uids) { + return false + } + + sort.Slice(users, func(i, j int) bool { + return users[i].UID < users[j].UID + }) + sort.Slice(uids, func(i, j int) bool { + return uids[i] < uids[j] + }) + + for i := range users { + if users[i].UID != uids[i] { + return false + } + } + + return true + } + + testUser1 := newUserWithParams(t) + defer deleteUser(testUser1.UID) + testUser2 := newUserWithParams(t) + defer deleteUser(testUser2.UID) + testUser3 := newUserWithParams(t) + defer deleteUser(testUser3.UID) + + importUser1UID := randomUID() + importUser1 := (&auth.UserToImport{}). + UID(importUser1UID). + Email(randomEmail(importUser1UID)). + PhoneNumber(randomPhoneNumber()). + ProviderData([](*auth.UserProvider){ + &auth.UserProvider{ + ProviderID: "google.com", + UID: "google_" + importUser1UID, + }, + }) + importUser(t, importUser1UID, importUser1) + defer deleteUser(importUser1UID) + + userRecordsToUIDs := func(users [](*auth.UserRecord)) []string { + results := []string{} + for i := range users { + results = append(results, users[i].UID) + } + return results + } + + t.Run("various identifier types", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: testUser1.UID}, + auth.EmailIdentifier{Email: testUser2.Email}, + auth.PhoneIdentifier{PhoneNumber: testUser3.PhoneNumber}, + auth.ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_" + importUser1UID}, + }) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if !sameUsers(getUsersResult.Users, []string{testUser1.UID, testUser2.UID, testUser3.UID, importUser1UID}) { + t.Errorf("GetUsers() = %v; want = %v (in any order)", + userRecordsToUIDs(getUsersResult.Users), []string{testUser1.UID, testUser2.UID, testUser3.UID, importUser1UID}) + } + }) + + t.Run("mix of existing and non-existing users", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: testUser1.UID}, + auth.UIDIdentifier{UID: "uid_that_doesnt_exist"}, + auth.UIDIdentifier{UID: testUser3.UID}, + }) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if !sameUsers(getUsersResult.Users, []string{testUser1.UID, testUser3.UID}) { + t.Errorf("GetUsers() = %v; want = %v (in any order)", + getUsersResult.Users, []string{testUser1.UID, testUser3.UID}) + } + if len(getUsersResult.NotFound) != 1 { + t.Errorf("len(GetUsers().NotFound) = %d; want 1", len(getUsersResult.NotFound)) + } else { + if getUsersResult.NotFound[0].(auth.UIDIdentifier).UID != "uid_that_doesnt_exist" { + t.Errorf("GetUsers().NotFound[0].UID = %s; want 'uid_that_doesnt_exist'", + getUsersResult.NotFound[0].(auth.UIDIdentifier).UID) + } + } + }) + + t.Run("only non-existing users", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: "non-existing user"}, + }) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if len(getUsersResult.Users) != 0 { + t.Errorf("len(GetUsers().Users) = %d; want = 0", len(getUsersResult.Users)) + } + if len(getUsersResult.NotFound) != 1 { + t.Errorf("len(GetUsers().NotFound) = %d; want = 1", len(getUsersResult.NotFound)) + } else { + if getUsersResult.NotFound[0].(auth.UIDIdentifier).UID != "non-existing user" { + t.Errorf("GetUsers().NotFound[0].UID = %s; want 'non-existing user'", + getUsersResult.NotFound[0].(auth.UIDIdentifier).UID) + } + } + }) + + t.Run("de-dups duplicate users", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: testUser1.UID}, + auth.UIDIdentifier{UID: testUser1.UID}, + }) + if err != nil { + t.Fatalf("GetUsers() returned an error: %v", err) + } + + if len(getUsersResult.Users) != 1 { + t.Errorf("len(GetUsers().Users) = %d; want = 1", len(getUsersResult.Users)) + } else { + if getUsersResult.Users[0].UID != testUser1.UID { + t.Errorf("GetUsers().Users[0].UID = %s; want = '%s'", getUsersResult.Users[0].UID, testUser1.UID) + } + } + if len(getUsersResult.NotFound) != 0 { + t.Errorf("len(GetUsers().NotFound) = %d; want = 0", len(getUsersResult.NotFound)) + } + }) +} + +func TestLastRefreshTime(t *testing.T) { + userRecord := newUserWithParams(t) + defer deleteUser(userRecord.UID) + + // New users should not have a LastRefreshTimestamp set. + if userRecord.UserMetadata.LastRefreshTimestamp != 0 { + t.Errorf( + "CreateUser(...).UserMetadata.LastRefreshTimestamp = %d; want = 0", + userRecord.UserMetadata.LastRefreshTimestamp) + } + + // Login to cause the LastRefreshTimestamp to be set + if _, err := signInWithPassword(userRecord.Email, "password"); err != nil { + t.Errorf("signInWithPassword failed: %v", err) + } + + getUsersResult, err := client.GetUser(context.Background(), userRecord.UID) + if err != nil { + t.Fatalf("GetUser(...) failed with error: %v", err) + } + + // Ensure last refresh time is approx now (with tollerance of 10m) + nowMillis := time.Now().Unix() * 1000 + lastRefreshTimestamp := getUsersResult.UserMetadata.LastRefreshTimestamp + if lastRefreshTimestamp < nowMillis-10*60*1000 { + t.Errorf("GetUser(...).UserMetadata.LastRefreshTimestamp = %d; want >= %d", lastRefreshTimestamp, nowMillis-10*60*1000) + } + if nowMillis+10*60*1000 < lastRefreshTimestamp { + t.Errorf("GetUser(...).UserMetadata.LastRefreshTimestamp = %d; want <= %d", lastRefreshTimestamp, nowMillis+10*60*1000) + } +} + func TestUpdateNonExistingUser(t *testing.T) { update := (&auth.UserToUpdate{}).Email("test@example.com") user, err := client.UpdateUser(context.Background(), "non.existing", update) @@ -334,6 +508,108 @@ func TestDeleteUser(t *testing.T) { } } +func TestDeleteUsers(t *testing.T) { + // Ensures the specified users don't exist. Expected to be called after + // deleting the users to ensure the delete method worked. + ensureUsersNotFound := func(t *testing.T, uids []string) { + identifiers := []auth.UserIdentifier{} + for i := range uids { + identifiers = append(identifiers, auth.UIDIdentifier{UID: uids[i]}) + } + + getUsersResult, err := client.GetUsers(context.Background(), identifiers) + if err != nil { + t.Errorf("GetUsers(notfound_ids) error %v; want nil", err) + return + } + + if len(getUsersResult.NotFound) != len(uids) { + t.Errorf("len(GetUsers(notfound_ids).NotFound) = %d; want %d", len(getUsersResult.NotFound), len(uids)) + return + } + + sort.Strings(uids) + notFoundUids := []string{} + for i := range getUsersResult.NotFound { + notFoundUids = append(notFoundUids, getUsersResult.NotFound[i].(auth.UIDIdentifier).UID) + } + sort.Strings(notFoundUids) + for i := range uids { + if notFoundUids[i] != uids[i] { + t.Errorf("GetUsers(deleted_ids).NotFound[%d] = %s; want %s", i, notFoundUids[i], uids[i]) + } + } + } + + t.Run("deletes users", func(t *testing.T) { + uids := []string{ + newUserWithParams(t).UID, newUserWithParams(t).UID, newUserWithParams(t).UID, + } + + result, err := client.DeleteUsers(context.Background(), uids) + if err != nil { + t.Fatalf("DeleteUsers([valid_ids]) error %v; want nil", err) + } + + if result.SuccessCount != 3 { + t.Errorf("DeleteUsers([valid_ids]).SuccessCount = %d; want 3", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers([valid_ids]).FailureCount = %d; want 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers([valid_ids]).Errors) = %d; want 0", len(result.Errors)) + } + + ensureUsersNotFound(t, uids) + }) + + t.Run("deletes users that exist even when non-existing users also specified", func(t *testing.T) { + uids := []string{newUserWithParams(t).UID, "uid-that-doesnt-exist"} + result, err := client.DeleteUsers(context.Background(), uids) + if err != nil { + t.Fatalf("DeleteUsers(uids) error %v; want nil", err) + } + + if result.SuccessCount != 2 { + t.Errorf("DeleteUsers(uids).SuccessCount = %d; want 2", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers(uids).FailureCount = %d; want 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers(uids).Errors) = %d; want 0", len(result.Errors)) + } + + ensureUsersNotFound(t, uids) + }) + + t.Run("is idempotent", func(t *testing.T) { + deleteUserAndEnsureSuccess := func(t *testing.T, uids []string) { + result, err := client.DeleteUsers(context.Background(), uids) + if err != nil { + t.Fatalf("DeleteUsers(uids) error %v; want nil", err) + } + + if result.SuccessCount != 1 { + t.Errorf("DeleteUsers(uids).SuccessCount = %d; want 1", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers(uids).FailureCount = %d; want 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers(uids).Errors) = %d; want 0", len(result.Errors)) + } + } + + uids := []string{newUserWithParams(t).UID} + deleteUserAndEnsureSuccess(t, uids) + + // Delete the user again, ensuring that everything still counts as a success. + deleteUserAndEnsureSuccess(t, uids) + }) +} + func TestImportUsers(t *testing.T) { uid := randomUID() email := randomEmail(uid) @@ -660,3 +936,26 @@ func newUserWithParams(t *testing.T) *auth.UserRecord { } return user } + +// Helper to import a user and return its UserRecord. Upon error, exits via +// t.Fatalf. `uid` must match the UID set on the `userToImport` parameter. +func importUser(t *testing.T, uid string, userToImport *auth.UserToImport) *auth.UserRecord { + userImportResult, err := client.ImportUsers( + context.Background(), [](*auth.UserToImport){userToImport}) + if err != nil { + t.Fatalf("Unable to import user %v (uid %v): %v", *userToImport, uid, err) + } + + if userImportResult.FailureCount > 0 { + t.Fatalf("Unable to import user %v (uid %v): %v", *userToImport, uid, userImportResult.Errors[0].Reason) + } + if userImportResult.SuccessCount != 1 { + t.Fatalf("Import didn't fail, but it didn't succeed either?") + } + + userRecord, err := client.GetUser(context.Background(), uid) + if err != nil { + t.Fatalf("GetUser(%s) for imported user failed: %v", uid, err) + } + return userRecord +}