Skip to content

Commit

Permalink
watchapi: Terminate RPCs when the manager shuts down
Browse files Browse the repository at this point in the history
Because we use GracefulStop, RPCs need to finish on their own when the
manager shuts down.

The watchapi server formerly was completely stateless, and didn't have a
notion of starting or stopping.

This adds Start and Stop functions that control a context, adapted from
the logbroker code. When the context is cancelled, outstanding RPCs will
finish.

There are also some related cleanups and fixes in the logbroker code.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
  • Loading branch information
aaronlehmann committed Jul 14, 2017
1 parent df54e4f commit 6a9946b
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 26 deletions.
37 changes: 23 additions & 14 deletions manager/logbroker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ func New(store *store.MemoryStore) *LogBroker {
}
}

// Run the log broker
func (lb *LogBroker) Run(ctx context.Context) error {
// Start starts the log broker
func (lb *LogBroker) Start(ctx context.Context) error {
lb.mu.Lock()
defer lb.mu.Unlock()

if lb.cancelAll != nil {
lb.mu.Unlock()
return errAlreadyRunning
}

Expand All @@ -71,12 +71,7 @@ func (lb *LogBroker) Run(ctx context.Context) error {
lb.subscriptionQueue = watch.NewQueue()
lb.registeredSubscriptions = make(map[string]*subscription)
lb.subscriptionsByNode = make(map[string]map[*subscription]struct{})
lb.mu.Unlock()

select {
case <-lb.pctx.Done():
return lb.pctx.Err()
}
return nil
}

// Stop stops the log broker
Expand Down Expand Up @@ -234,8 +229,15 @@ func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api
return err
}

lb.mu.Lock()
pctx := lb.pctx
lb.mu.Unlock()
if pctx == nil {
return errNotRunning
}

subscription := lb.newSubscription(request.Selector, request.Options)
subscription.Run(lb.pctx)
subscription.Run(pctx)
defer subscription.Stop()

log := log.G(ctx).WithFields(
Expand All @@ -257,8 +259,8 @@ func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api
select {
case <-ctx.Done():
return ctx.Err()
case <-lb.pctx.Done():
return lb.pctx.Err()
case <-pctx.Done():
return pctx.Err()
case event := <-publishCh:
publish := event.(*logMessage)
if publish.completed {
Expand Down Expand Up @@ -308,6 +310,13 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest
return err
}

lb.mu.Lock()
pctx := lb.pctx
lb.mu.Unlock()
if pctx == nil {
return errNotRunning
}

lb.nodeConnected(remote.NodeID)
defer lb.nodeDisconnected(remote.NodeID)

Expand All @@ -329,7 +338,7 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-lb.pctx.Done():
case <-pctx.Done():
return nil
default:
}
Expand Down Expand Up @@ -362,7 +371,7 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest
}
case <-stream.Context().Done():
return stream.Context().Err()
case <-lb.pctx.Done():
case <-pctx.Done():
return nil
}
}
Expand Down
6 changes: 3 additions & 3 deletions manager/logbroker/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ func TestLogBrokerLogs(t *testing.T) {

wg.Wait()

// Make sure double Run throws an error
require.EqualError(t, broker.Run(ctx), errAlreadyRunning.Error())
// Make sure double Start throws an error
require.EqualError(t, broker.Start(ctx), errAlreadyRunning.Error())
// Stop should work
require.NoError(t, broker.Stop())
// Double stopping should fail
Expand Down Expand Up @@ -780,7 +780,7 @@ func testLogBrokerEnv(t *testing.T) (context.Context, *testutils.TestCA, *LogBro
}
}()

go broker.Run(ctx)
require.NoError(t, broker.Start(ctx))

return ctx, tca, broker, logListener.Addr().String(), brokerListener.Addr().String(), func() {
broker.Stop()
Expand Down
20 changes: 12 additions & 8 deletions manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type Manager struct {
caserver *ca.Server
dispatcher *dispatcher.Dispatcher
logbroker *logbroker.LogBroker
watchServer *watchapi.Server
replicatedOrchestrator *replicated.Orchestrator
globalOrchestrator *global.Orchestrator
taskReaper *taskreaper.TaskReaper
Expand Down Expand Up @@ -220,6 +221,7 @@ func New(config *Config) (*Manager, error) {
caserver: ca.NewServer(raftNode.MemoryStore(), config.SecurityConfig, config.RootCAPaths),
dispatcher: dispatcher.New(raftNode, dispatcher.DefaultConfig()),
logbroker: logbroker.New(raftNode.MemoryStore()),
watchServer: watchapi.NewServer(raftNode.MemoryStore()),
server: grpc.NewServer(opts...),
localserver: grpc.NewServer(opts...),
raftNode: raftNode,
Expand Down Expand Up @@ -397,13 +399,12 @@ func (m *Manager) Run(parent context.Context) error {
}

baseControlAPI := controlapi.NewServer(m.raftNode.MemoryStore(), m.raftNode, m.config.SecurityConfig, m.caserver, m.config.PluginGetter)
baseWatchAPI := watchapi.NewServer(m.raftNode.MemoryStore())
baseResourceAPI := resourceapi.New(m.raftNode.MemoryStore())
healthServer := health.NewHealthServer()
localHealthServer := health.NewHealthServer()

authenticatedControlAPI := api.NewAuthenticatedWrapperControlServer(baseControlAPI, authorize)
authenticatedWatchAPI := api.NewAuthenticatedWrapperWatchServer(baseWatchAPI, authorize)
authenticatedWatchAPI := api.NewAuthenticatedWrapperWatchServer(m.watchServer, authorize)
authenticatedResourceAPI := api.NewAuthenticatedWrapperResourceAllocatorServer(baseResourceAPI, authorize)
authenticatedLogsServerAPI := api.NewAuthenticatedWrapperLogsServer(m.logbroker, authorize)
authenticatedLogBrokerAPI := api.NewAuthenticatedWrapperLogBrokerServer(m.logbroker, authorize)
Expand Down Expand Up @@ -476,7 +477,7 @@ func (m *Manager) Run(parent context.Context) error {
grpc_prometheus.Register(m.server)

api.RegisterControlServer(m.localserver, localProxyControlAPI)
api.RegisterWatchServer(m.localserver, baseWatchAPI)
api.RegisterWatchServer(m.localserver, m.watchServer)
api.RegisterLogsServer(m.localserver, localProxyLogsAPI)
api.RegisterHealthServer(m.localserver, localHealthServer)
api.RegisterDispatcherServer(m.localserver, localProxyDispatcherAPI)
Expand Down Expand Up @@ -1000,11 +1001,13 @@ func (m *Manager) becomeLeader(ctx context.Context) {
}
}(m.dispatcher)

go func(lb *logbroker.LogBroker) {
if err := lb.Run(ctx); err != nil {
log.G(ctx).WithError(err).Error("LogBroker exited with an error")
}
}(m.logbroker)
if err := m.logbroker.Start(ctx); err != nil {
log.G(ctx).WithError(err).Error("LogBroker failed to start")
}

if err := m.watchServer.Start(ctx); err != nil {
log.G(ctx).WithError(err).Error("watch server failed to start")
}

go func(server *ca.Server) {
if err := server.Run(ctx); err != nil {
Expand Down Expand Up @@ -1058,6 +1061,7 @@ func (m *Manager) becomeLeader(ctx context.Context) {
func (m *Manager) becomeFollower() {
m.dispatcher.Stop()
m.logbroker.Stop()
m.watchServer.Stop()
m.caserver.Stop()

if m.allocator != nil {
Expand Down
41 changes: 40 additions & 1 deletion manager/watchapi/server.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
package watchapi

import (
"errors"
"sync"

"github.com/docker/swarmkit/manager/state/store"
"golang.org/x/net/context"
)

var (
errAlreadyRunning = errors.New("broker is already running")
errNotRunning = errors.New("broker is not running")
)

// Server is the store API gRPC server.
type Server struct {
store *store.MemoryStore
store *store.MemoryStore
mu sync.Mutex
pctx context.Context
cancelAll func()
}

// NewServer creates a store API server.
Expand All @@ -15,3 +27,30 @@ func NewServer(store *store.MemoryStore) *Server {
store: store,
}
}

// Start starts the watch server.
func (s *Server) Start(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

if s.cancelAll != nil {
return errAlreadyRunning
}

s.pctx, s.cancelAll = context.WithCancel(ctx)
return nil
}

// Stop stops the watch server.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()

if s.cancelAll == nil {
return errNotRunning
}
s.cancelAll()
s.cancelAll = nil

return nil
}
5 changes: 5 additions & 0 deletions manager/watchapi/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/docker/swarmkit/manager/state/store"
stateutils "github.com/docker/swarmkit/manager/state/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/grpclog"
)
Expand All @@ -30,6 +32,7 @@ type testServer struct {
}

func (ts *testServer) Stop() {
ts.Server.Stop()
ts.clientConn.Close()
ts.grpcServer.Stop()
ts.Store.Close()
Expand All @@ -48,6 +51,8 @@ func newTestServer(t *testing.T) *testServer {
ts.Server = NewServer(ts.Store)
assert.NotNil(t, ts.Server)

require.NoError(t, ts.Server.Start(context.Background()))

temp, err := ioutil.TempFile("", "test-socket")
assert.NoError(t, err)
assert.NoError(t, temp.Close())
Expand Down
9 changes: 9 additions & 0 deletions manager/watchapi/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ import (
func (s *Server) Watch(request *api.WatchRequest, stream api.Watch_WatchServer) error {
ctx := stream.Context()

s.mu.Lock()
pctx := s.pctx
s.mu.Unlock()
if pctx == nil {
return errNotRunning
}

watchArgs, err := api.ConvertWatchArgs(request.Entries)
if err != nil {
return grpc.Errorf(codes.InvalidArgument, "%s", err.Error())
Expand All @@ -39,6 +46,8 @@ func (s *Server) Watch(request *api.WatchRequest, stream api.Watch_WatchServer)
select {
case <-ctx.Done():
return ctx.Err()
case <-pctx.Done():
return pctx.Err()
case event := <-watch:
if commitEvent, ok := event.(state.EventCommit); ok && len(events) > 0 {
if err := stream.Send(&api.WatchMessage{Events: events, Version: commitEvent.Version}); err != nil {
Expand Down

0 comments on commit 6a9946b

Please sign in to comment.