Skip to content

Commit

Permalink
Fix goroutine leak with async session update (TykTechnologies#2237)
Browse files Browse the repository at this point in the history
Session update worker pool was startng for every "DefaultSessionStore" object. And it was happening for every API initialization. So if you have 1000 APIs, and your worker pool is 100, each API reload will trigger at least 100k goroutines.

This change maintains single worker pool which initilized only once.

Issue itself was introduced by TykTechnologies#1757

Fix TykTechnologies#2236
  • Loading branch information
buger committed Apr 27, 2019
1 parent ab2334a commit ef56da7
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 74 deletions.
144 changes: 70 additions & 74 deletions auth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"strings"
"sync"
"sync/atomic"
"time"

uuid "github.com/satori/go.uuid"
Expand Down Expand Up @@ -43,6 +42,72 @@ type SessionHandler interface {
const sessionPoolDefaultSize = 50
const sessionBufferDefaultSize = 1000

type sessionUpdater struct {
store storage.Handler
once sync.Once
updateChan chan *SessionUpdate
poolSize int
bufferSize int
keyPrefix string
}

var defaultSessionUpdater *sessionUpdater

func init() {
defaultSessionUpdater = &sessionUpdater{}
}

func (s *sessionUpdater) Init(store storage.Handler) {
s.once.Do(func() {
s.store = store
// check pool size in config and set to 50 if unset
s.poolSize = config.Global().SessionUpdatePoolSize
if s.poolSize <= 0 {
s.poolSize = sessionPoolDefaultSize
}
//check size for channel buffer and set to 1000 if unset
s.bufferSize = config.Global().SessionUpdateBufferSize
if s.bufferSize <= 0 {
s.bufferSize = sessionBufferDefaultSize
}

log.WithField("pool_size", s.poolSize).Debug("Session update async pool size")

s.updateChan = make(chan *SessionUpdate, s.bufferSize)

s.keyPrefix = s.store.GetKeyPrefix()

for i := 0; i < s.poolSize; i++ {
go s.updateWorker()
}
})
}

func (s *sessionUpdater) updateWorker() {
for u := range s.updateChan {
v, err := json.Marshal(u.session)
if err != nil {
log.WithError(err).Error("Error marshalling session for async session update")
continue
}

if u.isHashed {
u.keyVal = s.keyPrefix + u.keyVal
err := s.store.SetRawKey(u.keyVal, string(v), u.ttl)
if err != nil {
log.WithError(err).Error("Error updating hashed key")
}
continue

}

err = s.store.SetKey(u.keyVal, string(v), u.ttl)
if err != nil {
log.WithError(err).Error("Error updating key")
}
}
}

// DefaultAuthorisationManager implements AuthorisationHandler,
// requires a storage.Handler to interact with key store
type DefaultAuthorisationManager struct {
Expand All @@ -53,12 +118,6 @@ type DefaultSessionManager struct {
store storage.Handler
asyncWrites bool
disableCacheSessionState bool
updateChan chan *SessionUpdate
poolSize int
shouldStop uint32
poolWG sync.WaitGroup
bufferSize int
keyPrefix string
}

type SessionUpdate struct {
Expand Down Expand Up @@ -114,68 +173,7 @@ func (b *DefaultSessionManager) Init(store storage.Handler) {
}

if b.asyncWrites {
// check pool size in config and set to 50 if unset
b.poolSize = config.Global().SessionUpdatePoolSize
if b.poolSize <= 0 {
b.poolSize = sessionPoolDefaultSize
}
//check size for channel buffer and set to 1000 if unset
b.bufferSize = config.Global().SessionUpdateBufferSize
if b.bufferSize <= 0 {
b.bufferSize = sessionBufferDefaultSize
}

log.WithField("SessionManager poolsize", b.poolSize).Debug("Session update async pool size")

b.updateChan = make(chan *SessionUpdate, b.bufferSize)

b.keyPrefix = b.store.GetKeyPrefix()

//start worker pool
atomic.SwapUint32(&b.shouldStop, 0)
for i := 0; i < b.poolSize; i++ {
b.poolWG.Add(1)
go b.updateWorker()
}
}
}

func (b *DefaultSessionManager) updateWorker() {
defer b.poolWG.Done()

for u := range b.updateChan {

v, err := json.Marshal(u.session)
if err != nil {
log.WithError(err).Error("Error marshalling session for async session update")
continue
}

if u.isHashed {
u.keyVal = b.keyPrefix + u.keyVal
err := b.store.SetRawKey(u.keyVal, string(v), u.ttl)
if err != nil {
log.WithError(err).Error("Error updating hashed key")
}
continue

}

err = b.store.SetKey(u.keyVal, string(v), u.ttl)
if err != nil {
log.WithError(err).Error("Error updating key")
}
}
}

func (b *DefaultSessionManager) Stop() {
if atomic.LoadUint32(&b.shouldStop) == 0 {
// flag to stop adding data to chan
atomic.SwapUint32(&b.shouldStop, 1)
// close update channel
close(b.updateChan)
// wait for workers to finish
b.poolWG.Wait()
defaultSessionUpdater.Init(store)
}
}

Expand Down Expand Up @@ -215,10 +213,6 @@ func (b *DefaultSessionManager) UpdateSession(keyName string, session *user.Sess

// async update and return if needed
if b.asyncWrites {
if atomic.LoadUint32(&b.shouldStop) > 0 {
return nil
}

sessionUpdate := &SessionUpdate{
isHashed: hashed,
keyVal: keyName,
Expand All @@ -227,7 +221,7 @@ func (b *DefaultSessionManager) UpdateSession(keyName string, session *user.Sess
}

// send sessionupdate object through channel to pool
b.updateChan <- sessionUpdate
defaultSessionUpdater.updateChan <- sessionUpdate

return nil
}
Expand Down Expand Up @@ -288,6 +282,8 @@ func (b *DefaultSessionManager) SessionDetail(keyName string, hashed bool) (user
return session, true
}

func (b *DefaultSessionManager) Stop() {}

// Sessions returns all sessions in the key store that match a filter key (a prefix)
func (b *DefaultSessionManager) Sessions(filter string) []string {
return b.store.GetKeys(filter)
Expand Down
26 changes: 26 additions & 0 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -911,6 +912,31 @@ func TestListenPathTykPrefix(t *testing.T) {
})
}

func TestReloadGoroutineLeakWithAsyncWrites(t *testing.T) {
ts := newTykTestServer()
defer ts.Close()

globalConf := config.Global()
globalConf.UseAsyncSessionWrite = true
config.SetGlobal(globalConf)
defer resetTestConfig()

buildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/"
})

before := runtime.NumGoroutine()
doReload()

time.Sleep(100 * time.Millisecond)

after := runtime.NumGoroutine()

if before != after {
t.Errorf("Goroutine leak, was: %d, after reload: %d", before, after)
}
}

func TestProxyUserAgent(t *testing.T) {
ts := newTykTestServer()
defer ts.Close()
Expand Down

0 comments on commit ef56da7

Please sign in to comment.