From f08042aae4bd62ccba1fa1109b2ad90e25b0b019 Mon Sep 17 00:00:00 2001 From: Andrii Yanakov Date: Thu, 20 Mar 2025 13:26:38 +0200 Subject: [PATCH] Refactoring of the auth module adding the account entity in the logic. --- auth/module.go | 11 ++ auth/password_auth.go | 52 ++++-- auth/password_auth_test.go | 111 ++++++------ auth/plain_token_auth.go | 15 +- auth/plain_token_auth_test.go | 26 +-- auth/repository/account.go | 67 ++++++++ auth/repository/credential.go | 28 +-- auth/repository/identity.go | 51 ++++-- auth/repository/token.go | 8 +- auth/storage/access_token.sql.go | 44 ++--- auth/storage/account.sql.go | 106 ++++++++++++ auth/storage/account_repository.go | 137 +++++++++++++++ auth/storage/credential.sql.go | 51 +++--- auth/storage/credential_repository.go | 32 ++-- auth/storage/fixture/access_token.go | 20 +-- auth/storage/fixture/account.go | 159 +++++++++++++++++ auth/storage/fixture/credential.ext.go | 8 +- auth/storage/fixture/credential.go | 20 +-- auth/storage/fixture/factory.go | 30 +++- auth/storage/fixture/identity.ext.go | 33 +++- auth/storage/fixture/identity.go | 46 ++--- auth/storage/fixture/session.go | 20 +-- auth/storage/identity.sql.go | 161 ++++++++++-------- auth/storage/identity_repository.go | 86 ++++++---- .../migration/20240320084613_auth_account.sql | 53 ++++++ auth/storage/models.go | 82 +++++++-- auth/storage/query/access_token.sql | 12 +- auth/storage/query/account.sql | 33 ++++ auth/storage/query/credential.sql | 14 +- auth/storage/query/identity.sql | 34 ++-- auth/storage/query/refresh_token.sql | 2 +- auth/storage/refresh_token.sql.go | 6 +- auth/storage/token_repository.go | 27 ++- docs/graphql_server_example.md | 25 ++- examples/blog/.env | 4 + examples/blog/go.mod | 4 +- examples/blog/go.sum | 4 +- .../migration/20240320084613_auth_account.sql | 53 ++++++ .../internal/blog/storage/fixture/post.go | 32 ++-- .../internal/user/action/register_user.go | 28 +-- modules.json | 4 + 41 files changed, 1309 insertions(+), 430 deletions(-) create mode 100644 auth/repository/account.go create mode 100644 auth/storage/account.sql.go create mode 100644 auth/storage/account_repository.go create mode 100644 auth/storage/fixture/account.go create mode 100644 auth/storage/migration/20240320084613_auth_account.sql create mode 100644 auth/storage/query/account.sql create mode 100644 examples/blog/internal/auth/storage/migration/20240320084613_auth_account.sql diff --git a/auth/module.go b/auth/module.go index c45631d..668f340 100644 --- a/auth/module.go +++ b/auth/module.go @@ -31,6 +31,7 @@ func NewModule() *module.Module { SetOverriddenProvider("CredentialRepository", storage.NewDefaultCredentialRepository). SetOverriddenProvider("IdentityRepository", storage.NewDefaultIdentityRepository). SetOverriddenProvider("TokenRepository", storage.NewDefaultTokenRepository). + SetOverriddenProvider("AccountRepository", storage.NewDefaultAccountRepository). SetOverriddenProvider("TokenHashStrategy", hash.NewSha1). SetOverriddenProvider( "MiddlewareAuthenticator", func(auth *PlainTokenAuthenticator) Authenticator { @@ -58,6 +59,12 @@ func OverrideTokenRepository(authModule *module.Module, repository interface{}) return authModule.SetOverriddenProvider("TokenRepository", repository) } +// OverrideAccountRepository overrides the default account storage implementation with the custom one. +// repository should be a constructor returning the implementation of the AccountRepository interface. +func OverrideAccountRepository(authModule *module.Module, repository interface{}) *module.Module { + return authModule.SetOverriddenProvider("AccountRepository", repository) +} + // OverrideTokenHashStrategy overrides the default token hash strategy with the custom one. // strategy should be a constructor returning the implementation of the hash.TokenHashStrategy interface. // by default, the sha1 hash strategy is used. @@ -84,6 +91,10 @@ func NewManifestModule() module.ManifestModule { SourceUrl: "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/storage/migration/20240214134322_auth.sql", DestFile: "internal/auth/storage/migration/20240214134322_auth.sql", }, + module.InstalledFile{ + SourceUrl: "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/storage/migration/20240320084613_auth_account.sql", + DestFile: "internal/auth/storage/migration/20240320084613_auth_account.sql", + }, module.InstalledFile{ SourceUrl: "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/install/module.go.tmpl", DestFile: "internal/auth/module.go", diff --git a/auth/password_auth.go b/auth/password_auth.go index 2bc12be..6e34097 100644 --- a/auth/password_auth.go +++ b/auth/password_auth.go @@ -15,15 +15,18 @@ var ErrInvalidPassword = errors.New("invalid password") var ErrCannotHashPassword = errors.New("cannot hash password") type PasswordAuthenticator struct { + accountRepository repository.AccountRepository identityRepository repository.IdentityRepository credentialRepository repository.CredentialRepository } func NewPasswordAuthenticator( + accountRepository repository.AccountRepository, identityRepository repository.IdentityRepository, credentialRepository repository.CredentialRepository, ) *PasswordAuthenticator { return &PasswordAuthenticator{ + accountRepository: accountRepository, identityRepository: identityRepository, credentialRepository: credentialRepository, } @@ -49,7 +52,7 @@ func (a *PasswordAuthenticator) Authenticate(ctx context.Context, identity, pass return Performer{}, errtrace.Wrap(ErrIdentityIsBlocked) } - cred, err := a.credentialRepository.GetLast(ctx, identityObj.ID, string(repository.CredentialTypePassword)) + cred, err := a.credentialRepository.GetLast(ctx, identityObj.AccountID, string(repository.CredentialTypePassword)) if err != nil { if errors.Is(err, repository.ErrCredentialNotFound) { return Performer{}, errtrace.Wrap(ErrInvalidPassword) @@ -61,7 +64,7 @@ func (a *PasswordAuthenticator) Authenticate(ctx context.Context, identity, pass if err != nil { return Performer{}, errtrace.Wrap(ErrInvalidPassword) } - return Performer{ID: identityObj.UserID, SessionID: uuid.Must(uuid.NewV6()), IdentityID: identityObj.ID}, nil + return Performer{ID: identityObj.AccountID, SessionID: uuid.Must(uuid.NewV6()), IdentityID: identityObj.ID}, nil } // Register registers a new user account with the given identity and password. @@ -77,47 +80,64 @@ func (a *PasswordAuthenticator) Register( ctx context.Context, identity, password string, - userID uuid.UUID, + identityType repository.IdentityType, roles []string, additionalData map[string]interface{}, -) (repository.Identity, error) { +) (repository.Account, error) { identityObj, err := a.identityRepository.Get(ctx, identity) if err == nil { if identityObj.IsBlocked() { - return repository.Identity{}, errtrace.Wrap(ErrIdentityIsBlocked) + return repository.Account{}, errtrace.Wrap(ErrIdentityIsBlocked) } - return repository.Identity{}, errtrace.Wrap(repository.ErrIdentityExists) + return repository.Account{}, errtrace.Wrap(repository.ErrIdentityExists) } else if !errors.Is(err, repository.ErrIdentityNotFound) { - return repository.Identity{}, errtrace.Wrap(err) + return repository.Account{}, errtrace.Wrap(err) } - identityObj, err = a.identityRepository.Create(ctx, identity, userID, additionalData) + accountID, err := uuid.NewV6() if err != nil { - return repository.Identity{}, errtrace.Wrap(err) + return repository.Account{}, errtrace.Wrap(err) } + account, err := a.accountRepository.Create(ctx, accountID) + if err != nil { + return repository.Account{}, errtrace.Wrap(err) + } + + identityObj, err = a.identityRepository.Create(ctx, identity, accountID, identityType, additionalData) + if err != nil { + return repository.Account{}, errtrace.Wrap(err) + } + + defer func() { + if err != nil { + _ = a.accountRepository.RemoveAccount(ctx, accountID) + _ = a.identityRepository.RemoveIdentity(ctx, identity) + } + }() + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return repository.Identity{}, errtrace.Wrap(errors.WithCause(ErrCannotHashPassword, err)) + return repository.Account{}, errtrace.Wrap(errors.WithCause(ErrCannotHashPassword, err)) } _, err = a.credentialRepository.Create( ctx, - identityObj.ID, + accountID, string(hash), - string(repository.CredentialTypePassword), + repository.CredentialTypePassword, nil, ) if err != nil { - return repository.Identity{}, errtrace.Wrap(err) + return repository.Account{}, errtrace.Wrap(err) } if len(roles) > 0 { - err = a.identityRepository.AddRoles(ctx, identityObj.ID, roles...) + err = a.accountRepository.AddRoles(ctx, identityObj.ID, roles...) if err != nil { - return repository.Identity{}, errtrace.Wrap(err) + return repository.Account{}, errtrace.Wrap(err) } } - return identityObj, nil + return account, nil } diff --git a/auth/password_auth_test.go b/auth/password_auth_test.go index 3573b21..0c64768 100644 --- a/auth/password_auth_test.go +++ b/auth/password_auth_test.go @@ -2,6 +2,7 @@ package auth_test import ( "context" + "encoding/json" "github.com/go-modulus/modulus/auth" "github.com/go-modulus/modulus/auth/repository" "github.com/go-modulus/modulus/auth/storage" @@ -15,76 +16,81 @@ func TestPasswordAuthenticator_Register(t *testing.T) { t.Run( "register identity without additional data", func(t *testing.T) { t.Parallel() - userId := uuid.Must(uuid.NewV6()) - identity, err := passwordAuth.Register( + account, err := passwordAuth.Register( context.Background(), "user", "password", - userId, + repository.IdentityTypeNickname, []string{}, nil, ) require.NoError(t, err) - savedIdentity := fixtureFactory.Identity().ID(identity.ID).PullUpdates(t).Cleanup(t).GetEntity() - fixtureFactory.Credential().IdentityID(identity.ID).CleanupAllOfIdentity(t) + savedAccount := fixtureFactory.Account().ID(account.ID).PullUpdates(t).Cleanup(t).GetEntity() + savedIdentity := fixtureFactory.Identity().AccountID(account.ID).PullUpdatesLastAccountIdentity(t).CleanupAllOfAccount(t).GetEntity() + fixtureFactory.Credential().AccountID(account.ID).CleanupAllOfAccount(t) - t.Log("When the identity is registered") - t.Log(" Then the identity is returned") + t.Log("When the account is registered") + t.Log(" Then the account is returned") require.NoError(t, err) - require.Equal(t, userId, identity.UserID) - require.Equal(t, "user", identity.Identity) - require.Equal(t, repository.IdentityStatusActive, identity.Status) - require.Empty(t, identity.Data) - - t.Log(" And the identity is saved") - require.Equal(t, identity.UserID, savedIdentity.UserID) - require.Equal(t, identity.Identity, savedIdentity.Identity) + require.Equal(t, repository.AccountStatusActive, account.Status) + + t.Log(" And the account is saved") + require.Equal(t, "user", savedIdentity.Identity) require.Equal(t, storage.IdentityStatusActive, savedIdentity.Status) + + t.Log(" And the account is created") + require.Equal(t, account.ID, savedAccount.ID) + require.Equal(t, storage.AccountStatusActive, savedAccount.Status) }, ) t.Run( "register identity with additional data", func(t *testing.T) { t.Parallel() - userId := uuid.Must(uuid.NewV6()) - identity, err := passwordAuth.Register( + account, err := passwordAuth.Register( context.Background(), "user1", "password", - userId, + repository.IdentityTypeNickname, []string{}, map[string]interface{}{ "key": "value", }, ) - savedIdentity := fixtureFactory.Identity().ID(identity.ID).PullUpdates(t).GetEntity() - fixtureFactory.Credential().IdentityID(identity.ID).CleanupAllOfIdentity(t) - fixtureFactory.Identity().UserID(userId).CleanupAllOfUser(t) + savedAccount := fixtureFactory.Account().ID(account.ID).PullUpdates(t).Cleanup(t).GetEntity() + savedIdentity := fixtureFactory.Identity().AccountID(account.ID).PullUpdatesLastAccountIdentity(t).CleanupAllOfAccount(t).GetEntity() + fixtureFactory.Credential().AccountID(account.ID).CleanupAllOfAccount(t) + fixtureFactory.Identity().AccountID(account.ID).CleanupAllOfAccount(t) - t.Log("When the identity is registered") - t.Log(" Then the identity is returned") + var data map[string]interface{} + errUnmarshal := json.Unmarshal(savedIdentity.Data, &data) + + t.Log("When the account is registered") + t.Log(" Then the account is returned") require.NoError(t, err) - require.Equal(t, userId, identity.UserID) - require.Equal(t, "user1", identity.Identity) - require.Equal(t, repository.IdentityStatusActive, identity.Status) - require.Equal(t, "value", identity.Data["key"]) - - t.Log(" And the identity is saved") - require.Equal(t, identity.UserID, savedIdentity.UserID) - require.Equal(t, identity.Identity, savedIdentity.Identity) + require.Equal(t, repository.AccountStatusActive, account.Status) + + t.Log(" And the account is saved") + require.NoError(t, errUnmarshal) + require.Equal(t, "user1", savedIdentity.Identity) require.Equal(t, storage.IdentityStatusActive, savedIdentity.Status) + require.Equal(t, "value", data["key"]) + + t.Log(" And the account is created") + require.Equal(t, account.ID, savedAccount.ID) + require.Equal(t, storage.AccountStatusActive, savedAccount.Status) }, ) t.Run( "fail on the second registration of identity", func(t *testing.T) { t.Parallel() - userId := uuid.Must(uuid.NewV6()) + account := fixtureFactory.Account().Create(t).GetEntity() identity := fixtureFactory.Identity(). ID(uuid.Must(uuid.NewV6())). - UserID(userId). + AccountID(account.ID). Identity("user2"). Create(t). GetEntity() @@ -92,7 +98,7 @@ func TestPasswordAuthenticator_Register(t *testing.T) { context.Background(), identity.Identity, "password", - userId, + repository.IdentityTypeNickname, []string{}, nil, ) @@ -107,10 +113,10 @@ func TestPasswordAuthenticator_Register(t *testing.T) { t.Run( "fail if identity is blocked", func(t *testing.T) { t.Parallel() - userId := uuid.Must(uuid.NewV6()) + accountId := uuid.Must(uuid.NewV6()) identity := fixtureFactory.Identity(). ID(uuid.Must(uuid.NewV6())). - UserID(userId). + AccountID(accountId). Identity("user3"). Status(storage.IdentityStatusBlocked). Create(t). @@ -120,7 +126,7 @@ func TestPasswordAuthenticator_Register(t *testing.T) { context.Background(), identity.Identity, "password", - userId, + repository.IdentityTypeNickname, []string{}, nil, ) @@ -137,12 +143,12 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) { t.Run( "authenticate", func(t *testing.T) { t.Parallel() - userId := uuid.Must(uuid.NewV6()) - identity, err := passwordAuth.Register( + identity := "user4" + account, err := passwordAuth.Register( context.Background(), - "user4", + identity, "password", - userId, + repository.IdentityTypeNickname, []string{}, map[string]interface{}{ "key": "value", @@ -150,20 +156,21 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) { ) require.NoError(t, err) - fixtureFactory.Identity().ID(identity.ID).Cleanup(t) - fixtureFactory.Credential().IdentityID(identity.ID).CleanupAllOfIdentity(t) + fixtureFactory.Account().ID(account.ID).Cleanup(t) + fixtureFactory.Identity().AccountID(account.ID).CleanupAllOfAccount(t) + fixtureFactory.Credential().AccountID(account.ID).CleanupAllOfAccount(t) performer, err := passwordAuth.Authenticate( context.Background(), - identity.Identity, + identity, "password", ) - t.Log("Given the identity is registered") + t.Log("Given the account is registered") t.Log("When try to authenticate with the correct password") t.Log(" Then the performer is returned") require.NoError(t, err) - require.Equal(t, userId, performer.ID) + require.Equal(t, account.ID, performer.ID) }, ) @@ -184,15 +191,15 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) { "fail if no password in database", func(t *testing.T) { t.Parallel() - userId := uuid.Must(uuid.NewV6()) + account := fixtureFactory.Account().Create(t).GetEntity() identity := fixtureFactory.Identity(). ID(uuid.Must(uuid.NewV6())). Identity("user6"). - UserID(userId). + AccountID(account.ID). Create(t). GetEntity() fixtureFactory.Credential(). - IdentityID(identity.ID). + AccountID(account.ID). Type(string(repository.CredentialTypeOTP)). Hash("ssss"). Create(t) @@ -214,12 +221,12 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) { identity := fixtureFactory.Identity(). ID(uuid.Must(uuid.NewV6())). Identity("user7"). - UserID(userId). + AccountID(userId). Create(t). GetEntity() fixtureFactory.Credential(). - IdentityID(identity.ID). + AccountID(userId). Hash("ssss2"). Create(t) _, err := passwordAuth.Authenticate(context.Background(), identity.Identity, "password") @@ -240,7 +247,7 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) { identity := fixtureFactory.Identity(). ID(uuid.Must(uuid.NewV6())). Identity("user8"). - UserID(userId). + AccountID(userId). Status(storage.IdentityStatusBlocked). Create(t). GetEntity() diff --git a/auth/plain_token_auth.go b/auth/plain_token_auth.go index 305922c..1aa5dfa 100644 --- a/auth/plain_token_auth.go +++ b/auth/plain_token_auth.go @@ -23,17 +23,20 @@ type TokenPair struct { } type PlainTokenAuthenticator struct { + accountRepository repository.AccountRepository tokenRepository repository.TokenRepository identityRepository repository.IdentityRepository config ModuleConfig } func NewPlainTokenAuthenticator( + accountRepository repository.AccountRepository, tokenRepository repository.TokenRepository, identityRepository repository.IdentityRepository, config ModuleConfig, ) *PlainTokenAuthenticator { return &PlainTokenAuthenticator{ + accountRepository: accountRepository, tokenRepository: tokenRepository, identityRepository: identityRepository, config: config, @@ -61,7 +64,7 @@ func (a *PlainTokenAuthenticator) Authenticate(ctx context.Context, token string } return Performer{ - ID: accessToken.UserID, + ID: accessToken.AccountID, SessionID: accessToken.SessionID, Roles: accessToken.Roles, }, nil @@ -197,12 +200,18 @@ func (a *PlainTokenAuthenticator) createAccessToken( err = errtrace.Wrap(errors.WithCause(ErrCannotCreateAccessToken, err)) return } + + account, err := a.accountRepository.Get(ctx, identity.AccountID) + if err != nil { + err = errtrace.Wrap(err) + return + } + accessToken, err = a.tokenRepository.CreateAccessToken( ctx, accessTokenStr, identity.ID, - identity.UserID, - identity.Roles, + account.Roles, sessionID, additionalData, time.Now().Add(a.config.AccessTokenTTL), diff --git a/auth/plain_token_auth_test.go b/auth/plain_token_auth_test.go index b8e026c..a0d8367 100644 --- a/auth/plain_token_auth_test.go +++ b/auth/plain_token_auth_test.go @@ -15,7 +15,8 @@ func TestPlainTokenAuthenticator_StartSession(t *testing.T) { t.Run( "should return a valid pair", func(t *testing.T) { t.Parallel() - identity := fixtureFactory.Identity().Create(t).GetEntity() + account := fixtureFactory.Account().Create(t).GetEntity() + identity := fixtureFactory.Identity().AccountID(account.ID).Create(t).GetEntity() pair, err := plainTokenAuth.IssueTokens( context.Background(), @@ -31,7 +32,7 @@ func TestPlainTokenAuthenticator_StartSession(t *testing.T) { t.Log("When the session is started") t.Log(" Then the access and refresh tokens should be created") require.NoError(t, err) - require.Equal(t, at.UserID, savedAt.UserID) + require.Equal(t, at.AccountID, savedAt.AccountID) require.Equal(t, at.SessionID, savedAt.SessionID) require.Equal(t, rt.SessionID, savedRt.SessionID) @@ -49,7 +50,8 @@ func TestPlainTokenAuthenticator_Authenticate(t *testing.T) { t.Run( "should return a valid performer", func(t *testing.T) { t.Parallel() - identity := fixtureFactory.Identity().Create(t).GetEntity() + account := fixtureFactory.Account().Create(t).GetEntity() + identity := fixtureFactory.Identity().AccountID(account.ID).Create(t).GetEntity() pair, err := plainTokenAuth.IssueTokens( context.Background(), @@ -73,9 +75,9 @@ func TestPlainTokenAuthenticator_Authenticate(t *testing.T) { t.Log("When authenticate the user") t.Log(" Then valid performer should be returned") require.NoError(t, err) - require.Equal(t, at.UserID, performer.ID) + require.Equal(t, at.AccountID, performer.ID) require.Equal(t, at.SessionID, performer.SessionID) - require.Equal(t, identity.Roles, performer.Roles) + require.Equal(t, account.Roles, performer.Roles) }, ) @@ -144,7 +146,8 @@ func TestPlainTokenAuthenticator_IssueNewAccessToken(t *testing.T) { t.Run( "should return a new access token", func(t *testing.T) { t.Parallel() - identity := fixtureFactory.Identity().Create(t).GetEntity() + account := fixtureFactory.Account().Create(t).GetEntity() + identity := fixtureFactory.Identity().AccountID(account.ID).Create(t).GetEntity() pair, err := plainTokenAuth.IssueTokens( context.Background(), @@ -176,9 +179,9 @@ func TestPlainTokenAuthenticator_IssueNewAccessToken(t *testing.T) { t.Log("When authenticate the user") t.Log(" Then valid performer should be returned") require.NoError(t, err) - require.Equal(t, at.UserID, performer.ID) + require.Equal(t, at.AccountID, performer.ID) require.Equal(t, at.SessionID, performer.SessionID) - require.Equal(t, identity.Roles, performer.Roles) + require.Equal(t, account.Roles, performer.Roles) }, ) } @@ -188,7 +191,8 @@ func TestPlainTokenAuthenticator_RefreshAccessToken(t *testing.T) { t.Run( "should return a new access token", func(t *testing.T) { t.Parallel() - identity := fixtureFactory.Identity().Create(t).GetEntity() + account := fixtureFactory.Account().Create(t).GetEntity() + identity := fixtureFactory.Identity().AccountID(account.ID).Create(t).GetEntity() pair, err := plainTokenAuth.IssueTokens( context.Background(), @@ -221,9 +225,9 @@ func TestPlainTokenAuthenticator_RefreshAccessToken(t *testing.T) { t.Log("When authenticate the user") t.Log(" Then valid performer should be returned") require.NoError(t, err) - require.Equal(t, at.UserID, performer.ID) + require.Equal(t, at.AccountID, performer.ID) require.Equal(t, at.SessionID, performer.SessionID) - require.Equal(t, identity.Roles, performer.Roles) + require.Equal(t, account.Roles, performer.Roles) require.True(t, oldToken.ExpiresAt.Before(time.Now())) }, ) diff --git a/auth/repository/account.go b/auth/repository/account.go new file mode 100644 index 0000000..bf108a4 --- /dev/null +++ b/auth/repository/account.go @@ -0,0 +1,67 @@ +package repository + +import ( + "context" + "github.com/go-modulus/modulus/errors" + "github.com/gofrs/uuid" +) + +var ErrAccountExists = errors.New("account exists") +var ErrAccountNotFound = errors.New("account not found") +var ErrCannotCreateAccount = errors.New("cannot create account") + +type Account struct { + ID uuid.UUID `db:"id" json:"id"` + Roles []string `db:"roles" json:"roles"` + Status AccountStatus `db:"status" json:"status"` +} + +func (i Account) IsBlocked() bool { + return i.Status == AccountStatusBlocked +} + +type AccountStatus string + +const ( + AccountStatusActive AccountStatus = "active" + AccountStatusBlocked AccountStatus = "blocked" +) + +type AccountRepository interface { + // Create creates a single new authorization account for the user. + Create( + ctx context.Context, + ID uuid.UUID, + ) (Account, error) + + // Get returns the account by its ID. + // If the identity does not exist, it returns github.com/go-modulus/modulus/auth.ErrAccountNotFound. + Get( + ctx context.Context, + ID uuid.UUID, + ) (Account, error) + + AddRoles( + ctx context.Context, + ID uuid.UUID, + roles ...string, + ) error + + RemoveRoles( + ctx context.Context, + ID uuid.UUID, + roles ...string, + ) error + + // RemoveAccount removes the identity. + RemoveAccount( + ctx context.Context, + ID uuid.UUID, + ) error + + // BlockAccount blocks the identity. + BlockAccount( + ctx context.Context, + ID uuid.UUID, + ) error +} diff --git a/auth/repository/credential.go b/auth/repository/credential.go index c4c64fc..3dc3fa4 100644 --- a/auth/repository/credential.go +++ b/auth/repository/credential.go @@ -12,25 +12,19 @@ var ErrCannotCreateCredential = errors.New("cannot create credential") var ErrCredentialNotFound = errors.New("credential not found") type Credential struct { - Hash string `json:"hash"` - IdentityID uuid.UUID `json:"identityId"` - Type string `json:"type"` - ExpiredAt null.Time `json:"expiredAt"` + Hash string `json:"hash"` + AccountID uuid.UUID `json:"accountId"` + Type CredentialType `json:"type"` + ExpiredAt null.Time `json:"expiredAt"` } type CredentialRepository interface { - // Create creates a new identity for the given user ID. - // If the identity already exists, it returns github.com/go-modulus/modulus/auth/errors.ErrCredentialExists. - // Otherwise, it returns nil. - // The identity is a unique string that represents the user. - // It is used for login and other operations. - // It may be an email, username, or other unique identifier. - // You are able to create multiple identities for a single user. + // Create creates a new credential for the given account ID. Create( ctx context.Context, - identityID uuid.UUID, + accountID uuid.UUID, credentialHash string, - credType string, + credType CredentialType, expiredAt *time.Time, ) (Credential, error) @@ -38,9 +32,15 @@ type CredentialRepository interface { // If the credential does not exist, it returns github.com/go-modulus/modulus/auth.ErrCredentialNotFound. GetLast( ctx context.Context, - identityID uuid.UUID, + accountID uuid.UUID, credType string, ) (Credential, error) + + // RemoveCredentials removes all credentials of the given account ID. + RemoveCredentials( + ctx context.Context, + accountID uuid.UUID, + ) error } type CredentialType string diff --git a/auth/repository/identity.go b/auth/repository/identity.go index f208e84..9f68d56 100644 --- a/auth/repository/identity.go +++ b/auth/repository/identity.go @@ -6,17 +6,25 @@ import ( "github.com/gofrs/uuid" ) +type IdentityType string + +const ( + IdentityTypeEmail IdentityType = "email" + IdentityTypePhone IdentityType = "phone" + IdentityTypeNickname IdentityType = "nickname" +) + var ErrIdentityExists = errors.New("identity exists") var ErrIdentityNotFound = errors.New("identity not found") var ErrCannotCreateIdentity = errors.New("cannot create identity") type Identity struct { - ID uuid.UUID `db:"id" json:"id"` - Identity string `db:"identity" json:"identity"` - UserID uuid.UUID `db:"user_id" json:"userId"` - Roles []string `db:"roles" json:"roles"` - Status IdentityStatus `db:"status" json:"status"` - Data map[string]interface{} `db:"data" json:"data"` + ID uuid.UUID `db:"id" json:"id"` + Identity string `db:"identity" json:"identity"` + AccountID uuid.UUID `db:"user_id" json:"accountId"` + Type IdentityType `db:"type" json:"type"` + Status IdentityStatus `db:"status" json:"status"` + Data map[string]interface{} `db:"data" json:"data"` } func (i Identity) IsBlocked() bool { @@ -31,18 +39,19 @@ const ( ) type IdentityRepository interface { - // Create creates a new identity for the given user ID. + // Create creates a new identity for the given account ID. // If the identity already exists, it returns github.com/go-modulus/modulus/auth.ErrIdentityExists. // If the identity cannot be created, it returns github.com/go-modulus/modulus/auth.ErrCannotCreateIdentity. // // The identity is a unique string that represents the user. // It is used for login and other operations. // It may be an email, username, or other unique identifier. - // You are able to create multiple identities for a single user. + // You are able to create multiple identities for a single account. Create( ctx context.Context, identity string, - userID uuid.UUID, + accountID uuid.UUID, + identityType IdentityType, additionalData map[string]interface{}, ) (Identity, error) @@ -59,15 +68,27 @@ type IdentityRepository interface { id uuid.UUID, ) (Identity, error) - AddRoles( + // GetByAccountID returns the identities with the given account ID. + GetByAccountID( + ctx context.Context, + accountID uuid.UUID, + ) ([]Identity, error) + + // RemoveAccountIdentities removes the identities with the given account ID. + RemoveAccountIdentities( ctx context.Context, - identityID uuid.UUID, - roles ...string, + accountID uuid.UUID, ) error - RemoveRoles( + // RemoveIdentity removes the identity. + RemoveIdentity( ctx context.Context, - identityID uuid.UUID, - roles ...string, + identity string, + ) error + + // BlockIdentity blocks the identity. + BlockIdentity( + ctx context.Context, + identity string, ) error } diff --git a/auth/repository/token.go b/auth/repository/token.go index d933c71..bf74cdf 100644 --- a/auth/repository/token.go +++ b/auth/repository/token.go @@ -25,7 +25,7 @@ type AccessToken struct { Hash string `json:"hash"` IdentityID uuid.UUID `json:"identityId"` SessionID uuid.UUID `json:"sessionId"` - UserID uuid.UUID `json:"userId"` + AccountID uuid.UUID `json:"accountId"` Roles []string `json:"roles"` Data map[string]interface{} `json:"data"` RevokedAt null.Time `json:"revokedAt"` @@ -44,6 +44,7 @@ type RefreshToken struct { type TokenRepository interface { // CreateAccessToken creates an access token. // It returns the created access token. + // Roles in token are obtained from the account of identity. // // Errors: // * ErrCannotCreateAccessToken - if the access token cannot be created. @@ -51,7 +52,6 @@ type TokenRepository interface { ctx context.Context, accessToken string, identityId uuid.UUID, - userId uuid.UUID, roles []string, sessionId uuid.UUID, data map[string]interface{}, @@ -98,7 +98,7 @@ type TokenRepository interface { // It can be used from the user settings to log out from some devices. RevokeSessionTokens(ctx context.Context, sessionId uuid.UUID) error - // RevokeUserTokens revokes all tokens of the user by the given user ID. + // RevokeAccountTokens revokes all tokens of the user by the given account ID. // It can be used from the user settings to log out from all devices. - RevokeUserTokens(ctx context.Context, userId uuid.UUID) error + RevokeAccountTokens(ctx context.Context, accountId uuid.UUID) error } diff --git a/auth/storage/access_token.sql.go b/auth/storage/access_token.sql.go index 16855e1..879421d 100644 --- a/auth/storage/access_token.sql.go +++ b/auth/storage/access_token.sql.go @@ -13,14 +13,14 @@ import ( ) const createAccessToken = `-- name: CreateAccessToken :one -INSERT INTO auth.access_token (hash, identity_id, session_id, user_id, roles, data, expires_at) -VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING hash, identity_id, session_id, user_id, roles, data, revoked_at, expires_at, created_at` +INSERT INTO auth.access_token (hash, identity_id, session_id, account_id, roles, data, expires_at) +VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING hash, identity_id, session_id, account_id, roles, data, revoked_at, expires_at, created_at` type CreateAccessTokenParams struct { Hash string `db:"hash" json:"hash"` IdentityID uuid.UUID `db:"identity_id" json:"identityId"` SessionID uuid.UUID `db:"session_id" json:"sessionId"` - UserID uuid.UUID `db:"user_id" json:"userId"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` Roles []string `db:"roles" json:"roles"` Data []byte `db:"data" json:"data"` ExpiresAt time.Time `db:"expires_at" json:"expiresAt"` @@ -31,7 +31,7 @@ func (q *Queries) CreateAccessToken(ctx context.Context, arg CreateAccessTokenPa arg.Hash, arg.IdentityID, arg.SessionID, - arg.UserID, + arg.AccountID, arg.Roles, arg.Data, arg.ExpiresAt, @@ -41,7 +41,7 @@ func (q *Queries) CreateAccessToken(ctx context.Context, arg CreateAccessTokenPa &i.Hash, &i.IdentityID, &i.SessionID, - &i.UserID, + &i.AccountID, &i.Roles, &i.Data, &i.RevokedAt, @@ -67,19 +67,19 @@ func (q *Queries) ExpireSessionAccessTokens(ctx context.Context, arg ExpireSessi return err } -const getAccessTokenByHash = `-- name: GetAccessTokenByHash :one -SELECT hash, identity_id, session_id, user_id, roles, data, revoked_at, expires_at, created_at +const findAccessTokenByHash = `-- name: FindAccessTokenByHash :one +SELECT hash, identity_id, session_id, account_id, roles, data, revoked_at, expires_at, created_at FROM auth.access_token WHERE hash = $1` -func (q *Queries) GetAccessTokenByHash(ctx context.Context, hash string) (AccessToken, error) { - row := q.db.QueryRow(ctx, getAccessTokenByHash, hash) +func (q *Queries) FindAccessTokenByHash(ctx context.Context, hash string) (AccessToken, error) { + row := q.db.QueryRow(ctx, findAccessTokenByHash, hash) var i AccessToken err := row.Scan( &i.Hash, &i.IdentityID, &i.SessionID, - &i.UserID, + &i.AccountID, &i.Roles, &i.Data, &i.RevokedAt, @@ -89,14 +89,14 @@ func (q *Queries) GetAccessTokenByHash(ctx context.Context, hash string) (Access return i, err } -const getUserNotRevokedSessionIds = `-- name: GetUserNotRevokedSessionIds :many +const findAccountNotRevokedSessionIds = `-- name: FindAccountNotRevokedSessionIds :many SELECT session_id FROM auth.access_token WHERE revoked_at IS NULL - AND user_id = $1` + AND account_id = $1` -func (q *Queries) GetUserNotRevokedSessionIds(ctx context.Context, userID uuid.UUID) ([]uuid.UUID, error) { - rows, err := q.db.Query(ctx, getUserNotRevokedSessionIds, userID) +func (q *Queries) FindAccountNotRevokedSessionIds(ctx context.Context, accountID uuid.UUID) ([]uuid.UUID, error) { + rows, err := q.db.Query(ctx, findAccountNotRevokedSessionIds, accountID) if err != nil { return nil, err } @@ -125,22 +125,22 @@ func (q *Queries) RevokeAccessToken(ctx context.Context, hash string) error { return err } -const revokeSessionAccessTokens = `-- name: RevokeSessionAccessTokens :exec +const revokeAccountAccessTokens = `-- name: RevokeAccountAccessTokens :exec UPDATE auth.access_token SET revoked_at = now() -WHERE session_id = $1 AND revoked_at IS NULL` +WHERE account_id = $1 AND revoked_at IS NULL` -func (q *Queries) RevokeSessionAccessTokens(ctx context.Context, sessionID uuid.UUID) error { - _, err := q.db.Exec(ctx, revokeSessionAccessTokens, sessionID) +func (q *Queries) RevokeAccountAccessTokens(ctx context.Context, accountID uuid.UUID) error { + _, err := q.db.Exec(ctx, revokeAccountAccessTokens, accountID) return err } -const revokeUserAccessTokens = `-- name: RevokeUserAccessTokens :exec +const revokeSessionAccessTokens = `-- name: RevokeSessionAccessTokens :exec UPDATE auth.access_token SET revoked_at = now() -WHERE user_id = $1 AND revoked_at IS NULL` +WHERE session_id = $1 AND revoked_at IS NULL` -func (q *Queries) RevokeUserAccessTokens(ctx context.Context, userID uuid.UUID) error { - _, err := q.db.Exec(ctx, revokeUserAccessTokens, userID) +func (q *Queries) RevokeSessionAccessTokens(ctx context.Context, sessionID uuid.UUID) error { + _, err := q.db.Exec(ctx, revokeSessionAccessTokens, sessionID) return err } diff --git a/auth/storage/account.sql.go b/auth/storage/account.sql.go new file mode 100644 index 0000000..3e94e53 --- /dev/null +++ b/auth/storage/account.sql.go @@ -0,0 +1,106 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.28.0 +// source: account.sql + +package storage + +import ( + "context" + + uuid "github.com/gofrs/uuid" +) + +const addRoles = `-- name: AddRoles :exec +update "auth"."account" +set roles = array(select distinct unnest(roles || $1::text[])) +where id = $2::uuid` + +type AddRolesParams struct { + Roles []string `db:"roles" json:"roles"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *Queries) AddRoles(ctx context.Context, arg AddRolesParams) error { + _, err := q.db.Exec(ctx, addRoles, arg.Roles, arg.ID) + return err +} + +const blockAccount = `-- name: BlockAccount :exec +UPDATE auth.account +SET status = 'blocked'::auth.account_status +WHERE id = $1` + +func (q *Queries) BlockAccount(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, blockAccount, id) + return err +} + +const findAccount = `-- name: FindAccount :one +SELECT id, status, roles, updated_at, created_at +FROM auth.account +WHERE id = $1` + +func (q *Queries) FindAccount(ctx context.Context, id uuid.UUID) (Account, error) { + row := q.db.QueryRow(ctx, findAccount, id) + var i Account + err := row.Scan( + &i.ID, + &i.Status, + &i.Roles, + &i.UpdatedAt, + &i.CreatedAt, + ) + return i, err +} + +const registerAccount = `-- name: RegisterAccount :one +INSERT INTO auth.account (id) +VALUES ($1) RETURNING id, status, roles, updated_at, created_at` + +func (q *Queries) RegisterAccount(ctx context.Context, id uuid.UUID) (Account, error) { + row := q.db.QueryRow(ctx, registerAccount, id) + var i Account + err := row.Scan( + &i.ID, + &i.Status, + &i.Roles, + &i.UpdatedAt, + &i.CreatedAt, + ) + return i, err +} + +const removeAccount = `-- name: RemoveAccount :exec +DELETE FROM auth.account +WHERE id = $1` + +func (q *Queries) RemoveAccount(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, removeAccount, id) + return err +} + +const removeRoles = `-- name: RemoveRoles :exec +update "auth"."account" +set roles = array(select distinct unnest(roles) except select distinct unnest($1::text[])) +where id = $2::uuid` + +type RemoveRolesParams struct { + Roles []string `db:"roles" json:"roles"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *Queries) RemoveRoles(ctx context.Context, arg RemoveRolesParams) error { + _, err := q.db.Exec(ctx, removeRoles, arg.Roles, arg.ID) + return err +} + +const unblockAccount = `-- name: UnblockAccount :exec +UPDATE auth.account +SET status = 'active'::auth.account_status +WHERE id = $1` + +func (q *Queries) UnblockAccount(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, unblockAccount, id) + return err +} diff --git a/auth/storage/account_repository.go b/auth/storage/account_repository.go new file mode 100644 index 0000000..b95b8aa --- /dev/null +++ b/auth/storage/account_repository.go @@ -0,0 +1,137 @@ +package storage + +import ( + "braces.dev/errtrace" + "context" + "github.com/go-modulus/modulus/auth/repository" + "github.com/go-modulus/modulus/errors" + "github.com/gofrs/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +type DefaultAccountRepository struct { + queries *Queries + db *pgxpool.Pool +} + +func NewDefaultAccountRepository(db *pgxpool.Pool) repository.AccountRepository { + return &DefaultAccountRepository{ + queries: New(db), + db: db, + } +} + +func (r *DefaultAccountRepository) Create(ctx context.Context, ID uuid.UUID) (repository.Account, error) { + _, err := r.Get(ctx, ID) + if err == nil { + return repository.Account{}, repository.ErrAccountExists + } else if !errors.Is(err, repository.ErrAccountNotFound) { + return repository.Account{}, errtrace.Wrap(err) + } + + storedAccount, err := r.queries.RegisterAccount( + ctx, ID, + ) + + if err != nil { + return repository.Account{}, errtrace.Wrap(errors.WithCause(repository.ErrCannotCreateAccount, err)) + } + + return r.Transform(storedAccount), nil +} + +func (r *DefaultAccountRepository) Get(ctx context.Context, ID uuid.UUID) (repository.Account, error) { + res, err := r.queries.FindAccount(ctx, ID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return repository.Account{}, repository.ErrAccountNotFound + } + return repository.Account{}, errtrace.Wrap(err) + } + return r.Transform(res), nil +} + +func (r *DefaultAccountRepository) RemoveAccount(ctx context.Context, ID uuid.UUID) error { + _, err := r.Get(ctx, ID) + if err == nil { + return errtrace.Wrap(err) + } + tx, err := r.db.Begin(ctx) + if err != nil { + return errtrace.Wrap(err) + } + defer func() { _ = tx.Rollback(ctx) }() + qtx := New(tx) + err = qtx.RemoveAccount(ctx, ID) + if err != nil { + return errtrace.Wrap(err) + } + err = qtx.RemoveCredentialsOfAccount(ctx, ID) + if err != nil { + return errtrace.Wrap(err) + } + err = qtx.RemoveIdentitiesOfAccount(ctx, ID) + if err != nil { + return errtrace.Wrap(err) + } + return tx.Commit(ctx) +} + +func (r *DefaultAccountRepository) BlockAccount(ctx context.Context, ID uuid.UUID) error { + _, err := r.Get(ctx, ID) + if err == nil { + return errtrace.Wrap(err) + } + tx, err := r.db.Begin(ctx) + if err != nil { + return errtrace.Wrap(err) + } + defer func() { _ = tx.Rollback(ctx) }() + qtx := New(tx) + err = qtx.BlockAccount(ctx, ID) + if err != nil { + return errtrace.Wrap(err) + } + err = qtx.BlockIdentitiesOfAccount(ctx, ID) + if err != nil { + return errtrace.Wrap(err) + } + return tx.Commit(ctx) +} + +func (r *DefaultAccountRepository) Transform( + account Account, +) repository.Account { + return repository.Account{ + ID: account.ID, + Roles: account.Roles, + Status: repository.AccountStatus(account.Status), + } +} + +func (r *DefaultAccountRepository) AddRoles( + ctx context.Context, + accountID uuid.UUID, + roles ...string, +) error { + return r.queries.AddRoles( + ctx, AddRolesParams{ + ID: accountID, + Roles: roles, + }, + ) +} + +func (r *DefaultAccountRepository) RemoveRoles( + ctx context.Context, + accountID uuid.UUID, + roles ...string, +) error { + return r.queries.RemoveRoles( + ctx, RemoveRolesParams{ + ID: accountID, + Roles: roles, + }, + ) +} diff --git a/auth/storage/credential.sql.go b/auth/storage/credential.sql.go index d742f83..1d9449e 100644 --- a/auth/storage/credential.sql.go +++ b/auth/storage/credential.sql.go @@ -14,20 +14,20 @@ import ( const createCredential = `-- name: CreateCredential :one INSERT INTO "auth"."credential" - (identity_id, type, hash, expired_at) + (account_id, type, hash, expired_at) VALUES ($1::uuid, $2::text, $3::text, $4) -RETURNING hash, identity_id, type, expired_at, created_at` +RETURNING hash, account_id, type, expired_at, created_at` type CreateCredentialParams struct { - IdentityID uuid.UUID `db:"identity_id" json:"identityId"` - Type string `db:"type" json:"type"` - Hash string `db:"hash" json:"hash"` - ExpiredAt null.Time `db:"expired_at" json:"expiredAt"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` + Type string `db:"type" json:"type"` + Hash string `db:"hash" json:"hash"` + ExpiredAt null.Time `db:"expired_at" json:"expiredAt"` } func (q *Queries) CreateCredential(ctx context.Context, arg CreateCredentialParams) (Credential, error) { row := q.db.QueryRow(ctx, createCredential, - arg.IdentityID, + arg.AccountID, arg.Type, arg.Hash, arg.ExpiredAt, @@ -35,7 +35,7 @@ func (q *Queries) CreateCredential(ctx context.Context, arg CreateCredentialPara var i Credential err := row.Scan( &i.Hash, - &i.IdentityID, + &i.AccountID, &i.Type, &i.ExpiredAt, &i.CreatedAt, @@ -44,7 +44,7 @@ func (q *Queries) CreateCredential(ctx context.Context, arg CreateCredentialPara } const findAllCredentialsOfType = `-- name: FindAllCredentialsOfType :many -SELECT hash, identity_id, type, expired_at, created_at +SELECT hash, account_id, type, expired_at, created_at FROM "auth"."credential" WHERE type = $1::text ORDER BY created_at DESC` @@ -60,7 +60,7 @@ func (q *Queries) FindAllCredentialsOfType(ctx context.Context, type_ string) ([ var i Credential if err := rows.Scan( &i.Hash, - &i.IdentityID, + &i.AccountID, &i.Type, &i.ExpiredAt, &i.CreatedAt, @@ -76,17 +76,17 @@ func (q *Queries) FindAllCredentialsOfType(ctx context.Context, type_ string) ([ } const findLastCredential = `-- name: FindLastCredential :one -SELECT hash, identity_id, type, expired_at, created_at +SELECT hash, account_id, type, expired_at, created_at FROM "auth"."credential" -WHERE identity_id = $1::uuid +WHERE account_id = $1::uuid ORDER BY created_at DESC` -func (q *Queries) FindLastCredential(ctx context.Context, identityID uuid.UUID) (Credential, error) { - row := q.db.QueryRow(ctx, findLastCredential, identityID) +func (q *Queries) FindLastCredential(ctx context.Context, accountID uuid.UUID) (Credential, error) { + row := q.db.QueryRow(ctx, findLastCredential, accountID) var i Credential err := row.Scan( &i.Hash, - &i.IdentityID, + &i.AccountID, &i.Type, &i.ExpiredAt, &i.CreatedAt, @@ -95,26 +95,35 @@ func (q *Queries) FindLastCredential(ctx context.Context, identityID uuid.UUID) } const findLastCredentialOfType = `-- name: FindLastCredentialOfType :one -SELECT hash, identity_id, type, expired_at, created_at +SELECT hash, account_id, type, expired_at, created_at FROM "auth"."credential" -WHERE identity_id = $1::uuid +WHERE account_id = $1::uuid AND type = $2::text ORDER BY created_at DESC` type FindLastCredentialOfTypeParams struct { - IdentityID uuid.UUID `db:"identity_id" json:"identityId"` - Type string `db:"type" json:"type"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` + Type string `db:"type" json:"type"` } func (q *Queries) FindLastCredentialOfType(ctx context.Context, arg FindLastCredentialOfTypeParams) (Credential, error) { - row := q.db.QueryRow(ctx, findLastCredentialOfType, arg.IdentityID, arg.Type) + row := q.db.QueryRow(ctx, findLastCredentialOfType, arg.AccountID, arg.Type) var i Credential err := row.Scan( &i.Hash, - &i.IdentityID, + &i.AccountID, &i.Type, &i.ExpiredAt, &i.CreatedAt, ) return i, err } + +const removeCredentialsOfAccount = `-- name: RemoveCredentialsOfAccount :exec +DELETE FROM "auth"."credential" +WHERE account_id = $1::uuid` + +func (q *Queries) RemoveCredentialsOfAccount(ctx context.Context, accountID uuid.UUID) error { + _, err := q.db.Exec(ctx, removeCredentialsOfAccount, accountID) + return err +} diff --git a/auth/storage/credential_repository.go b/auth/storage/credential_repository.go index a2ecc67..5b6744f 100644 --- a/auth/storage/credential_repository.go +++ b/auth/storage/credential_repository.go @@ -24,18 +24,18 @@ func NewDefaultCredentialRepository(db *pgxpool.Pool) repository.CredentialRepos func (r *DefaultCredentialRepository) Create( ctx context.Context, - identityID uuid.UUID, - credHash string, - credType string, + accountID uuid.UUID, + credentialHash string, + credType repository.CredentialType, expiredAt *time.Time, ) (repository.Credential, error) { expAt := null.TimeFromPtr(expiredAt) cred, err := r.queries.CreateCredential( ctx, CreateCredentialParams{ - IdentityID: identityID, - Type: credType, - Hash: credHash, - ExpiredAt: expAt, + AccountID: accountID, + Type: string(credType), + Hash: credentialHash, + ExpiredAt: expAt, }, ) @@ -46,24 +46,28 @@ func (r *DefaultCredentialRepository) Create( return r.transform(cred), nil } +func (r *DefaultCredentialRepository) RemoveCredentials(ctx context.Context, accountID uuid.UUID) error { + return errtrace.Wrap(r.queries.RemoveCredentialsOfAccount(ctx, accountID)) +} + func (r *DefaultCredentialRepository) transform(res Credential) repository.Credential { return repository.Credential{ - IdentityID: res.IdentityID, - Hash: res.Hash, - Type: res.Type, - ExpiredAt: res.ExpiredAt, + AccountID: res.AccountID, + Hash: res.Hash, + Type: repository.CredentialType(res.Type), + ExpiredAt: res.ExpiredAt, } } func (r *DefaultCredentialRepository) GetLast( ctx context.Context, - identityID uuid.UUID, + accountID uuid.UUID, credType string, ) (repository.Credential, error) { res, err := r.queries.FindLastCredentialOfType( ctx, FindLastCredentialOfTypeParams{ - IdentityID: identityID, - Type: credType, + AccountID: accountID, + Type: credType, }, ) if err != nil { diff --git a/auth/storage/fixture/access_token.go b/auth/storage/fixture/access_token.go index 5c6a6f6..96d88ba 100644 --- a/auth/storage/fixture/access_token.go +++ b/auth/storage/fixture/access_token.go @@ -41,9 +41,9 @@ func (f *AccessTokenFixture) SessionID(sessionID uuid.UUID) *AccessTokenFixture return c } -func (f *AccessTokenFixture) UserID(userID uuid.UUID) *AccessTokenFixture { +func (f *AccessTokenFixture) AccountID(accountID uuid.UUID) *AccessTokenFixture { c := f.clone() - c.entity.UserID = userID + c.entity.AccountID = accountID return c } @@ -86,15 +86,15 @@ func (f *AccessTokenFixture) clone() *AccessTokenFixture { func (f *AccessTokenFixture) save(ctx context.Context) error { query := `INSERT INTO "auth"."access_token" - ("hash", "identity_id", "session_id", "user_id", "roles", "data", "revoked_at", "expires_at", "created_at") + ("hash", "identity_id", "session_id", "account_id", "roles", "data", "revoked_at", "expires_at", "created_at") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - RETURNING "hash", "identity_id", "session_id", "user_id", "roles", "data", "revoked_at", "expires_at", "created_at" + RETURNING "hash", "identity_id", "session_id", "account_id", "roles", "data", "revoked_at", "expires_at", "created_at" ` row := f.db.QueryRow(ctx, query, f.entity.Hash, f.entity.IdentityID, f.entity.SessionID, - f.entity.UserID, + f.entity.AccountID, f.entity.Roles, f.entity.Data, f.entity.RevokedAt, @@ -105,7 +105,7 @@ func (f *AccessTokenFixture) save(ctx context.Context) error { &f.entity.Hash, &f.entity.IdentityID, &f.entity.SessionID, - &f.entity.UserID, + &f.entity.AccountID, &f.entity.Roles, &f.entity.Data, &f.entity.RevokedAt, @@ -148,7 +148,7 @@ func (f *AccessTokenFixture) Cleanup(tb testing.TB) *AccessTokenFixture { func (f *AccessTokenFixture) PullUpdates(tb testing.TB) *AccessTokenFixture { c := f.clone() ctx := context.Background() - query := `SELECT "hash", "identity_id", "session_id", "user_id", "roles", "data", "revoked_at", "expires_at", "created_at" FROM "auth"."access_token" WHERE hash = $1` + query := `SELECT "hash", "identity_id", "session_id", "account_id", "roles", "data", "revoked_at", "expires_at", "created_at" FROM "auth"."access_token" WHERE hash = $1` row := f.db.QueryRow(ctx, query, c.entity.Hash, ) @@ -157,7 +157,7 @@ func (f *AccessTokenFixture) PullUpdates(tb testing.TB) *AccessTokenFixture { &c.entity.Hash, &c.entity.IdentityID, &c.entity.SessionID, - &c.entity.UserID, + &c.entity.AccountID, &c.entity.Roles, &c.entity.Data, &c.entity.RevokedAt, @@ -176,7 +176,7 @@ func (f *AccessTokenFixture) PushUpdates(tb testing.TB) *AccessTokenFixture { UPDATE "auth"."access_token" SET "identity_id" = $2, "session_id" = $3, - "user_id" = $4, + "account_id" = $4, "roles" = $5, "data" = $6, "revoked_at" = $7, @@ -190,7 +190,7 @@ func (f *AccessTokenFixture) PushUpdates(tb testing.TB) *AccessTokenFixture { f.entity.Hash, f.entity.IdentityID, f.entity.SessionID, - f.entity.UserID, + f.entity.AccountID, f.entity.Roles, f.entity.Data, f.entity.RevokedAt, diff --git a/auth/storage/fixture/account.go b/auth/storage/fixture/account.go new file mode 100644 index 0000000..d700b69 --- /dev/null +++ b/auth/storage/fixture/account.go @@ -0,0 +1,159 @@ +// Code generated by sqlc-fixture plugin for SQLc. DO NOT EDIT. + +package fixture + +import ( + "context" + "github.com/go-modulus/modulus/auth/storage" + uuid "github.com/gofrs/uuid" + "testing" + "time" +) + +type AccountFixture struct { + entity storage.Account + db storage.DBTX +} + +func NewAccountFixture(db storage.DBTX, defaultEntity storage.Account) *AccountFixture { + return &AccountFixture{ + db: db, + entity: defaultEntity, + } +} + +func (f *AccountFixture) ID(iD uuid.UUID) *AccountFixture { + c := f.clone() + c.entity.ID = iD + return c +} + +func (f *AccountFixture) Status(status storage.AccountStatus) *AccountFixture { + c := f.clone() + c.entity.Status = status + return c +} + +func (f *AccountFixture) Roles(roles []string) *AccountFixture { + c := f.clone() + c.entity.Roles = roles + return c +} + +func (f *AccountFixture) UpdatedAt(updatedAt time.Time) *AccountFixture { + c := f.clone() + c.entity.UpdatedAt = updatedAt + return c +} + +func (f *AccountFixture) CreatedAt(createdAt time.Time) *AccountFixture { + c := f.clone() + c.entity.CreatedAt = createdAt + return c +} + +func (f *AccountFixture) clone() *AccountFixture { + return &AccountFixture{ + db: f.db, + entity: f.entity, + } +} + +func (f *AccountFixture) save(ctx context.Context) error { + query := `INSERT INTO "auth"."account" + ("id", "status", "roles", "updated_at", "created_at") + VALUES ($1, $2, $3, $4, $5) + RETURNING "id", "status", "roles", "updated_at", "created_at" + ` + row := f.db.QueryRow(ctx, query, + f.entity.ID, + f.entity.Status, + f.entity.Roles, + f.entity.UpdatedAt, + f.entity.CreatedAt, + ) + err := row.Scan( + &f.entity.ID, + &f.entity.Status, + &f.entity.Roles, + &f.entity.UpdatedAt, + &f.entity.CreatedAt, + ) + return err +} + +func (f *AccountFixture) GetEntity() storage.Account { + return f.entity +} + +func (f *AccountFixture) Create(tb testing.TB) *AccountFixture { + err := f.save(context.Background()) + if err != nil { + tb.Fatalf("failed to create Account: %v", err) + } + f.Cleanup(tb) + c := f.clone() + return c +} + +// Cleanup calls testing.TB.Cleanup() function with providing a callback inside it. +// This callback will delete a record from the table by primary key when test will be finished. +func (f *AccountFixture) Cleanup(tb testing.TB) *AccountFixture { + tb.Cleanup( + func() { + query := `DELETE FROM "auth"."account" WHERE id = $1` + _, err := f.db.Exec(context.Background(), query, f.entity.ID) + + if err != nil { + tb.Fatalf("failed to cleanup Account: %v", err) + } + }) + + return f +} + +func (f *AccountFixture) PullUpdates(tb testing.TB) *AccountFixture { + c := f.clone() + ctx := context.Background() + query := `SELECT "id", "status", "roles", "updated_at", "created_at" FROM "auth"."account" WHERE id = $1` + row := f.db.QueryRow(ctx, query, + c.entity.ID, + ) + + err := row.Scan( + &c.entity.ID, + &c.entity.Status, + &c.entity.Roles, + &c.entity.UpdatedAt, + &c.entity.CreatedAt, + ) + if err != nil { + tb.Fatalf("failed to actualize data Account: %v", err) + } + return c +} + +func (f *AccountFixture) PushUpdates(tb testing.TB) *AccountFixture { + c := f.clone() + query := ` + UPDATE "auth"."account" SET + "status" = $2, + "roles" = $3, + "updated_at" = $4, + "created_at" = $5 + WHERE "id" = $1 + ` + _, err := f.db.Exec( + context.Background(), + query, + f.entity.ID, + f.entity.Status, + f.entity.Roles, + f.entity.UpdatedAt, + f.entity.CreatedAt, + ) + if err != nil { + tb.Fatalf("failed to push the data Account: %v", err) + } + return c +} diff --git a/auth/storage/fixture/credential.ext.go b/auth/storage/fixture/credential.ext.go index d63538d..96a88c7 100644 --- a/auth/storage/fixture/credential.ext.go +++ b/auth/storage/fixture/credential.ext.go @@ -5,13 +5,13 @@ import ( "testing" ) -// CleanupAllOfIdentity calls testing.TB.Cleanup() function with providing a callback inside it. +// CleanupAllOfAccount calls testing.TB.Cleanup() function with providing a callback inside it. // This callback will delete all records from the table by the IdentityID field. -func (f *CredentialFixture) CleanupAllOfIdentity(tb testing.TB) *CredentialFixture { +func (f *CredentialFixture) CleanupAllOfAccount(tb testing.TB) *CredentialFixture { tb.Cleanup( func() { - query := `DELETE FROM auth.credential WHERE credential.identity_id = $1` - _, err := f.db.Exec(context.Background(), query, f.entity.IdentityID) + query := `DELETE FROM auth.credential WHERE credential.account_id = $1` + _, err := f.db.Exec(context.Background(), query, f.entity.AccountID) if err != nil { tb.Fatalf("failed to cleanup Credentials of identity: %v", err) diff --git a/auth/storage/fixture/credential.go b/auth/storage/fixture/credential.go index 9418b06..42c289a 100644 --- a/auth/storage/fixture/credential.go +++ b/auth/storage/fixture/credential.go @@ -29,9 +29,9 @@ func (f *CredentialFixture) Hash(hash string) *CredentialFixture { return c } -func (f *CredentialFixture) IdentityID(identityID uuid.UUID) *CredentialFixture { +func (f *CredentialFixture) AccountID(accountID uuid.UUID) *CredentialFixture { c := f.clone() - c.entity.IdentityID = identityID + c.entity.AccountID = accountID return c } @@ -62,20 +62,20 @@ func (f *CredentialFixture) clone() *CredentialFixture { func (f *CredentialFixture) save(ctx context.Context) error { query := `INSERT INTO "auth"."credential" - ("hash", "identity_id", "type", "expired_at", "created_at") + ("hash", "account_id", "type", "expired_at", "created_at") VALUES ($1, $2, $3, $4, $5) - RETURNING "hash", "identity_id", "type", "expired_at", "created_at" + RETURNING "hash", "account_id", "type", "expired_at", "created_at" ` row := f.db.QueryRow(ctx, query, f.entity.Hash, - f.entity.IdentityID, + f.entity.AccountID, f.entity.Type, f.entity.ExpiredAt, f.entity.CreatedAt, ) err := row.Scan( &f.entity.Hash, - &f.entity.IdentityID, + &f.entity.AccountID, &f.entity.Type, &f.entity.ExpiredAt, &f.entity.CreatedAt, @@ -116,14 +116,14 @@ func (f *CredentialFixture) Cleanup(tb testing.TB) *CredentialFixture { func (f *CredentialFixture) PullUpdates(tb testing.TB) *CredentialFixture { c := f.clone() ctx := context.Background() - query := `SELECT "hash", "identity_id", "type", "expired_at", "created_at" FROM "auth"."credential" WHERE hash = $1` + query := `SELECT "hash", "account_id", "type", "expired_at", "created_at" FROM "auth"."credential" WHERE hash = $1` row := f.db.QueryRow(ctx, query, c.entity.Hash, ) err := row.Scan( &c.entity.Hash, - &c.entity.IdentityID, + &c.entity.AccountID, &c.entity.Type, &c.entity.ExpiredAt, &c.entity.CreatedAt, @@ -138,7 +138,7 @@ func (f *CredentialFixture) PushUpdates(tb testing.TB) *CredentialFixture { c := f.clone() query := ` UPDATE "auth"."credential" SET - "identity_id" = $2, + "account_id" = $2, "type" = $3, "expired_at" = $4, "created_at" = $5 @@ -148,7 +148,7 @@ func (f *CredentialFixture) PushUpdates(tb testing.TB) *CredentialFixture { context.Background(), query, f.entity.Hash, - f.entity.IdentityID, + f.entity.AccountID, f.entity.Type, f.entity.ExpiredAt, f.entity.CreatedAt, diff --git a/auth/storage/fixture/factory.go b/auth/storage/fixture/factory.go index d186ac6..5948701 100644 --- a/auth/storage/fixture/factory.go +++ b/auth/storage/fixture/factory.go @@ -25,11 +25,11 @@ func (f *FixturesFactory) Credential() *CredentialFixture { hash := base64.StdEncoding.EncodeToString(id.Bytes())[:16] return NewCredentialFixture( f.db, storage.Credential{ - IdentityID: uuid.Must(uuid.NewV6()), - Hash: hash, - Type: string(repository.CredentialTypePassword), - ExpiredAt: null.Time{}, - CreatedAt: time.Now(), + AccountID: uuid.Must(uuid.NewV6()), + Hash: hash, + Type: string(repository.CredentialTypePassword), + ExpiredAt: null.Time{}, + CreatedAt: time.Now(), }, ) } @@ -40,12 +40,12 @@ func (f *FixturesFactory) Identity() *IdentityFixture { f.db, storage.Identity{ ID: id, Identity: "test" + id.String(), - UserID: uuid.Must(uuid.NewV6()), - Roles: []string{}, + AccountID: uuid.Must(uuid.NewV6()), Status: storage.IdentityStatusActive, Data: nil, UpdatedAt: time.Now(), CreatedAt: time.Now(), + Type: "test", }, ) } @@ -74,7 +74,7 @@ func (f *FixturesFactory) AccessToken() *AccessTokenFixture { Hash: hash, IdentityID: uuid.Must(uuid.NewV6()), SessionID: uuid.Must(uuid.NewV6()), - UserID: uuid.Must(uuid.NewV6()), + AccountID: uuid.Must(uuid.NewV6()), Roles: []string{}, Data: nil, RevokedAt: null.Time{}, @@ -88,7 +88,7 @@ func (f *FixturesFactory) Session() *SessionFixture { return NewSessionFixture( f.db, storage.Session{ ID: uuid.Must(uuid.NewV6()), - UserID: uuid.Must(uuid.NewV6()), + AccountID: uuid.Must(uuid.NewV6()), IdentityID: uuid.Must(uuid.NewV6()), Data: nil, ExpiresAt: time.Now().Add(time.Hour), @@ -96,3 +96,15 @@ func (f *FixturesFactory) Session() *SessionFixture { }, ) } + +func (f *FixturesFactory) Account() *AccountFixture { + return NewAccountFixture( + f.db, storage.Account{ + ID: uuid.Must(uuid.NewV6()), + Status: "active", + Roles: []string{"test"}, + UpdatedAt: time.Now(), + CreatedAt: time.Now(), + }, + ) +} diff --git a/auth/storage/fixture/identity.ext.go b/auth/storage/fixture/identity.ext.go index 8c48b37..41f513c 100644 --- a/auth/storage/fixture/identity.ext.go +++ b/auth/storage/fixture/identity.ext.go @@ -5,13 +5,13 @@ import ( "testing" ) -// CleanupAllOfUser calls testing.TB.Cleanup() function with providing a callback inside it. +// CleanupAllOfAccount calls testing.TB.Cleanup() function with providing a callback inside it. // This callback will delete all records from the table by the UserID field. -func (f *IdentityFixture) CleanupAllOfUser(tb testing.TB) *IdentityFixture { +func (f *IdentityFixture) CleanupAllOfAccount(tb testing.TB) *IdentityFixture { tb.Cleanup( func() { - query := `DELETE FROM auth.identity WHERE identity.user_id = $1` - _, err := f.db.Exec(context.Background(), query, f.entity.UserID) + query := `DELETE FROM auth.identity WHERE identity.account_id = $1` + _, err := f.db.Exec(context.Background(), query, f.entity.AccountID) if err != nil { tb.Fatalf("failed to cleanup Identities of user: %v", err) @@ -21,3 +21,28 @@ func (f *IdentityFixture) CleanupAllOfUser(tb testing.TB) *IdentityFixture { return f } + +// PullUpdatesLastAccountIdentity gets the last Identity of the account and updates the fixture entity. +// This method is useful when you need to get the data from the database after registering the account, when you don't have identity ID +func (f *IdentityFixture) PullUpdatesLastAccountIdentity(tb testing.TB) *IdentityFixture { + + query := `SELECT id FROM auth.identity WHERE identity.account_id = $1 ORDER BY created_at DESC LIMIT 1` + rows, err := f.db.Query(context.Background(), query, f.entity.AccountID) + + if err != nil { + tb.Fatalf("failed to get the last identity of the account %s: %v", f.entity.AccountID, err) + } + + defer rows.Close() + + if !rows.Next() { + tb.Fatalf("no identity found for the account %s", f.entity.AccountID) + } + + err = rows.Scan(&f.entity.ID) + if err != nil { + tb.Fatalf("failed to scan the last identity of the account %s: %v", f.entity.AccountID, err) + } + + return f.PullUpdates(tb) +} diff --git a/auth/storage/fixture/identity.go b/auth/storage/fixture/identity.go index 9bb1407..54da43c 100644 --- a/auth/storage/fixture/identity.go +++ b/auth/storage/fixture/identity.go @@ -34,9 +34,9 @@ func (f *IdentityFixture) Identity(identity string) *IdentityFixture { return c } -func (f *IdentityFixture) UserID(userID uuid.UUID) *IdentityFixture { +func (f *IdentityFixture) AccountID(accountID uuid.UUID) *IdentityFixture { c := f.clone() - c.entity.UserID = userID + c.entity.AccountID = accountID return c } @@ -52,12 +52,6 @@ func (f *IdentityFixture) Data(data []byte) *IdentityFixture { return c } -func (f *IdentityFixture) Roles(roles []string) *IdentityFixture { - c := f.clone() - c.entity.Roles = roles - return c -} - func (f *IdentityFixture) UpdatedAt(updatedAt time.Time) *IdentityFixture { c := f.clone() c.entity.UpdatedAt = updatedAt @@ -70,6 +64,12 @@ func (f *IdentityFixture) CreatedAt(createdAt time.Time) *IdentityFixture { return c } +func (f *IdentityFixture) Type(typ string) *IdentityFixture { + c := f.clone() + c.entity.Type = typ + return c +} + func (f *IdentityFixture) clone() *IdentityFixture { return &IdentityFixture{ db: f.db, @@ -79,29 +79,29 @@ func (f *IdentityFixture) clone() *IdentityFixture { func (f *IdentityFixture) save(ctx context.Context) error { query := `INSERT INTO "auth"."identity" - ("id", "identity", "user_id", "status", "data", "roles", "updated_at", "created_at") + ("id", "identity", "account_id", "status", "data", "updated_at", "created_at", "type") VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - RETURNING "id", "identity", "user_id", "status", "data", "roles", "updated_at", "created_at" + RETURNING "id", "identity", "account_id", "status", "data", "updated_at", "created_at", "type" ` row := f.db.QueryRow(ctx, query, f.entity.ID, f.entity.Identity, - f.entity.UserID, + f.entity.AccountID, f.entity.Status, f.entity.Data, - f.entity.Roles, f.entity.UpdatedAt, f.entity.CreatedAt, + f.entity.Type, ) err := row.Scan( &f.entity.ID, &f.entity.Identity, - &f.entity.UserID, + &f.entity.AccountID, &f.entity.Status, &f.entity.Data, - &f.entity.Roles, &f.entity.UpdatedAt, &f.entity.CreatedAt, + &f.entity.Type, ) return err } @@ -139,7 +139,7 @@ func (f *IdentityFixture) Cleanup(tb testing.TB) *IdentityFixture { func (f *IdentityFixture) PullUpdates(tb testing.TB) *IdentityFixture { c := f.clone() ctx := context.Background() - query := `SELECT "id", "identity", "user_id", "status", "data", "roles", "updated_at", "created_at" FROM "auth"."identity" WHERE id = $1` + query := `SELECT "id", "identity", "account_id", "status", "data", "updated_at", "created_at", "type" FROM "auth"."identity" WHERE id = $1` row := f.db.QueryRow(ctx, query, c.entity.ID, ) @@ -147,12 +147,12 @@ func (f *IdentityFixture) PullUpdates(tb testing.TB) *IdentityFixture { err := row.Scan( &c.entity.ID, &c.entity.Identity, - &c.entity.UserID, + &c.entity.AccountID, &c.entity.Status, &c.entity.Data, - &c.entity.Roles, &c.entity.UpdatedAt, &c.entity.CreatedAt, + &c.entity.Type, ) if err != nil { tb.Fatalf("failed to actualize data Identity: %v", err) @@ -165,12 +165,12 @@ func (f *IdentityFixture) PushUpdates(tb testing.TB) *IdentityFixture { query := ` UPDATE "auth"."identity" SET "identity" = $2, - "user_id" = $3, + "account_id" = $3, "status" = $4, "data" = $5, - "roles" = $6, - "updated_at" = $7, - "created_at" = $8 + "updated_at" = $6, + "created_at" = $7, + "type" = $8 WHERE "id" = $1 ` _, err := f.db.Exec( @@ -178,12 +178,12 @@ func (f *IdentityFixture) PushUpdates(tb testing.TB) *IdentityFixture { query, f.entity.ID, f.entity.Identity, - f.entity.UserID, + f.entity.AccountID, f.entity.Status, f.entity.Data, - f.entity.Roles, f.entity.UpdatedAt, f.entity.CreatedAt, + f.entity.Type, ) if err != nil { tb.Fatalf("failed to push the data Identity: %v", err) diff --git a/auth/storage/fixture/session.go b/auth/storage/fixture/session.go index f9ac9f4..534ec31 100644 --- a/auth/storage/fixture/session.go +++ b/auth/storage/fixture/session.go @@ -28,9 +28,9 @@ func (f *SessionFixture) ID(iD uuid.UUID) *SessionFixture { return c } -func (f *SessionFixture) UserID(userID uuid.UUID) *SessionFixture { +func (f *SessionFixture) AccountID(accountID uuid.UUID) *SessionFixture { c := f.clone() - c.entity.UserID = userID + c.entity.AccountID = accountID return c } @@ -67,13 +67,13 @@ func (f *SessionFixture) clone() *SessionFixture { func (f *SessionFixture) save(ctx context.Context) error { query := `INSERT INTO "auth"."session" - ("id", "user_id", "identity_id", "data", "expires_at", "created_at") + ("id", "account_id", "identity_id", "data", "expires_at", "created_at") VALUES ($1, $2, $3, $4, $5, $6) - RETURNING "id", "user_id", "identity_id", "data", "expires_at", "created_at" + RETURNING "id", "account_id", "identity_id", "data", "expires_at", "created_at" ` row := f.db.QueryRow(ctx, query, f.entity.ID, - f.entity.UserID, + f.entity.AccountID, f.entity.IdentityID, f.entity.Data, f.entity.ExpiresAt, @@ -81,7 +81,7 @@ func (f *SessionFixture) save(ctx context.Context) error { ) err := row.Scan( &f.entity.ID, - &f.entity.UserID, + &f.entity.AccountID, &f.entity.IdentityID, &f.entity.Data, &f.entity.ExpiresAt, @@ -123,14 +123,14 @@ func (f *SessionFixture) Cleanup(tb testing.TB) *SessionFixture { func (f *SessionFixture) PullUpdates(tb testing.TB) *SessionFixture { c := f.clone() ctx := context.Background() - query := `SELECT "id", "user_id", "identity_id", "data", "expires_at", "created_at" FROM "auth"."session" WHERE id = $1` + query := `SELECT "id", "account_id", "identity_id", "data", "expires_at", "created_at" FROM "auth"."session" WHERE id = $1` row := f.db.QueryRow(ctx, query, c.entity.ID, ) err := row.Scan( &c.entity.ID, - &c.entity.UserID, + &c.entity.AccountID, &c.entity.IdentityID, &c.entity.Data, &c.entity.ExpiresAt, @@ -146,7 +146,7 @@ func (f *SessionFixture) PushUpdates(tb testing.TB) *SessionFixture { c := f.clone() query := ` UPDATE "auth"."session" SET - "user_id" = $2, + "account_id" = $2, "identity_id" = $3, "data" = $4, "expires_at" = $5, @@ -157,7 +157,7 @@ func (f *SessionFixture) PushUpdates(tb testing.TB) *SessionFixture { context.Background(), query, f.entity.ID, - f.entity.UserID, + f.entity.AccountID, f.entity.IdentityID, f.entity.Data, f.entity.ExpiresAt, diff --git a/auth/storage/identity.sql.go b/auth/storage/identity.sql.go index 5313fb1..7313236 100644 --- a/auth/storage/identity.sql.go +++ b/auth/storage/identity.sql.go @@ -11,66 +11,108 @@ import ( uuid "github.com/gofrs/uuid" ) -const addRoles = `-- name: AddRoles :exec +const activateIdentity = `-- name: ActivateIdentity :exec update "auth"."identity" -set roles = array(select distinct unnest(roles || $1::text[])) -where id = $2::uuid` +set status = 'active'::auth.identity_status +where id = $1::uuid` -type AddRolesParams struct { - Roles []string `db:"roles" json:"roles"` - ID uuid.UUID `db:"id" json:"id"` +func (q *Queries) ActivateIdentity(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, activateIdentity, id) + return err } -func (q *Queries) AddRoles(ctx context.Context, arg AddRolesParams) error { - _, err := q.db.Exec(ctx, addRoles, arg.Roles, arg.ID) +const blockIdentitiesOfAccount = `-- name: BlockIdentitiesOfAccount :exec +update "auth"."identity" +set status = 'blocked'::auth.identity_status +where account_id = $1::uuid` + +func (q *Queries) BlockIdentitiesOfAccount(ctx context.Context, accountID uuid.UUID) error { + _, err := q.db.Exec(ctx, blockIdentitiesOfAccount, accountID) + return err +} + +const blockIdentity = `-- name: BlockIdentity :exec +update "auth"."identity" +set status = 'blocked'::auth.identity_status +where id = $1::uuid` + +func (q *Queries) BlockIdentity(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, blockIdentity, id) return err } const createIdentity = `-- name: CreateIdentity :one insert into "auth"."identity" - (id, identity, user_id, "data") -values ($1::uuid, $2::text, $3::uuid, $4::jsonb) -RETURNING id, identity, user_id, status, data, roles, updated_at, created_at` + (id, identity, account_id, "data", "type") +values ($1::uuid, $2::text, $3::uuid, $4::jsonb, $5::text) +RETURNING id, identity, account_id, status, data, updated_at, created_at, type` type CreateIdentityParams struct { - ID uuid.UUID `db:"id" json:"id"` - Identity string `db:"identity" json:"identity"` - UserID uuid.UUID `db:"user_id" json:"userId"` - Data []byte `db:"data" json:"data"` + ID uuid.UUID `db:"id" json:"id"` + Identity string `db:"identity" json:"identity"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` + Data []byte `db:"data" json:"data"` + Type string `db:"type" json:"type"` } func (q *Queries) CreateIdentity(ctx context.Context, arg CreateIdentityParams) (Identity, error) { row := q.db.QueryRow(ctx, createIdentity, arg.ID, arg.Identity, - arg.UserID, + arg.AccountID, arg.Data, + arg.Type, ) var i Identity err := row.Scan( &i.ID, &i.Identity, - &i.UserID, + &i.AccountID, &i.Status, &i.Data, - &i.Roles, &i.UpdatedAt, &i.CreatedAt, + &i.Type, ) return i, err } -const deleteIdentity = `-- name: DeleteIdentity :exec -delete from "auth"."identity" -where id = $1` +const findAccountIdentities = `-- name: FindAccountIdentities :many +select id, identity, account_id, status, data, updated_at, created_at, type +from "auth"."identity" +where account_id = $1::uuid` -func (q *Queries) DeleteIdentity(ctx context.Context, id uuid.UUID) error { - _, err := q.db.Exec(ctx, deleteIdentity, id) - return err +func (q *Queries) FindAccountIdentities(ctx context.Context, accountID uuid.UUID) ([]Identity, error) { + rows, err := q.db.Query(ctx, findAccountIdentities, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Identity + for rows.Next() { + var i Identity + if err := rows.Scan( + &i.ID, + &i.Identity, + &i.AccountID, + &i.Status, + &i.Data, + &i.UpdatedAt, + &i.CreatedAt, + &i.Type, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const findIdentity = `-- name: FindIdentity :one -select id, identity, user_id, status, data, roles, updated_at, created_at +select id, identity, account_id, status, data, updated_at, created_at, type from "auth"."identity" where identity = $1::text` @@ -80,18 +122,18 @@ func (q *Queries) FindIdentity(ctx context.Context, identity string) (Identity, err := row.Scan( &i.ID, &i.Identity, - &i.UserID, + &i.AccountID, &i.Status, &i.Data, - &i.Roles, &i.UpdatedAt, &i.CreatedAt, + &i.Type, ) return i, err } const findIdentityById = `-- name: FindIdentityById :one -select id, identity, user_id, status, data, roles, updated_at, created_at +select id, identity, account_id, status, data, updated_at, created_at, type from "auth"."identity" where id = $1::uuid` @@ -101,61 +143,40 @@ func (q *Queries) FindIdentityById(ctx context.Context, id uuid.UUID) (Identity, err := row.Scan( &i.ID, &i.Identity, - &i.UserID, + &i.AccountID, &i.Status, &i.Data, - &i.Roles, &i.UpdatedAt, &i.CreatedAt, + &i.Type, ) return i, err } -const findUserIdentities = `-- name: FindUserIdentities :many -select id, identity, user_id, status, data, roles, updated_at, created_at -from "auth"."identity" -where user_id = $1::uuid` +const removeIdentitiesOfAccount = `-- name: RemoveIdentitiesOfAccount :exec +delete from "auth"."identity" +where account_id = $1::uuid` -func (q *Queries) FindUserIdentities(ctx context.Context, userID uuid.UUID) ([]Identity, error) { - rows, err := q.db.Query(ctx, findUserIdentities, userID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Identity - for rows.Next() { - var i Identity - if err := rows.Scan( - &i.ID, - &i.Identity, - &i.UserID, - &i.Status, - &i.Data, - &i.Roles, - &i.UpdatedAt, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil +func (q *Queries) RemoveIdentitiesOfAccount(ctx context.Context, accountID uuid.UUID) error { + _, err := q.db.Exec(ctx, removeIdentitiesOfAccount, accountID) + return err } -const removeRoles = `-- name: RemoveRoles :exec -update "auth"."identity" -set roles = array(select distinct unnest(roles) except select distinct unnest($1::text[])) -where id = $2::uuid` +const removeIdentity = `-- name: RemoveIdentity :exec +delete from "auth"."identity" +where id = $1` -type RemoveRolesParams struct { - Roles []string `db:"roles" json:"roles"` - ID uuid.UUID `db:"id" json:"id"` +func (q *Queries) RemoveIdentity(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, removeIdentity, id) + return err } -func (q *Queries) RemoveRoles(ctx context.Context, arg RemoveRolesParams) error { - _, err := q.db.Exec(ctx, removeRoles, arg.Roles, arg.ID) +const requestIdentityVerification = `-- name: RequestIdentityVerification :exec +update "auth"."identity" +set status = 'not-verified'::auth.identity_status +where id = $1::uuid` + +func (q *Queries) RequestIdentityVerification(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, requestIdentityVerification, id) return err } diff --git a/auth/storage/identity_repository.go b/auth/storage/identity_repository.go index 6d278f9..1b43670 100644 --- a/auth/storage/identity_repository.go +++ b/auth/storage/identity_repository.go @@ -24,7 +24,8 @@ func NewDefaultIdentityRepository(db *pgxpool.Pool) repository.IdentityRepositor func (r *DefaultIdentityRepository) Create( ctx context.Context, identity string, - userID uuid.UUID, + accountID uuid.UUID, + identityType repository.IdentityType, additionalData map[string]interface{}, ) (repository.Identity, error) { _, err := r.Get(ctx, identity) @@ -43,10 +44,11 @@ func (r *DefaultIdentityRepository) Create( } storedIdentity, err := r.queries.CreateIdentity( ctx, CreateIdentityParams{ - ID: id, - Identity: identity, - UserID: userID, - Data: dataVal, + ID: id, + Identity: identity, + AccountID: accountID, + Data: dataVal, + Type: string(identityType), }, ) @@ -57,6 +59,43 @@ func (r *DefaultIdentityRepository) Create( return r.Transform(storedIdentity), nil } +func (r *DefaultIdentityRepository) GetByAccountID(ctx context.Context, accountID uuid.UUID) ( + []repository.Identity, + error, +) { + idents, err := r.queries.FindAccountIdentities(ctx, accountID) + if err != nil { + return nil, errtrace.Wrap(err) + } + var res []repository.Identity + for _, ident := range idents { + res = append(res, r.Transform(ident)) + } + return res, nil +} + +func (r *DefaultIdentityRepository) RemoveAccountIdentities(ctx context.Context, accountID uuid.UUID) error { + return errtrace.Wrap(r.queries.RemoveIdentitiesOfAccount(ctx, accountID)) +} + +func (r *DefaultIdentityRepository) RemoveIdentity(ctx context.Context, identity string) error { + ident, err := r.Get(ctx, identity) + if err != nil { + return errtrace.Wrap(err) + } + + return errtrace.Wrap(r.queries.RemoveIdentity(ctx, ident.ID)) +} + +func (r *DefaultIdentityRepository) BlockIdentity(ctx context.Context, identity string) error { + ident, err := r.Get(ctx, identity) + if err != nil { + return errtrace.Wrap(err) + } + + return errtrace.Wrap(r.queries.BlockIdentity(ctx, ident.ID)) +} + func (r *DefaultIdentityRepository) Transform( identity Identity, ) repository.Identity { @@ -65,12 +104,11 @@ func (r *DefaultIdentityRepository) Transform( data = make(map[string]interface{}) } return repository.Identity{ - ID: identity.ID, - Identity: identity.Identity, - UserID: identity.UserID, - Roles: identity.Roles, - Status: repository.IdentityStatus(identity.Status), - Data: data, + ID: identity.ID, + Identity: identity.Identity, + AccountID: identity.AccountID, + Status: repository.IdentityStatus(identity.Status), + Data: data, } } @@ -101,29 +139,3 @@ func (r *DefaultIdentityRepository) GetById( } return r.Transform(res), nil } - -func (r *DefaultIdentityRepository) AddRoles( - ctx context.Context, - identityID uuid.UUID, - roles ...string, -) error { - return r.queries.AddRoles( - ctx, AddRolesParams{ - ID: identityID, - Roles: roles, - }, - ) -} - -func (r *DefaultIdentityRepository) RemoveRoles( - ctx context.Context, - identityID uuid.UUID, - roles ...string, -) error { - return r.queries.RemoveRoles( - ctx, RemoveRolesParams{ - ID: identityID, - Roles: roles, - }, - ) -} diff --git a/auth/storage/migration/20240320084613_auth_account.sql b/auth/storage/migration/20240320084613_auth_account.sql new file mode 100644 index 0000000..d1643d8 --- /dev/null +++ b/auth/storage/migration/20240320084613_auth_account.sql @@ -0,0 +1,53 @@ +-- migrate:up + +CREATE TYPE auth.account_status AS ENUM ( + 'active', + 'blocked' + ); + +ALTER TABLE auth.identity + RENAME COLUMN user_id TO account_id; +ALTER TABLE auth.identity + DROP COLUMN roles; +ALTER TABLE auth.identity + ADD COLUMN type text NOT NULL DEFAULT 'not-set'; + +COMMENT ON COLUMN auth.identity.type IS 'Type of the identity (eg. email, phone, google-auth, etc.).'; + +ALTER TYPE auth.identity_status ADD VALUE IF NOT EXISTS 'not-verified'; + +CREATE TABLE auth.account +( + id uuid PRIMARY KEY, + status auth.account_status NOT NULL DEFAULT 'active'::auth.account_status, + roles text[] NOT NULL DEFAULT '{}', + updated_at timestamptz NOT NULL DEFAULT NOW(), + created_at timestamptz NOT NULL DEFAULT NOW() +); + +ALTER TABLE auth.access_token + RENAME COLUMN user_id TO account_id; +ALTER TABLE auth.session + RENAME COLUMN user_id TO account_id; + +ALTER TABLE auth.credential + RENAME COLUMN identity_id TO account_id; + +-- migrate:down +ALTER TABLE auth.identity + RENAME COLUMN account_id TO user_id; +ALTER TABLE auth.identity + ADD COLUMN roles text[] NOT NULL DEFAULT '{}'; +ALTER TABLE auth.identity + DROP COLUMN type; + +ALTER TABLE auth.access_token + RENAME COLUMN account_id TO user_id; +ALTER TABLE auth.session + RENAME COLUMN account_id TO user_id; + +ALTER TABLE auth.credential + RENAME COLUMN account_id TO identity_id; + +DROP TABLE auth.account; +DROP TYPE auth.account_status; \ No newline at end of file diff --git a/auth/storage/models.go b/auth/storage/models.go index 160962d..ef8d075 100644 --- a/auth/storage/models.go +++ b/auth/storage/models.go @@ -13,11 +13,61 @@ import ( null "gopkg.in/guregu/null.v4" ) +type AccountStatus string + +const ( + AccountStatusActive AccountStatus = "active" + AccountStatusBlocked AccountStatus = "blocked" +) + +func (e *AccountStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AccountStatus(s) + case string: + *e = AccountStatus(s) + default: + return fmt.Errorf("unsupported scan type for AccountStatus: %T", src) + } + return nil +} + +type NullAccountStatus struct { + AccountStatus AccountStatus `json:"accountStatus"` + Valid bool `json:"valid"` // Valid is true if AccountStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAccountStatus) Scan(value interface{}) error { + if value == nil { + ns.AccountStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AccountStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAccountStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AccountStatus), nil +} + +func AllAccountStatusValues() []AccountStatus { + return []AccountStatus{ + AccountStatusActive, + AccountStatusBlocked, + } +} + type IdentityStatus string const ( - IdentityStatusActive IdentityStatus = "active" - IdentityStatusBlocked IdentityStatus = "blocked" + IdentityStatusActive IdentityStatus = "active" + IdentityStatusBlocked IdentityStatus = "blocked" + IdentityStatusNotVerified IdentityStatus = "not-verified" ) func (e *IdentityStatus) Scan(src interface{}) error { @@ -59,6 +109,7 @@ func AllIdentityStatusValues() []IdentityStatus { return []IdentityStatus{ IdentityStatusActive, IdentityStatusBlocked, + IdentityStatusNotVerified, } } @@ -66,7 +117,7 @@ type AccessToken struct { Hash string `db:"hash" json:"hash"` IdentityID uuid.UUID `db:"identity_id" json:"identityId"` SessionID uuid.UUID `db:"session_id" json:"sessionId"` - UserID uuid.UUID `db:"user_id" json:"userId"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` Roles []string `db:"roles" json:"roles"` Data []byte `db:"data" json:"data"` RevokedAt null.Time `db:"revoked_at" json:"revokedAt"` @@ -74,23 +125,32 @@ type AccessToken struct { CreatedAt time.Time `db:"created_at" json:"createdAt"` } +type Account struct { + ID uuid.UUID `db:"id" json:"id"` + Status AccountStatus `db:"status" json:"status"` + Roles []string `db:"roles" json:"roles"` + UpdatedAt time.Time `db:"updated_at" json:"updatedAt"` + CreatedAt time.Time `db:"created_at" json:"createdAt"` +} + type Credential struct { - Hash string `db:"hash" json:"hash"` - IdentityID uuid.UUID `db:"identity_id" json:"identityId"` - Type string `db:"type" json:"type"` - ExpiredAt null.Time `db:"expired_at" json:"expiredAt"` - CreatedAt time.Time `db:"created_at" json:"createdAt"` + Hash string `db:"hash" json:"hash"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` + Type string `db:"type" json:"type"` + ExpiredAt null.Time `db:"expired_at" json:"expiredAt"` + CreatedAt time.Time `db:"created_at" json:"createdAt"` } type Identity struct { ID uuid.UUID `db:"id" json:"id"` Identity string `db:"identity" json:"identity"` - UserID uuid.UUID `db:"user_id" json:"userId"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` Status IdentityStatus `db:"status" json:"status"` Data []byte `db:"data" json:"data"` - Roles []string `db:"roles" json:"roles"` UpdatedAt time.Time `db:"updated_at" json:"updatedAt"` CreatedAt time.Time `db:"created_at" json:"createdAt"` + // Type of the identity (eg. email, phone, google-auth, etc.). + Type string `db:"type" json:"type"` } type RefreshToken struct { @@ -104,7 +164,7 @@ type RefreshToken struct { type Session struct { ID uuid.UUID `db:"id" json:"id"` - UserID uuid.UUID `db:"user_id" json:"userId"` + AccountID uuid.UUID `db:"account_id" json:"accountId"` IdentityID uuid.UUID `db:"identity_id" json:"identityId"` Data []byte `db:"data" json:"data"` ExpiresAt time.Time `db:"expires_at" json:"expiresAt"` diff --git a/auth/storage/query/access_token.sql b/auth/storage/query/access_token.sql index ddabc30..0cb43dc 100644 --- a/auth/storage/query/access_token.sql +++ b/auth/storage/query/access_token.sql @@ -1,8 +1,8 @@ -- name: CreateAccessToken :one -INSERT INTO auth.access_token (hash, identity_id, session_id, user_id, roles, data, expires_at) +INSERT INTO auth.access_token (hash, identity_id, session_id, account_id, roles, data, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *; --- name: GetAccessTokenByHash :one +-- name: FindAccessTokenByHash :one SELECT * FROM auth.access_token WHERE hash = $1; @@ -17,16 +17,16 @@ UPDATE auth.access_token SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL; --- name: RevokeUserAccessTokens :exec +-- name: RevokeAccountAccessTokens :exec UPDATE auth.access_token SET revoked_at = now() -WHERE user_id = $1 AND revoked_at IS NULL; +WHERE account_id = $1 AND revoked_at IS NULL; --- name: GetUserNotRevokedSessionIds :many +-- name: FindAccountNotRevokedSessionIds :many SELECT session_id FROM auth.access_token WHERE revoked_at IS NULL - AND user_id = $1; + AND account_id = $1; -- name: ExpireSessionAccessTokens :exec UPDATE auth.access_token diff --git a/auth/storage/query/account.sql b/auth/storage/query/account.sql new file mode 100644 index 0000000..c5d75bb --- /dev/null +++ b/auth/storage/query/account.sql @@ -0,0 +1,33 @@ +-- name: RegisterAccount :one +INSERT INTO auth.account (id) +VALUES ($1) RETURNING *; + +-- name: FindAccount :one +SELECT * +FROM auth.account +WHERE id = $1; + +-- name: BlockAccount :exec +UPDATE auth.account +SET status = 'blocked'::auth.account_status +WHERE id = $1; + +-- name: UnblockAccount :exec +UPDATE auth.account +SET status = 'active'::auth.account_status +WHERE id = $1; + + +-- name: AddRoles :exec +update "auth"."account" +set roles = array(select distinct unnest(roles || @roles::text[])) +where id = @id::uuid; + +-- name: RemoveRoles :exec +update "auth"."account" +set roles = array(select distinct unnest(roles) except select distinct unnest(@roles::text[])) +where id = @id::uuid; + +-- name: RemoveAccount :exec +DELETE FROM auth.account +WHERE id = $1; \ No newline at end of file diff --git a/auth/storage/query/credential.sql b/auth/storage/query/credential.sql index 99a0e32..4453629 100644 --- a/auth/storage/query/credential.sql +++ b/auth/storage/query/credential.sql @@ -1,19 +1,19 @@ -- name: CreateCredential :one INSERT INTO "auth"."credential" - (identity_id, type, hash, expired_at) -VALUES (@identity_id::uuid, @type::text, @hash::text, @expired_at) + (account_id, type, hash, expired_at) +VALUES (@account_id::uuid, @type::text, @hash::text, @expired_at) RETURNING *; -- name: FindLastCredential :one SELECT * FROM "auth"."credential" -WHERE identity_id = @identity_id::uuid +WHERE account_id = @account_id::uuid ORDER BY created_at DESC; -- name: FindLastCredentialOfType :one SELECT * FROM "auth"."credential" -WHERE identity_id = @identity_id::uuid +WHERE account_id = @account_id::uuid AND type = @type::text ORDER BY created_at DESC; @@ -21,4 +21,8 @@ ORDER BY created_at DESC; SELECT * FROM "auth"."credential" WHERE type = @type::text -ORDER BY created_at DESC; \ No newline at end of file +ORDER BY created_at DESC; + +-- name: RemoveCredentialsOfAccount :exec +DELETE FROM "auth"."credential" +WHERE account_id = @account_id::uuid; \ No newline at end of file diff --git a/auth/storage/query/identity.sql b/auth/storage/query/identity.sql index 15643b0..5ef3a79 100644 --- a/auth/storage/query/identity.sql +++ b/auth/storage/query/identity.sql @@ -1,10 +1,10 @@ -- name: CreateIdentity :one insert into "auth"."identity" - (id, identity, user_id, "data") -values (@id::uuid, @identity::text, @user_id::uuid, @data::jsonb) + (id, identity, account_id, "data", "type") +values (@id::uuid, @identity::text, @account_id::uuid, @data::jsonb, @type::text) RETURNING *; --- name: DeleteIdentity :exec +-- name: RemoveIdentity :exec delete from "auth"."identity" where id = @id; @@ -13,22 +13,36 @@ select * from "auth"."identity" where identity = @identity::text; --- name: FindUserIdentities :many +-- name: FindAccountIdentities :many select * from "auth"."identity" -where user_id = @user_id::uuid; +where account_id = @account_id::uuid; -- name: FindIdentityById :one select * from "auth"."identity" where id = @id::uuid; --- name: AddRoles :exec +-- name: BlockIdentity :exec update "auth"."identity" -set roles = array(select distinct unnest(roles || @roles::text[])) +set status = 'blocked'::auth.identity_status where id = @id::uuid; --- name: RemoveRoles :exec +-- name: BlockIdentitiesOfAccount :exec update "auth"."identity" -set roles = array(select distinct unnest(roles) except select distinct unnest(@roles::text[])) -where id = @id::uuid; \ No newline at end of file +set status = 'blocked'::auth.identity_status +where account_id = @account_id::uuid; + +-- name: ActivateIdentity :exec +update "auth"."identity" +set status = 'active'::auth.identity_status +where id = @id::uuid; + +-- name: RequestIdentityVerification :exec +update "auth"."identity" +set status = 'not-verified'::auth.identity_status +where id = @id::uuid; + +-- name: RemoveIdentitiesOfAccount :exec +delete from "auth"."identity" +where account_id = @account_id::uuid; \ No newline at end of file diff --git a/auth/storage/query/refresh_token.sql b/auth/storage/query/refresh_token.sql index 84dcb45..df26f26 100644 --- a/auth/storage/query/refresh_token.sql +++ b/auth/storage/query/refresh_token.sql @@ -18,7 +18,7 @@ UPDATE auth.refresh_token SET revoked_at = now() WHERE session_id = @session_ids::uuid[] AND revoked_at IS NULL; --- name: GetRefreshTokenByHash :one +-- name: FindRefreshTokenByHash :one SELECT * FROM auth.refresh_token WHERE hash = $1; diff --git a/auth/storage/refresh_token.sql.go b/auth/storage/refresh_token.sql.go index a46c91d..3b69167 100644 --- a/auth/storage/refresh_token.sql.go +++ b/auth/storage/refresh_token.sql.go @@ -59,13 +59,13 @@ func (q *Queries) ExpireSessionRefreshTokens(ctx context.Context, arg ExpireSess return err } -const getRefreshTokenByHash = `-- name: GetRefreshTokenByHash :one +const findRefreshTokenByHash = `-- name: FindRefreshTokenByHash :one SELECT hash, session_id, identity_id, revoked_at, expires_at, created_at FROM auth.refresh_token WHERE hash = $1` -func (q *Queries) GetRefreshTokenByHash(ctx context.Context, hash string) (RefreshToken, error) { - row := q.db.QueryRow(ctx, getRefreshTokenByHash, hash) +func (q *Queries) FindRefreshTokenByHash(ctx context.Context, hash string) (RefreshToken, error) { + row := q.db.QueryRow(ctx, findRefreshTokenByHash, hash) var i RefreshToken err := row.Scan( &i.Hash, diff --git a/auth/storage/token_repository.go b/auth/storage/token_repository.go index 2d2a518..fb0436a 100644 --- a/auth/storage/token_repository.go +++ b/auth/storage/token_repository.go @@ -18,6 +18,11 @@ type DefaultTokenRepository struct { hashStrategy hash.TokenHashStrategy } +func (r *DefaultTokenRepository) RevokeAccountTokens(ctx context.Context, accountId uuid.UUID) error { + //TODO implement me + panic("implement me") +} + func NewDefaultTokenRepository( db *pgxpool.Pool, hashStrategy hash.TokenHashStrategy, @@ -32,7 +37,6 @@ func (r *DefaultTokenRepository) CreateAccessToken( ctx context.Context, accessToken string, identityId uuid.UUID, - userId uuid.UUID, roles []string, sessionId uuid.UUID, data map[string]interface{}, @@ -48,12 +52,19 @@ func (r *DefaultTokenRepository) CreateAccessToken( return repository.AccessToken{}, errtrace.Wrap(errors.WithCause(repository.ErrCannotCreateAccessToken, err)) } + ident, err := r.queries.FindIdentityById(ctx, identityId) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return repository.AccessToken{}, errtrace.Wrap(repository.ErrIdentityNotFound) + } + return repository.AccessToken{}, errtrace.Wrap(err) + } storedAccessToken, err := r.queries.CreateAccessToken( ctx, CreateAccessTokenParams{ Hash: accessToken, IdentityID: identityId, - UserID: userId, + AccountID: ident.AccountID, Roles: roles, SessionID: sessionId, Data: dataJson, @@ -94,7 +105,7 @@ func (r *DefaultTokenRepository) GetRefreshToken(ctx context.Context, refreshTok error, ) { refreshToken = r.hashToken(refreshToken) - storedRefreshToken, err := r.queries.GetRefreshTokenByHash(ctx, refreshToken) + storedRefreshToken, err := r.queries.FindRefreshTokenByHash(ctx, refreshToken) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return repository.RefreshToken{}, errtrace.Wrap(repository.ErrTokenNotExist) @@ -109,7 +120,7 @@ func (r *DefaultTokenRepository) GetAccessToken(ctx context.Context, accessToken error, ) { accessToken = r.hashToken(accessToken) - storedAccessToken, err := r.queries.GetAccessTokenByHash(ctx, accessToken) + storedAccessToken, err := r.queries.FindAccessTokenByHash(ctx, accessToken) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return repository.AccessToken{}, errtrace.Wrap(repository.ErrTokenNotExist) @@ -180,8 +191,8 @@ func (r *DefaultTokenRepository) RevokeSessionTokens(ctx context.Context, sessio return nil } -func (r *DefaultTokenRepository) RevokeUserTokens(ctx context.Context, userId uuid.UUID) error { - sessionIds, err := r.queries.GetUserNotRevokedSessionIds(ctx, userId) +func (r *DefaultTokenRepository) RevokeUserTokens(ctx context.Context, accountId uuid.UUID) error { + sessionIds, err := r.queries.FindAccountNotRevokedSessionIds(ctx, accountId) if err != nil { return errtrace.Wrap(err) } @@ -189,7 +200,7 @@ func (r *DefaultTokenRepository) RevokeUserTokens(ctx context.Context, userId uu if err != nil { return errtrace.Wrap(err) } - err = r.queries.RevokeUserAccessTokens(ctx, userId) + err = r.queries.RevokeAccountAccessTokens(ctx, accountId) if err != nil { return errtrace.Wrap(err) } @@ -204,7 +215,7 @@ func (r *DefaultTokenRepository) transformAccessToken(storedAccessToken AccessTo return repository.AccessToken{ Hash: storedAccessToken.Hash, IdentityID: storedAccessToken.IdentityID, - UserID: storedAccessToken.UserID, + AccountID: storedAccessToken.AccountID, Roles: storedAccessToken.Roles, SessionID: storedAccessToken.SessionID, Data: data, diff --git a/docs/graphql_server_example.md b/docs/graphql_server_example.md index dba8300..f978456 100644 --- a/docs/graphql_server_example.md +++ b/docs/graphql_server_example.md @@ -980,14 +980,31 @@ func NewRegisterUser( func (r *RegisterUser) Execute(ctx context.Context, input RegisterUserInput) (storage.User, error) { ... - _, err = r.passwordAuth.Register( + // register the new account with identity + // it also creates a new identity for this account + account, err := r.passwordAuth.Register( ctx, input.Email, input.Password, - user.ID, - // the authenticated user role that will be used in the future + // type of the created identity + repository.IdentityTypeEmail, + // the authenticated user role that will be used in the future []string{"user"}, - nil, + nil, + ) + if err != nil { + return storage.User{}, errtrace.Wrap(err) + } + user, err := r.userDb.RegisterUser( + ctx, storage.RegisterUserParams{ + // store somewhere the ID of created account. It will be used in the future authentication requests as a Performer.ID field + // it is a good idea to store the ID of the account as a primary key of the user table + // it will be relation 1:1 between the account and the user + ID: account.ID, + // store data or the registered user + Email: input.Email, + Name: input.Name, + }, ) if err != nil { return storage.User{}, errtrace.Wrap(err) diff --git a/examples/blog/.env b/examples/blog/.env index f41a558..f04deeb 100644 --- a/examples/blog/.env +++ b/examples/blog/.env @@ -36,3 +36,7 @@ LOGGER_LEVEL=debug # Use either "console" or "json" value LOGGER_TYPE=console + + +AUTH_ACCESS_TOKEN_TTL=1h0m0s +AUTH_REFRESH_TOKEN_TTL=720h0m0s diff --git a/examples/blog/go.mod b/examples/blog/go.mod index 1244555..9e2aa1e 100644 --- a/examples/blog/go.mod +++ b/examples/blog/go.mod @@ -8,7 +8,7 @@ require ( braces.dev/errtrace v0.3.0 github.com/99designs/gqlgen v0.17.66 github.com/debugger84/sqlc-dataloader v0.1.4 - github.com/go-modulus/modulus v0.2.5 + github.com/go-modulus/modulus v0.3.5 github.com/go-ozzo/ozzo-validation/v4 v4.3.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/graph-gophers/dataloader/v7 v7.1.0 @@ -103,4 +103,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/go-modulus/modulus v0.2.5 => ../../ +replace github.com/go-modulus/modulus v0.3.5 => ../../ diff --git a/examples/blog/go.sum b/examples/blog/go.sum index c22d1a3..0c3060a 100644 --- a/examples/blog/go.sum +++ b/examples/blog/go.sum @@ -69,8 +69,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-modulus/modulus v0.2.5 h1:9edbRiqi06VVoQUdKHsDuMpzT0YMiOOtu97jDN1wfXw= -github.com/go-modulus/modulus v0.2.5/go.mod h1:HEvT13JZCkFhap60kOnPeTLbv8lEvjzGwx/Yl82XVOg= +github.com/go-modulus/modulus v0.3.5 h1:cX3FX+5YB3kMp31sKDRo6wTNkdt1TIEwIbf9AaSeQjI= +github.com/go-modulus/modulus v0.3.5/go.mod h1:HEvT13JZCkFhap60kOnPeTLbv8lEvjzGwx/Yl82XVOg= github.com/go-ozzo/ozzo-validation/v4 v4.3.0 h1:byhDUpfEwjsVQb1vBunvIjh2BHQ9ead57VkAEY4V+Es= github.com/go-ozzo/ozzo-validation/v4 v4.3.0/go.mod h1:2NKgrcHl3z6cJs+3Oo940FPRiTzuqKbvfrL2RxCj6Ew= github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo= diff --git a/examples/blog/internal/auth/storage/migration/20240320084613_auth_account.sql b/examples/blog/internal/auth/storage/migration/20240320084613_auth_account.sql new file mode 100644 index 0000000..d1643d8 --- /dev/null +++ b/examples/blog/internal/auth/storage/migration/20240320084613_auth_account.sql @@ -0,0 +1,53 @@ +-- migrate:up + +CREATE TYPE auth.account_status AS ENUM ( + 'active', + 'blocked' + ); + +ALTER TABLE auth.identity + RENAME COLUMN user_id TO account_id; +ALTER TABLE auth.identity + DROP COLUMN roles; +ALTER TABLE auth.identity + ADD COLUMN type text NOT NULL DEFAULT 'not-set'; + +COMMENT ON COLUMN auth.identity.type IS 'Type of the identity (eg. email, phone, google-auth, etc.).'; + +ALTER TYPE auth.identity_status ADD VALUE IF NOT EXISTS 'not-verified'; + +CREATE TABLE auth.account +( + id uuid PRIMARY KEY, + status auth.account_status NOT NULL DEFAULT 'active'::auth.account_status, + roles text[] NOT NULL DEFAULT '{}', + updated_at timestamptz NOT NULL DEFAULT NOW(), + created_at timestamptz NOT NULL DEFAULT NOW() +); + +ALTER TABLE auth.access_token + RENAME COLUMN user_id TO account_id; +ALTER TABLE auth.session + RENAME COLUMN user_id TO account_id; + +ALTER TABLE auth.credential + RENAME COLUMN identity_id TO account_id; + +-- migrate:down +ALTER TABLE auth.identity + RENAME COLUMN account_id TO user_id; +ALTER TABLE auth.identity + ADD COLUMN roles text[] NOT NULL DEFAULT '{}'; +ALTER TABLE auth.identity + DROP COLUMN type; + +ALTER TABLE auth.access_token + RENAME COLUMN account_id TO user_id; +ALTER TABLE auth.session + RENAME COLUMN account_id TO user_id; + +ALTER TABLE auth.credential + RENAME COLUMN account_id TO identity_id; + +DROP TABLE auth.account; +DROP TYPE auth.account_status; \ No newline at end of file diff --git a/examples/blog/internal/blog/storage/fixture/post.go b/examples/blog/internal/blog/storage/fixture/post.go index 2958690..39b2d77 100644 --- a/examples/blog/internal/blog/storage/fixture/post.go +++ b/examples/blog/internal/blog/storage/fixture/post.go @@ -91,10 +91,10 @@ func (f *PostFixture) clone() *PostFixture { } func (f *PostFixture) save(ctx context.Context) error { - query := `INSERT INTO blog.post - (id, title, preview, content, status, created_at, updated_at, published_at, deleted_at, author_id) + query := `INSERT INTO "blog"."post" + ("id", "title", "preview", "content", "status", "created_at", "updated_at", "published_at", "deleted_at", "author_id") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - RETURNING id, title, preview, content, status, created_at, updated_at, published_at, deleted_at, author_id + RETURNING "id", "title", "preview", "content", "status", "created_at", "updated_at", "published_at", "deleted_at", "author_id" ` row := f.db.QueryRow(ctx, query, f.entity.ID, @@ -142,7 +142,7 @@ func (f *PostFixture) Create(tb testing.TB) *PostFixture { func (f *PostFixture) Cleanup(tb testing.TB) *PostFixture { tb.Cleanup( func() { - query := `DELETE FROM blog.post WHERE id = $1` + query := `DELETE FROM "blog"."post" WHERE id = $1` _, err := f.db.Exec(context.Background(), query, f.entity.ID) if err != nil { @@ -156,7 +156,7 @@ func (f *PostFixture) Cleanup(tb testing.TB) *PostFixture { func (f *PostFixture) PullUpdates(tb testing.TB) *PostFixture { c := f.clone() ctx := context.Background() - query := `SELECT * FROM blog.post WHERE id = $1` + query := `SELECT "id", "title", "preview", "content", "status", "created_at", "updated_at", "published_at", "deleted_at", "author_id" FROM "blog"."post" WHERE id = $1` row := f.db.QueryRow(ctx, query, c.entity.ID, ) @@ -182,17 +182,17 @@ func (f *PostFixture) PullUpdates(tb testing.TB) *PostFixture { func (f *PostFixture) PushUpdates(tb testing.TB) *PostFixture { c := f.clone() query := ` - UPDATE blog.post SET - title = $2, - preview = $3, - content = $4, - status = $5, - created_at = $6, - updated_at = $7, - published_at = $8, - deleted_at = $9, - author_id = $10 - WHERE id = $1 + UPDATE "blog"."post" SET + "title" = $2, + "preview" = $3, + "content" = $4, + "status" = $5, + "created_at" = $6, + "updated_at" = $7, + "published_at" = $8, + "deleted_at" = $9, + "author_id" = $10 + WHERE "id" = $1 ` _, err := f.db.Exec( context.Background(), diff --git a/examples/blog/internal/user/action/register_user.go b/examples/blog/internal/user/action/register_user.go index ee9037c..3a90c3a 100644 --- a/examples/blog/internal/user/action/register_user.go +++ b/examples/blog/internal/user/action/register_user.go @@ -6,11 +6,11 @@ import ( "context" "errors" "github.com/go-modulus/modulus/auth" + "github.com/go-modulus/modulus/auth/repository" "github.com/go-modulus/modulus/errors/erruser" "github.com/go-modulus/modulus/validator" validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/go-ozzo/ozzo-validation/v4/is" - "github.com/gofrs/uuid" "github.com/jackc/pgx/v5" ) @@ -79,21 +79,13 @@ func (r *RegisterUser) Execute(ctx context.Context, input RegisterUserInput) (st } else { return storage.User{}, ErrUserAlreadyExists } - user, err := r.userDb.RegisterUser( - ctx, storage.RegisterUserParams{ - ID: uuid.Must(uuid.NewV6()), - Email: input.Email, - Name: input.Name, - }, - ) - if err != nil { - return storage.User{}, errtrace.Wrap(err) - } - _, err = r.passwordAuth.Register( + + account, err := r.passwordAuth.Register( ctx, input.Email, input.Password, - user.ID, + // type of the created account + repository.IdentityTypeEmail, // the authenticated user role that will be used in the future []string{"user"}, nil, @@ -101,5 +93,15 @@ func (r *RegisterUser) Execute(ctx context.Context, input RegisterUserInput) (st if err != nil { return storage.User{}, errtrace.Wrap(err) } + user, err := r.userDb.RegisterUser( + ctx, storage.RegisterUserParams{ + ID: account.ID, + Email: input.Email, + Name: input.Name, + }, + ) + if err != nil { + return storage.User{}, errtrace.Wrap(err) + } return user, nil } diff --git a/modules.json b/modules.json index 31555ad..dd4689c 100644 --- a/modules.json +++ b/modules.json @@ -270,6 +270,10 @@ "sourceUrl": "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/storage/migration/20240214134322_auth.sql", "destFile": "internal/auth/storage/migration/20240214134322_auth.sql" }, + { + "sourceUrl": "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/storage/migration/20240320084613_auth_account.sql", + "destFile": "internal/auth/storage/migration/20240320084613_auth_account.sql" + }, { "sourceUrl": "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/install/module.go.tmpl", "destFile": "internal/auth/module.go"