Skip to content

Commit

Permalink
chore(userAccounts): register new user credentials, delete on user de…
Browse files Browse the repository at this point in the history
…lete
  • Loading branch information
benjohns1 committed Mar 14, 2024
1 parent 19ac83c commit e080c4f
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 31 deletions.
10 changes: 8 additions & 2 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type (
SessionRepo
FileRepo
UserRepo
CredentialRepo
GenerateToken func() (Token, error)
Clock
PasswordHasher
Expand Down Expand Up @@ -47,14 +48,19 @@ type (
Delete(context.Context, blinkfile.UserID) error
}

CredentialRepo interface {
Set(context.Context, Credentials) error
Remove(context.Context, blinkfile.UserID) error
}

PasswordHasher interface {
Hash(data []byte) (hash string)
Match(hash string, data []byte) (matched bool, err error)
}

App struct {
cfg Config
credentials map[blinkfile.Username]Credentials
cfg Config
adminCredentials map[blinkfile.Username]Credentials
Log
}

Expand Down
22 changes: 22 additions & 0 deletions app/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ func AppConfigDefaults(cfg app.Config) app.Config {
out.UserRepo = &StubUserRepo{}
}

if cfg.CredentialRepo == nil {
out.CredentialRepo = &StubCredentialRepo{}
}

if cfg.Log == nil {
out.Log = log.New(log.Config{})
}
Expand Down Expand Up @@ -172,6 +176,24 @@ func (ur *StubUserRepo) Delete(ctx context.Context, userID blinkfile.UserID) err
return nil
}

type StubCredentialRepo struct {
SetFunc func(context.Context, app.Credentials) error
RemoveFunc func(context.Context, blinkfile.UserID) error
}

func (cr *StubCredentialRepo) Set(ctx context.Context, cred app.Credentials) error {
if cr.SetFunc != nil {
return cr.SetFunc(ctx, cred)
}
return nil
}
func (cr *StubCredentialRepo) Remove(ctx context.Context, userID blinkfile.UserID) error {
if cr.RemoveFunc != nil {
return cr.RemoveFunc(ctx, userID)
}
return nil
}

func TestNew(t *testing.T) {
ctx := context.Background()
type args struct {
Expand Down
31 changes: 17 additions & 14 deletions app/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (

type Credentials struct {
blinkfile.UserID
username blinkfile.Username
encodedPasswordHash string
blinkfile.Username
PasswordHash string
}

const passwordMinLength = 16

var ErrPasswordTooShort = fmt.Errorf("password must be at least %d characters long", passwordMinLength)

func newPasswordCredentials(userID blinkfile.UserID, user blinkfile.Username, pass string, hash func([]byte) string) (Credentials, error) {
if userID == "" {
return Credentials{}, fmt.Errorf("user ID cannot be empty")
Expand All @@ -24,13 +26,13 @@ func newPasswordCredentials(userID blinkfile.UserID, user blinkfile.Username, pa
return Credentials{}, fmt.Errorf("username cannot be empty")
}
if len(pass) < passwordMinLength {
return Credentials{}, fmt.Errorf("password must be at least %d characters long", passwordMinLength)
return Credentials{}, ErrPasswordTooShort
}
encodedHash := hash([]byte(pass))
return Credentials{
UserID: userID,
username: user,
encodedPasswordHash: encodedHash,
UserID: userID,
Username: user,
PasswordHash: encodedHash,
}, nil
}

Expand Down Expand Up @@ -108,44 +110,45 @@ func (a *App) authenticate(username blinkfile.Username, password string) (blinkf
if password == "" {
return "", Err(ErrAuthnFailed, fmt.Errorf("invalid credentials: password cannot be empty"))
}
credentials, found, err := a.getCredentials(username)
cred, found, err := a.getAdminCredentials(username)
if err != nil {
return "", Err(ErrInternal, fmt.Errorf("error retrieving credentials for %q: %w", username, err))
}
if !found {
return "", Err(ErrAuthnFailed, fmt.Errorf("invalid credentials: no username %q found", username))
}
match, err := credentialsMatch(credentials, username, password, a.cfg.PasswordHasher.Match)
match, err := credentialsMatch(cred, username, password, a.cfg.PasswordHasher.Match)
if err != nil {
return "", Err(ErrInternal, fmt.Errorf("error matching credentials: %w", err))
}
if !match {
return "", Err(ErrAuthnFailed, fmt.Errorf("invalid credentials: passwords do not match"))
}
return credentials.UserID, nil
return cred.UserID, nil
}

func credentialsMatch(c Credentials, username blinkfile.Username, password string, passwordMatcher func(hash string, data []byte) (matched bool, err error)) (bool, error) {
if !stringsAreEqual(string(c.username), string(username)) {
if !stringsAreEqual(string(c.Username), string(username)) {
return false, nil
}
return passwordMatcher(c.encodedPasswordHash, []byte(password))
return passwordMatcher(c.PasswordHash, []byte(password))
}

func (a *App) getCredentials(username blinkfile.Username) (Credentials, bool, error) {
creds, found := a.credentials[username]
func (a *App) getAdminCredentials(username blinkfile.Username) (Credentials, bool, error) {
creds, found := a.adminCredentials[username]
if !found {
return Credentials{}, false, nil
}
return creds, true, nil
}

func (a *App) userIsValid(userID blinkfile.UserID) bool {
for _, creds := range a.credentials {
for _, creds := range a.adminCredentials {
if creds.UserID == userID {
return true
}
}

return false
}

Expand Down
11 changes: 10 additions & 1 deletion app/testautomation/testautomation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ type (
Log app.Log
Clock
FileRepo
UserRepo UserRepo
UserRepo UserRepo
CredentialRepo CredentialRepo
}

Args struct {
Expand All @@ -38,6 +39,10 @@ type (
ListAll(context.Context) ([]blinkfile.User, error)
Delete(context.Context, blinkfile.UserID) error
}

CredentialRepo interface {
Remove(context.Context, blinkfile.UserID) error
}
)

func (a *Automator) TestAutomation(ctx context.Context, args Args) error {
Expand Down Expand Up @@ -74,6 +79,10 @@ func (a *Automator) TestAutomation(ctx context.Context, args Args) error {
if err != nil {
return err
}
err = a.CredentialRepo.Remove(ctx, user.ID)
if err != nil {
return err
}
}
a.Log.Printf(ctx, "deleted all users")
}
Expand Down
44 changes: 35 additions & 9 deletions app/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ type (
}
)

var ErrDuplicateUsername = fmt.Errorf("username already exists")
var (
ErrDuplicateUsername = fmt.Errorf("username already exists")
ErrUsernameTaken = fmt.Errorf("username already taken")
ErrCredentialNotFound = fmt.Errorf("credential not found")
)

func (a *App) CreateUser(ctx context.Context, args CreateUserArgs) error {
_, found, err := a.getCredentials(args.Username)
_, found, err := a.getAdminCredentials(args.Username)
if err != nil {
return Err(ErrInternal, err)
}
Expand All @@ -43,6 +47,26 @@ func (a *App) CreateUser(ctx context.Context, args CreateUserArgs) error {
}
return Err(ErrRepo, err)
}
err = a.registerCredentials(ctx, user.ID, user.Username, args.Password)
if err != nil {
if deleteErr := a.cfg.UserRepo.Delete(ctx, user.ID); deleteErr != nil {
a.Errorf(ctx, "deleting user after failure to register credentials: %v", deleteErr)
}
return err
}

return nil
}

func (a *App) registerCredentials(ctx context.Context, userID blinkfile.UserID, username blinkfile.Username, password string) error {
cred, err := newPasswordCredentials(userID, username, password, a.cfg.PasswordHasher.Hash)
if err != nil {
return ErrUser("Error creating user credentials", fmt.Sprintf("Credential error: %s", err), err)
}
err = a.cfg.CredentialRepo.Set(ctx, cred)
if err != nil {
return Err(ErrRepo, err)
}
return nil
}

Expand All @@ -60,6 +84,10 @@ func (a *App) DeleteUsers(ctx context.Context, userIDs []blinkfile.UserID) error
if err != nil {
return Err(ErrRepo, err)
}
err = a.cfg.CredentialRepo.Remove(ctx, userID)
if err != nil {
return Err(ErrRepo, err)
}
}
return nil
}
Expand All @@ -70,24 +98,22 @@ func (a *App) registerAdminUser(ctx context.Context, username blinkfile.Username
if username == "" {
return nil
}
creds, err := newPasswordCredentials(AdminUserID, username, password, a.cfg.PasswordHasher.Hash)
cred, err := newPasswordCredentials(AdminUserID, username, password, a.cfg.PasswordHasher.Hash)
if err != nil {
return err
}
err = a.registerUserCredentials(creds)
err = a.registerAdminCredentials(cred)
if err != nil {
return err
}
a.Printf(ctx, "Registered admin credentials for username %q", username)
return nil
}

var ErrUsernameTaken = fmt.Errorf("username already taken")

func (a *App) registerUserCredentials(creds Credentials) error {
if _, exists := a.credentials[creds.username]; exists {
func (a *App) registerAdminCredentials(cred Credentials) error {
if _, exists := a.adminCredentials[cred.Username]; exists {
return ErrUsernameTaken
}
a.credentials[creds.username] = creds
a.adminCredentials[cred.Username] = cred
return nil
}
47 changes: 46 additions & 1 deletion app/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,39 @@ func TestApp_CreateUser(t *testing.T) {
},
},
{
name: "should create a new user",
name: "should fail if password credential is too short",
args: app.CreateUserArgs{
Username: "user1",
Password: "",
},
wantErr: &app.Error{
Type: app.ErrBadRequest,
Title: "Error creating user credentials",
Detail: fmt.Sprintf("Credential error: %s", app.ErrPasswordTooShort),
Err: app.ErrPasswordTooShort,
},
},
{
name: "should fail if credentials cannot be stored",
cfg: app.Config{
CredentialRepo: &StubCredentialRepo{SetFunc: func(context.Context, app.Credentials) error {
return fmt.Errorf("cred repo err")
}},
},
args: app.CreateUserArgs{
Username: "user1",
Password: "1234567812345678",
},
wantErr: &app.Error{
Type: app.ErrRepo,
Err: fmt.Errorf("cred repo err"),
},
},
{
name: "should successfully create a user",
args: app.CreateUserArgs{
Username: "user1",
Password: "1234567812345678",
},
},
}
Expand Down Expand Up @@ -181,6 +211,21 @@ func TestApp_DeleteUsers(t *testing.T) {
Err: fmt.Errorf("user repo delete err"),
},
},
{
name: "should fail if credential repo returns an error",
args: args{
userIDs: []blinkfile.UserID{"u1"},
},
cfg: app.Config{
CredentialRepo: &StubCredentialRepo{RemoveFunc: func(context.Context, blinkfile.UserID) error {
return fmt.Errorf("cred repo err")
}},
},
wantErr: &app.Error{
Type: app.ErrRepo,
Err: fmt.Errorf("cred repo err"),
},
},
{
name: "should delete a user",
args: args{
Expand Down
17 changes: 13 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ func run(ctx context.Context) (err error) {
return err
}

credentialRepo, err := repo.NewCredentialRepo(ctx, repo.CredentialRepoConfig{
Dir: fmt.Sprintf("%s/credentials", cfg.DataDir),
})
if err != nil {
return err
}

l := log.New(log.Config{GetRequestID: request.GetID})
l.Printf(ctx, "Running build %q", build)

Expand All @@ -80,6 +87,7 @@ func run(ctx context.Context) (err error) {
SessionRepo: sessionRepo,
FileRepo: fileRepo,
UserRepo: userRepo,
CredentialRepo: credentialRepo,
PasswordHasher: &hash.Argon2idDefault,
}

Expand All @@ -88,10 +96,11 @@ func run(ctx context.Context) (err error) {
l.Printf(ctx, "WARNING: Server running with test automation enabled! DO NOT RUN IN PRODUCTION!")
testClock := &testautomation.TestClock{}
automator = &testautomation.Automator{
Log: l,
Clock: testClock,
FileRepo: fileRepo,
UserRepo: userRepo,
Log: l,
Clock: testClock,
FileRepo: fileRepo,
UserRepo: userRepo,
CredentialRepo: credentialRepo,
}
appConfig.Clock = testClock
}
Expand Down

0 comments on commit e080c4f

Please sign in to comment.