Skip to content

Commit

Permalink
Refactor server state management
Browse files Browse the repository at this point in the history
  • Loading branch information
hibiken committed May 31, 2020
1 parent 69ad583 commit a38f628
Show file tree
Hide file tree
Showing 12 changed files with 477 additions and 463 deletions.
113 changes: 99 additions & 14 deletions heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
package asynq

import (
"os"
"sync"
"time"

"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/log"
"github.com/rs/xid"
)

// heartbeater is responsible for writing process info to redis periodically to
Expand All @@ -18,29 +20,69 @@ type heartbeater struct {
logger *log.Logger
broker base.Broker

ss *base.ServerState

// channel to communicate back to the long running "heartbeater" goroutine.
done chan struct{}

// interval between heartbeats.
interval time.Duration

// following fields are initialized at construction time and are immutable.
host string
pid int
serverID string
concurrency int
queues map[string]int
strictPriority bool

// following fields are mutable and should be accessed only by the
// heartbeater goroutine. In other words, confine these variables
// to this goroutine only.
started time.Time
workers map[string]workerStat

// status is shared with other goroutine but is concurrency safe.
status *base.ServerStatus

// channels to receive updates on active workers.
starting <-chan *base.TaskMessage
finished <-chan *base.TaskMessage
}

type heartbeaterParams struct {
logger *log.Logger
broker base.Broker
serverState *base.ServerState
interval time.Duration
logger *log.Logger
broker base.Broker
interval time.Duration
concurrency int
queues map[string]int
strictPriority bool
status *base.ServerStatus
starting <-chan *base.TaskMessage
finished <-chan *base.TaskMessage
}

func newHeartbeater(params heartbeaterParams) *heartbeater {
host, err := os.Hostname()
if err != nil {
host = "unknown-host"
}

return &heartbeater{
logger: params.logger,
broker: params.broker,
ss: params.serverState,
done: make(chan struct{}),
interval: params.interval,

host: host,
pid: os.Getpid(),
serverID: xid.New().String(),
concurrency: params.concurrency,
queues: params.queues,
strictPriority: params.strictPriority,

status: params.status,
workers: make(map[string]workerStat),
starting: params.starting,
finished: params.finished,
}
}

Expand All @@ -50,31 +92,74 @@ func (h *heartbeater) terminate() {
h.done <- struct{}{}
}

// A workerStat records the message a worker is working on
// and the time the worker has started processing the message.
type workerStat struct {
started time.Time
msg *base.TaskMessage
}

func (h *heartbeater) start(wg *sync.WaitGroup) {
h.ss.SetStarted(time.Now())
h.ss.SetStatus(base.StatusRunning)
wg.Add(1)
go func() {
defer wg.Done()

h.started = time.Now()

h.beat()

timer := time.NewTimer(h.interval)
for {
select {
case <-h.done:
h.broker.ClearServerState(h.ss)
h.broker.ClearServerState(h.host, h.pid, h.serverID)
h.logger.Debug("Heartbeater done")
timer.Stop()
return
case <-time.After(h.interval):

case <-timer.C:
h.beat()
timer.Reset(h.interval)

case msg := <-h.starting:
h.workers[msg.ID.String()] = workerStat{time.Now(), msg}

case msg := <-h.finished:
delete(h.workers, msg.ID.String())
}
}
}()
}

func (h *heartbeater) beat() {
info := base.ServerInfo{
Host: h.host,
PID: h.pid,
ServerID: h.serverID,
Concurrency: h.concurrency,
Queues: h.queues,
StrictPriority: h.strictPriority,
Status: h.status.String(),
Started: h.started,
ActiveWorkerCount: len(h.workers),
}

var ws []*base.WorkerInfo
for id, stat := range h.workers {
ws = append(ws, &base.WorkerInfo{
Host: h.host,
PID: h.pid,
ID: id,
Type: stat.msg.Type,
Queue: stat.msg.Queue,
Payload: stat.msg.Payload,
Started: stat.started,
})
}

// Note: Set TTL to be long enough so that it won't expire before we write again
// and short enough to expire quickly once the process is shut down or killed.
err := h.broker.WriteServerState(h.ss, h.interval*2)
if err != nil {
h.logger.Errorf("could not write heartbeat data: %v", err)
if err := h.broker.WriteServerState(&info, ws, h.interval*2); err != nil {
h.logger.Errorf("could not write server state data: %v", err)
}
}
36 changes: 25 additions & 11 deletions heartbeat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,24 @@ func TestHeartbeater(t *testing.T) {
for _, tc := range tests {
h.FlushDB(t, r)

state := base.NewServerState(tc.host, tc.pid, tc.concurrency, tc.queues, false)
status := base.NewServerStatus(base.StatusIdle)
hb := newHeartbeater(heartbeaterParams{
logger: testLogger,
broker: rdbClient,
serverState: state,
interval: tc.interval,
logger: testLogger,
broker: rdbClient,
interval: tc.interval,
concurrency: tc.concurrency,
queues: tc.queues,
strictPriority: false,
status: status,
starting: make(chan *base.TaskMessage),
finished: make(chan *base.TaskMessage),
})

// Change host and pid fields for testing purpose.
hb.host = tc.host
hb.pid = tc.pid

status.Set(base.StatusRunning)
var wg sync.WaitGroup
hb.start(&wg)

Expand Down Expand Up @@ -80,7 +90,7 @@ func TestHeartbeater(t *testing.T) {
}

// status change
state.SetStatus(base.StatusStopped)
status.Set(base.StatusStopped)

// allow for heartbeater to write to redis
time.Sleep(tc.interval * 2)
Expand Down Expand Up @@ -119,12 +129,16 @@ func TestHeartbeaterWithRedisDown(t *testing.T) {
}()
r := rdb.NewRDB(setup(t))
testBroker := testbroker.NewTestBroker(r)
ss := base.NewServerState("localhost", 1234, 10, map[string]int{"default": 1}, false)
hb := newHeartbeater(heartbeaterParams{
logger: testLogger,
broker: testBroker,
serverState: ss,
interval: time.Second,
logger: testLogger,
broker: testBroker,
interval: time.Second,
concurrency: 10,
queues: map[string]int{"default": 1},
strictPriority: false,
status: base.NewServerStatus(base.StatusRunning),
starting: make(chan *base.TaskMessage),
finished: make(chan *base.TaskMessage),
})

testBroker.Sleep()
Expand Down
2 changes: 1 addition & 1 deletion internal/asynqtest/asynqtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var SortServerInfoOpt = cmp.Transformer("SortServerInfo", func(in []*base.Server
var SortWorkerInfoOpt = cmp.Transformer("SortWorkerInfo", func(in []*base.WorkerInfo) []*base.WorkerInfo {
out := append([]*base.WorkerInfo(nil), in...) // Copy input to avoid mutating it
sort.Slice(out, func(i, j int) bool {
return out[i].ID.String() < out[j].ID.String()
return out[i].ID < out[j].ID
})
return out
})
Expand Down
Loading

0 comments on commit a38f628

Please sign in to comment.