Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions auth/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func NewModule() *module.Module {
SetOverriddenProvider("CredentialRepository", storage.NewDefaultCredentialRepository).
SetOverriddenProvider("IdentityRepository", storage.NewDefaultIdentityRepository).
SetOverriddenProvider("TokenRepository", storage.NewDefaultTokenRepository).
SetOverriddenProvider("AccountRepository", storage.NewDefaultAccountRepository).
SetOverriddenProvider("TokenHashStrategy", hash.NewSha1).
SetOverriddenProvider(
"MiddlewareAuthenticator", func(auth *PlainTokenAuthenticator) Authenticator {
Expand Down Expand Up @@ -58,6 +59,12 @@ func OverrideTokenRepository(authModule *module.Module, repository interface{})
return authModule.SetOverriddenProvider("TokenRepository", repository)
}

// OverrideAccountRepository overrides the default account storage implementation with the custom one.
// repository should be a constructor returning the implementation of the AccountRepository interface.
func OverrideAccountRepository(authModule *module.Module, repository interface{}) *module.Module {
return authModule.SetOverriddenProvider("AccountRepository", repository)
}

// OverrideTokenHashStrategy overrides the default token hash strategy with the custom one.
// strategy should be a constructor returning the implementation of the hash.TokenHashStrategy interface.
// by default, the sha1 hash strategy is used.
Expand All @@ -84,6 +91,10 @@ func NewManifestModule() module.ManifestModule {
SourceUrl: "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/storage/migration/20240214134322_auth.sql",
DestFile: "internal/auth/storage/migration/20240214134322_auth.sql",
},
module.InstalledFile{
SourceUrl: "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/storage/migration/20240320084613_auth_account.sql",
DestFile: "internal/auth/storage/migration/20240320084613_auth_account.sql",
},
module.InstalledFile{
SourceUrl: "https://raw.githubusercontent.com/go-modulus/modulus/refs/heads/main/auth/install/module.go.tmpl",
DestFile: "internal/auth/module.go",
Expand Down
52 changes: 36 additions & 16 deletions auth/password_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ var ErrInvalidPassword = errors.New("invalid password")
var ErrCannotHashPassword = errors.New("cannot hash password")

type PasswordAuthenticator struct {
accountRepository repository.AccountRepository
identityRepository repository.IdentityRepository
credentialRepository repository.CredentialRepository
}

func NewPasswordAuthenticator(
accountRepository repository.AccountRepository,
identityRepository repository.IdentityRepository,
credentialRepository repository.CredentialRepository,
) *PasswordAuthenticator {
return &PasswordAuthenticator{
accountRepository: accountRepository,
identityRepository: identityRepository,
credentialRepository: credentialRepository,
}
Expand All @@ -49,7 +52,7 @@ func (a *PasswordAuthenticator) Authenticate(ctx context.Context, identity, pass
return Performer{}, errtrace.Wrap(ErrIdentityIsBlocked)
}

cred, err := a.credentialRepository.GetLast(ctx, identityObj.ID, string(repository.CredentialTypePassword))
cred, err := a.credentialRepository.GetLast(ctx, identityObj.AccountID, string(repository.CredentialTypePassword))
if err != nil {
if errors.Is(err, repository.ErrCredentialNotFound) {
return Performer{}, errtrace.Wrap(ErrInvalidPassword)
Expand All @@ -61,7 +64,7 @@ func (a *PasswordAuthenticator) Authenticate(ctx context.Context, identity, pass
if err != nil {
return Performer{}, errtrace.Wrap(ErrInvalidPassword)
}
return Performer{ID: identityObj.UserID, SessionID: uuid.Must(uuid.NewV6()), IdentityID: identityObj.ID}, nil
return Performer{ID: identityObj.AccountID, SessionID: uuid.Must(uuid.NewV6()), IdentityID: identityObj.ID}, nil
}

// Register registers a new user account with the given identity and password.
Expand All @@ -77,47 +80,64 @@ func (a *PasswordAuthenticator) Register(
ctx context.Context,
identity,
password string,
userID uuid.UUID,
identityType repository.IdentityType,
roles []string,
additionalData map[string]interface{},
) (repository.Identity, error) {
) (repository.Account, error) {
identityObj, err := a.identityRepository.Get(ctx, identity)
if err == nil {
if identityObj.IsBlocked() {
return repository.Identity{}, errtrace.Wrap(ErrIdentityIsBlocked)
return repository.Account{}, errtrace.Wrap(ErrIdentityIsBlocked)
}
return repository.Identity{}, errtrace.Wrap(repository.ErrIdentityExists)
return repository.Account{}, errtrace.Wrap(repository.ErrIdentityExists)
} else if !errors.Is(err, repository.ErrIdentityNotFound) {
return repository.Identity{}, errtrace.Wrap(err)
return repository.Account{}, errtrace.Wrap(err)
}

identityObj, err = a.identityRepository.Create(ctx, identity, userID, additionalData)
accountID, err := uuid.NewV6()
if err != nil {
return repository.Identity{}, errtrace.Wrap(err)
return repository.Account{}, errtrace.Wrap(err)
}

account, err := a.accountRepository.Create(ctx, accountID)
if err != nil {
return repository.Account{}, errtrace.Wrap(err)
}

identityObj, err = a.identityRepository.Create(ctx, identity, accountID, identityType, additionalData)
if err != nil {
return repository.Account{}, errtrace.Wrap(err)
}

defer func() {
if err != nil {
_ = a.accountRepository.RemoveAccount(ctx, accountID)
_ = a.identityRepository.RemoveIdentity(ctx, identity)
}
}()

hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return repository.Identity{}, errtrace.Wrap(errors.WithCause(ErrCannotHashPassword, err))
return repository.Account{}, errtrace.Wrap(errors.WithCause(ErrCannotHashPassword, err))
}

_, err = a.credentialRepository.Create(
ctx,
identityObj.ID,
accountID,
string(hash),
string(repository.CredentialTypePassword),
repository.CredentialTypePassword,
nil,
)
if err != nil {
return repository.Identity{}, errtrace.Wrap(err)
return repository.Account{}, errtrace.Wrap(err)
}

if len(roles) > 0 {
err = a.identityRepository.AddRoles(ctx, identityObj.ID, roles...)
err = a.accountRepository.AddRoles(ctx, identityObj.ID, roles...)
if err != nil {
return repository.Identity{}, errtrace.Wrap(err)
return repository.Account{}, errtrace.Wrap(err)
}
}

return identityObj, nil
return account, nil
}
111 changes: 59 additions & 52 deletions auth/password_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth_test

import (
"context"
"encoding/json"
"github.com/go-modulus/modulus/auth"
"github.com/go-modulus/modulus/auth/repository"
"github.com/go-modulus/modulus/auth/storage"
Expand All @@ -15,84 +16,89 @@ func TestPasswordAuthenticator_Register(t *testing.T) {
t.Run(
"register identity without additional data", func(t *testing.T) {
t.Parallel()
userId := uuid.Must(uuid.NewV6())
identity, err := passwordAuth.Register(
account, err := passwordAuth.Register(
context.Background(),
"user",
"password",
userId,
repository.IdentityTypeNickname,
[]string{},
nil,
)
require.NoError(t, err)

savedIdentity := fixtureFactory.Identity().ID(identity.ID).PullUpdates(t).Cleanup(t).GetEntity()
fixtureFactory.Credential().IdentityID(identity.ID).CleanupAllOfIdentity(t)
savedAccount := fixtureFactory.Account().ID(account.ID).PullUpdates(t).Cleanup(t).GetEntity()
savedIdentity := fixtureFactory.Identity().AccountID(account.ID).PullUpdatesLastAccountIdentity(t).CleanupAllOfAccount(t).GetEntity()
fixtureFactory.Credential().AccountID(account.ID).CleanupAllOfAccount(t)

t.Log("When the identity is registered")
t.Log(" Then the identity is returned")
t.Log("When the account is registered")
t.Log(" Then the account is returned")
require.NoError(t, err)
require.Equal(t, userId, identity.UserID)
require.Equal(t, "user", identity.Identity)
require.Equal(t, repository.IdentityStatusActive, identity.Status)
require.Empty(t, identity.Data)

t.Log(" And the identity is saved")
require.Equal(t, identity.UserID, savedIdentity.UserID)
require.Equal(t, identity.Identity, savedIdentity.Identity)
require.Equal(t, repository.AccountStatusActive, account.Status)

t.Log(" And the account is saved")
require.Equal(t, "user", savedIdentity.Identity)
require.Equal(t, storage.IdentityStatusActive, savedIdentity.Status)

t.Log(" And the account is created")
require.Equal(t, account.ID, savedAccount.ID)
require.Equal(t, storage.AccountStatusActive, savedAccount.Status)
},
)

t.Run(
"register identity with additional data", func(t *testing.T) {
t.Parallel()
userId := uuid.Must(uuid.NewV6())
identity, err := passwordAuth.Register(
account, err := passwordAuth.Register(
context.Background(),
"user1",
"password",
userId,
repository.IdentityTypeNickname,
[]string{},
map[string]interface{}{
"key": "value",
},
)

savedIdentity := fixtureFactory.Identity().ID(identity.ID).PullUpdates(t).GetEntity()
fixtureFactory.Credential().IdentityID(identity.ID).CleanupAllOfIdentity(t)
fixtureFactory.Identity().UserID(userId).CleanupAllOfUser(t)
savedAccount := fixtureFactory.Account().ID(account.ID).PullUpdates(t).Cleanup(t).GetEntity()
savedIdentity := fixtureFactory.Identity().AccountID(account.ID).PullUpdatesLastAccountIdentity(t).CleanupAllOfAccount(t).GetEntity()
fixtureFactory.Credential().AccountID(account.ID).CleanupAllOfAccount(t)
fixtureFactory.Identity().AccountID(account.ID).CleanupAllOfAccount(t)

t.Log("When the identity is registered")
t.Log(" Then the identity is returned")
var data map[string]interface{}
errUnmarshal := json.Unmarshal(savedIdentity.Data, &data)

t.Log("When the account is registered")
t.Log(" Then the account is returned")
require.NoError(t, err)
require.Equal(t, userId, identity.UserID)
require.Equal(t, "user1", identity.Identity)
require.Equal(t, repository.IdentityStatusActive, identity.Status)
require.Equal(t, "value", identity.Data["key"])

t.Log(" And the identity is saved")
require.Equal(t, identity.UserID, savedIdentity.UserID)
require.Equal(t, identity.Identity, savedIdentity.Identity)
require.Equal(t, repository.AccountStatusActive, account.Status)

t.Log(" And the account is saved")
require.NoError(t, errUnmarshal)
require.Equal(t, "user1", savedIdentity.Identity)
require.Equal(t, storage.IdentityStatusActive, savedIdentity.Status)
require.Equal(t, "value", data["key"])

t.Log(" And the account is created")
require.Equal(t, account.ID, savedAccount.ID)
require.Equal(t, storage.AccountStatusActive, savedAccount.Status)
},
)

t.Run(
"fail on the second registration of identity", func(t *testing.T) {
t.Parallel()
userId := uuid.Must(uuid.NewV6())
account := fixtureFactory.Account().Create(t).GetEntity()
identity := fixtureFactory.Identity().
ID(uuid.Must(uuid.NewV6())).
UserID(userId).
AccountID(account.ID).
Identity("user2").
Create(t).
GetEntity()
_, err := passwordAuth.Register(
context.Background(),
identity.Identity,
"password",
userId,
repository.IdentityTypeNickname,
[]string{},
nil,
)
Expand All @@ -107,10 +113,10 @@ func TestPasswordAuthenticator_Register(t *testing.T) {
t.Run(
"fail if identity is blocked", func(t *testing.T) {
t.Parallel()
userId := uuid.Must(uuid.NewV6())
accountId := uuid.Must(uuid.NewV6())
identity := fixtureFactory.Identity().
ID(uuid.Must(uuid.NewV6())).
UserID(userId).
AccountID(accountId).
Identity("user3").
Status(storage.IdentityStatusBlocked).
Create(t).
Expand All @@ -120,7 +126,7 @@ func TestPasswordAuthenticator_Register(t *testing.T) {
context.Background(),
identity.Identity,
"password",
userId,
repository.IdentityTypeNickname,
[]string{},
nil,
)
Expand All @@ -137,33 +143,34 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) {
t.Run(
"authenticate", func(t *testing.T) {
t.Parallel()
userId := uuid.Must(uuid.NewV6())
identity, err := passwordAuth.Register(
identity := "user4"
account, err := passwordAuth.Register(
context.Background(),
"user4",
identity,
"password",
userId,
repository.IdentityTypeNickname,
[]string{},
map[string]interface{}{
"key": "value",
},
)
require.NoError(t, err)

fixtureFactory.Identity().ID(identity.ID).Cleanup(t)
fixtureFactory.Credential().IdentityID(identity.ID).CleanupAllOfIdentity(t)
fixtureFactory.Account().ID(account.ID).Cleanup(t)
fixtureFactory.Identity().AccountID(account.ID).CleanupAllOfAccount(t)
fixtureFactory.Credential().AccountID(account.ID).CleanupAllOfAccount(t)

performer, err := passwordAuth.Authenticate(
context.Background(),
identity.Identity,
identity,
"password",
)

t.Log("Given the identity is registered")
t.Log("Given the account is registered")
t.Log("When try to authenticate with the correct password")
t.Log(" Then the performer is returned")
require.NoError(t, err)
require.Equal(t, userId, performer.ID)
require.Equal(t, account.ID, performer.ID)
},
)

Expand All @@ -184,15 +191,15 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) {
"fail if no password in database", func(t *testing.T) {
t.Parallel()

userId := uuid.Must(uuid.NewV6())
account := fixtureFactory.Account().Create(t).GetEntity()
identity := fixtureFactory.Identity().
ID(uuid.Must(uuid.NewV6())).
Identity("user6").
UserID(userId).
AccountID(account.ID).
Create(t).
GetEntity()
fixtureFactory.Credential().
IdentityID(identity.ID).
AccountID(account.ID).
Type(string(repository.CredentialTypeOTP)).
Hash("ssss").
Create(t)
Expand All @@ -214,12 +221,12 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) {
identity := fixtureFactory.Identity().
ID(uuid.Must(uuid.NewV6())).
Identity("user7").
UserID(userId).
AccountID(userId).
Create(t).
GetEntity()

fixtureFactory.Credential().
IdentityID(identity.ID).
AccountID(userId).
Hash("ssss2").
Create(t)
_, err := passwordAuth.Authenticate(context.Background(), identity.Identity, "password")
Expand All @@ -240,7 +247,7 @@ func TestPasswordAuthenticator_Authenticate(t *testing.T) {
identity := fixtureFactory.Identity().
ID(uuid.Must(uuid.NewV6())).
Identity("user8").
UserID(userId).
AccountID(userId).
Status(storage.IdentityStatusBlocked).
Create(t).
GetEntity()
Expand Down
Loading