Skip to content

Commit

Permalink
[VIAM-660] Fix memory leaks on kratos anti-password brute forcing pro…
Browse files Browse the repository at this point in the history
…tection (ory#44)
  • Loading branch information
mcjimenez committed Jul 8, 2023
1 parent cb38a91 commit 9f1a520
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions selfservice/strategy/password/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"context"
"encoding/json"
"net/http"
"sync"
"time"

"github.com/Vonage/go-viam-utils/utils"
errors2 "github.com/ory/kratos/schema/errors"
"github.com/ory/kratos/selfservice/flowhelpers"

Expand Down Expand Up @@ -57,8 +57,21 @@ type passCheckStatus struct {
numTries uint
}

var passCheckCache = make(map[string]passCheckStatus, 10000)
var cacheMutex sync.RWMutex
// This is the validity checker for cache elements. A element is valid if it hasn't expired yet.
func inWindow(v interface{}) bool {
return v.(*passCheckStatus).checkExpiresAt.After(time.Now())
}

// passCheckCache is ExpiringCache[string, passCheckStatus)
var passCheckCache = utils.NewCheckedExpiringCache(inWindow, delayReset, 10000).Start()

func passCheckCacheGet(id string) (passCheckStatus, bool) {
asInt, exists := passCheckCache.SyncGet(id)
if exists {
return asInt.(passCheckStatus), true
}
return passCheckStatus{}, false
}

func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, identityID uuid.UUID) (i *identity.Identity, err error) {
if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil {
Expand All @@ -81,21 +94,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, s.handleLoginError(w, r, f, &p, err)
}

cacheMutex.RLock()
lastCheckResult, exists := passCheckCache[p.Identifier]
cacheMutex.RUnlock()
if exists && lastCheckResult.checkExpiresAt.Before(time.Now()) {
cacheMutex.Lock()
delete(passCheckCache, p.Identifier)
cacheMutex.Unlock()
exists = false
}
lastCheckResult, exists := passCheckCacheGet(p.Identifier)
if exists && lastCheckResult.numTries >= delayAfterNumTries {
expireAt := lastCheckResult.checkExpiresAt
time.Sleep(delayTry)
cacheMutex.RLock()
lastCheckResult, exists = passCheckCache[p.Identifier]
cacheMutex.RUnlock()
lastCheckResult, exists = passCheckCacheGet(p.Identifier)
if exists && !expireAt.Equal(lastCheckResult.checkExpiresAt) {
time.Sleep(delayTry) // Note that this will probably mean the request will time out. Too bad, so sad.
}
Expand All @@ -104,12 +107,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
id, hashedPwd, exists, cacheEnabled := s.d.CheckPwdCache(p.Identifier)

invalidUserDelay := func() {
cacheMutex.Lock()
passCheckCache[p.Identifier] = passCheckStatus{
passCheckCache.SyncSet(p.Identifier, passCheckStatus{
checkExpiresAt: time.Now().Add(delayReset),
numTries: lastCheckResult.numTries + 1,
}
cacheMutex.Unlock()
})
time.Sleep(x.RandomDelay(s.d.Config().HasherArgon2(r.Context()).ExpectedDuration, s.d.Config().HasherArgon2(r.Context()).ExpectedDeviation))
i = nil
err = s.handleLoginError(w, r, f, &p, errors.WithStack(errors2.NewInvalidCredentialsError()))
Expand All @@ -133,12 +134,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

if err := hash.Compare(r.Context(), []byte(p.Password), []byte(hashedPwd)); err != nil {
cacheMutex.Lock()
passCheckCache[p.Identifier] = passCheckStatus{
passCheckCache.SyncSet(p.Identifier, passCheckStatus{
checkExpiresAt: time.Now().Add(delayReset),
numTries: lastCheckResult.numTries + 1,
}
cacheMutex.Unlock()
})
return nil, s.handleLoginError(w, r, f, &p, errors.WithStack(errors2.NewInvalidCredentialsError()))
}

Expand All @@ -156,9 +155,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
if err = s.d.LoginFlowPersister().UpdateLoginFlow(r.Context(), f); err != nil {
return nil, s.handleLoginError(w, r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReason("Could not update flow").WithDebug(err.Error())))
}
cacheMutex.Lock()
delete(passCheckCache, p.Identifier)
cacheMutex.Unlock()
passCheckCache.SyncRemove(p.Identifier)
if i == nil {
userId, _ := uuid.FromString(id)
i, err = s.d.PrivilegedIdentityPool().GetIdentity(r.Context(), userId, identity.ExpandDefault)
Expand Down

0 comments on commit 9f1a520

Please sign in to comment.