Skip to content

Commit

Permalink
Fix race condition in Queue.Protect
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgeniy Gavryushin committed Mar 23, 2022
1 parent c136c33 commit b44ba4a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ var seleniumPaths = struct {

func selenium() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc(seleniumPaths.CreateSession, post(queue.Try(queue.Check(queue.Protect(create)))))
mux.HandleFunc(seleniumPaths.CreateSession, post(queue.Protect(create)))
mux.HandleFunc(seleniumPaths.ProxySession, proxy)
mux.HandleFunc(paths.Status, status)
mux.HandleFunc(paths.Welcome, welcome)
Expand Down
47 changes: 18 additions & 29 deletions protect/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,24 @@ type Queue struct {
used chan struct{}
}

// Try - when X-Selenoid-No-Wait header is set
// reply to client immediately if queue is full
func (q *Queue) Try(next http.HandlerFunc) http.HandlerFunc {
func (q *Queue) Protect(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
_, noWait := r.Header["X-Selenoid-No-Wait"]
tokenAcquired := false
select {
case q.limit <- struct{}{}:
<-q.limit
tokenAcquired = true
default:
tokenAcquired = false
}

if ! tokenAcquired {
_, noWait := r.Header["X-Selenoid-No-Wait"]
if noWait {
err := errors.New(http.StatusText(http.StatusTooManyRequests))
jsonerror.UnknownError(err).Encode(w)
return
}
}
next.ServeHTTP(w, r)
}
}

// Check - if queue disabled
func (q *Queue) Check(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
select {
case q.limit <- struct{}{}:
<-q.limit
default:
if q.disabled {
user, remote := util.RequestInfo(r)
log.Printf("[-] [QUEUE_IS_FULL] [%s] [%s]", user, remote)
Expand All @@ -54,27 +46,24 @@ func (q *Queue) Check(next http.HandlerFunc) http.HandlerFunc {
return
}
}
next.ServeHTTP(w, r)
}
}

// Protect - handler to control limit of sessions
func (q *Queue) Protect(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, remote := util.RequestInfo(r)
log.Printf("[-] [NEW_REQUEST] [%s] [%s]", user, remote)
s := time.Now()
go func() {
q.queued <- struct{}{}
}()
select {
case <-r.Context().Done():
<-q.queued
log.Printf("[-] [CLIENT_DISCONNECTED] [%s] [%s] [%s]", user, remote, time.Since(s))
return
case q.limit <- struct{}{}:
q.pending <- struct{}{}
if ! tokenAcquired {
select {
case <-r.Context().Done():
<-q.queued
log.Printf("[-] [CLIENT_DISCONNECTED] [%s] [%s] [%s]", user, remote, time.Since(s))
return
case q.limit <- struct{}{}:
// Do nothing
}
}
q.pending <- struct{}{}
<-q.queued
log.Printf("[-] [NEW_REQUEST_ACCEPTED] [%s] [%s]", user, remote)
next.ServeHTTP(w, r)
Expand Down
2 changes: 1 addition & 1 deletion utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestSumUsedTotalGreaterThanPending(t *testing.T) {
hf := func(_ http.ResponseWriter, _ *http.Request) {
time.Sleep(50 * time.Millisecond)
}
queuedHandlerFunc := queue.Try(queue.Check(queue.Protect(hf)))
queuedHandlerFunc := queue.Protect(hf)
mux := http.NewServeMux()
mux.HandleFunc("/", queuedHandlerFunc)

Expand Down

0 comments on commit b44ba4a

Please sign in to comment.