Skip to content

Commit

Permalink
[fix]: task client blocked when exiting
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingCrush committed Sep 19, 2021
1 parent fc86b2a commit 4c60b87
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 24 deletions.
18 changes: 13 additions & 5 deletions app/broker/runtime.go
Expand Up @@ -156,11 +156,11 @@ func (r *runtime) Run() error {
return err
}

tackClientFct := rpc.NewTaskClientFactory(r.node)
tackClientFct := rpc.NewTaskClientFactory(r.ctx, r.node)
r.factory = factory{
taskClient: tackClientFct,
taskServer: rpc.NewTaskServerFactory(),
connectionMgr: rpc.NewConnectionManager(tackClientFct), //TODO close connections
connectionMgr: rpc.NewConnectionManager(tackClientFct),
}

r.buildServiceDependency()
Expand All @@ -170,8 +170,11 @@ func (r *runtime) Run() error {

discoveryFactory := discovery.NewFactory(r.repo)

r.stateMgr = broker.NewStateManager(r.ctx, *r.node,
r.factory.connectionMgr, r.factory.taskClient,
r.stateMgr = broker.NewStateManager(
r.ctx,
*r.node,
r.factory.connectionMgr,
r.factory.taskClient,
r.srv.channelManager)
// finally start all state machine
r.stateMachineFactory = newStateMachineFactory(r.ctx, discoveryFactory, r.stateMgr)
Expand Down Expand Up @@ -270,6 +273,11 @@ func (r *runtime) Stop() {
r.log.Info("closed write channel successfully")
}

if r.factory.connectionMgr != nil {
_ = r.factory.connectionMgr.Close()
}
r.log.Info("close connections successfully")

// finally shutdown rpc server
if r.grpcServer != nil {
r.log.Info("stopping grpc server...")
Expand Down Expand Up @@ -336,7 +344,7 @@ func (r *runtime) buildServiceDependency() {
// todo watch stateMachine states change.

// hard code create channel first.
cm := replica.NewChannelManager(r.ctx, rpc.NewClientStreamFactory(r.node))
cm := replica.NewChannelManager(r.ctx, rpc.NewClientStreamFactory(r.ctx, r.node))

taskManager := brokerQuery.NewTaskManager(
r.ctx,
Expand Down
7 changes: 6 additions & 1 deletion app/storage/runtime.go
Expand Up @@ -215,6 +215,11 @@ func (r *runtime) MustRegisterStateFulNode() error {
)
// sometimes lease isn't expired when storage restarts, retry registering is necessary
for attempt := 1; attempt <= maxRetries; attempt++ {
select {
case <-r.ctx.Done(): // no more retries when context is done
return nil
default:
}
ok, _, err = r.repo.Elect(
r.ctx,
constants.GetLiveNodePath(strconv.Itoa(int(r.node.ID))),
Expand Down Expand Up @@ -389,7 +394,7 @@ func (r *runtime) bindRPCHandlers() {
r.ctx,
r.config.StorageBase.WAL,
r.node.ID, r.engine,
rpc.NewClientStreamFactory(r.node),
rpc.NewClientStreamFactory(r.ctx, r.node),
r.stateMgr,
)
r.rpcHandler = &rpcHandler{
Expand Down
9 changes: 6 additions & 3 deletions rpc/rpc.go
Expand Up @@ -115,13 +115,15 @@ type ClientStreamFactory interface {

// clientStreamFactory implements ClientStreamFactory.
type clientStreamFactory struct {
ctx context.Context
logicNode models.Node
connFct ClientConnFactory
}

// NewClientStreamFactory returns a factory to get clientStream.
func NewClientStreamFactory(logicNode models.Node) ClientStreamFactory {
func NewClientStreamFactory(ctx context.Context, logicNode models.Node) ClientStreamFactory {
return &clientStreamFactory{
ctx: ctx,
logicNode: logicNode,
connFct: GetClientConnFactory(),
}
Expand All @@ -140,8 +142,9 @@ func (w *clientStreamFactory) CreateTaskClient(target models.Node) (protoCommonV
}

node := w.LogicNode()
//TODO handle context?????
ctx := CreateOutgoingContextWithPairs(context.TODO(), constants.RPCMetaKeyLogicNode, node.Indicator())
// https://pkg.go.dev/google.golang.org/grpc#ClientConn.NewStream
// context is the lifetime of stream
ctx := CreateOutgoingContextWithPairs(w.ctx, constants.RPCMetaKeyLogicNode, node.Indicator())
cli, err := protoCommonV1.NewTaskServiceClient(conn).Handle(ctx)
return cli, err
}
Expand Down
6 changes: 5 additions & 1 deletion rpc/rpc_test.go
Expand Up @@ -18,6 +18,7 @@
package rpc

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -68,7 +69,10 @@ func TestClientStreamFactory_CreateTaskClient(t *testing.T) {

handler := protoCommonV1.NewMockTaskServiceServer(ctrl)

factory := NewClientStreamFactory(&models.StatelessNode{HostIP: "127.0.0.2", GRPCPort: 9000})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

factory := NewClientStreamFactory(ctx, &models.StatelessNode{HostIP: "127.0.0.2", GRPCPort: 9000})
target := models.StatelessNode{HostIP: "127.0.0.1", GRPCPort: 9000}

client, err := factory.CreateTaskClient(&target)
Expand Down
35 changes: 23 additions & 12 deletions rpc/task_transport.go
Expand Up @@ -33,8 +33,6 @@ import (

//go:generate mockgen -source ./task_transport.go -destination=./task_transport_mock.go -package=rpc

var log = logger.GetLogger("rpc", "TaskClient")

// TaskClientFactory represents the task stream manage
type TaskClientFactory interface {
// CreateTaskClient creates a task client stream if not exist
Expand All @@ -57,6 +55,7 @@ type taskClient struct {

// taskClientFactory implements TaskClientFactory interface
type taskClientFactory struct {
ctx context.Context
currentNode models.Node
taskReceiver TaskReceiver
// target node ID => client stream
Expand All @@ -65,15 +64,18 @@ type taskClientFactory struct {

newTaskServiceClientFunc func(cc *grpc.ClientConn) protoCommonV1.TaskServiceClient
connFct ClientConnFactory
logger *logger.Logger
}

// NewTaskClientFactory creates a task client factory
func NewTaskClientFactory(currentNode models.Node) TaskClientFactory {
func NewTaskClientFactory(ctx context.Context, currentNode models.Node) TaskClientFactory {
return &taskClientFactory{
ctx: ctx,
currentNode: currentNode,
connFct: GetClientConnFactory(),
taskStreams: make(map[string]*taskClient),
newTaskServiceClientFunc: protoCommonV1.NewTaskServiceClient,
logger: logger.GetLogger("rpc", "TaskClient"),
}
}

Expand Down Expand Up @@ -140,7 +142,7 @@ func (f *taskClientFactory) initTaskClient(client *taskClient) error {

if client.cli != nil {
if err := client.cli.CloseSend(); err != nil {
log.Error("close task client error", logger.Error(err))
f.logger.Error("close task client error", logger.Error(err))
}
client.cli = nil
}
Expand All @@ -163,22 +165,29 @@ func (f *taskClientFactory) initTaskClient(client *taskClient) error {
func (f *taskClientFactory) handleTaskResponse(client *taskClient) {
var attempt int32 = 0
for client.running.Load() {
select {
case <-f.ctx.Done():
// if client is not ready, this goroutine may be blocked without ctx.Done()
return
default:
}

if !client.ready.Load() {
attempt++
log.Info("initializing task client",
f.logger.Info("initializing task client",
logger.String("target", client.targetID),
logger.Int32("attempt", attempt),
)
if err := f.initTaskClient(client); err != nil {
log.Error("failed to initialize task client",
f.logger.Error("failed to initialize task client",
logger.Error(err),
logger.String("target", client.targetID),
logger.Int32("attempt", attempt),
)
time.Sleep(time.Second)
continue
} else {
log.Info("initialized task client successfully",
f.logger.Info("initialized task client successfully",
logger.String("target", client.targetID),
logger.Int32("attempt", attempt))
client.ready.Store(true)
Expand All @@ -187,13 +196,13 @@ func (f *taskClientFactory) handleTaskResponse(client *taskClient) {
resp, err := client.cli.Recv()
if err != nil {
client.ready.Store(false)
log.Error("receive task error from stream", logger.Error(err))
// todo: suppress errors before shard assignment
f.logger.Error("receive task error from stream", logger.Error(err))
continue
}

err = f.taskReceiver.Receive(resp, client.targetID)
if err != nil {
log.Error("receive task response",
if err = f.taskReceiver.Receive(resp, client.targetID); err != nil {
f.logger.Error("receive task response",
logger.String("taskID", resp.TaskID),
logger.String("taskType", resp.Type.String()),
logger.Error(err))
Expand Down Expand Up @@ -223,12 +232,14 @@ type taskServerFactory struct {
nodeMap map[string]*taskService
epoch atomic.Int64
lock sync.RWMutex
logger *logger.Logger
}

// NewTaskServerFactory returns the singleton server stream factory
func NewTaskServerFactory() TaskServerFactory {
return &taskServerFactory{
nodeMap: make(map[string]*taskService),
logger: logger.GetLogger("rpc", "TaskServer"),
}
}

Expand Down Expand Up @@ -265,7 +276,7 @@ func (fct *taskServerFactory) Nodes() []models.Node {
for nodeID := range fct.nodeMap {
node, err := models.ParseNode(nodeID)
if err != nil {
log.Warn("parse node error", logger.Error(err))
fct.logger.Warn("parse node error", logger.Error(err))
continue
}
nodes = append(nodes, node)
Expand Down
8 changes: 6 additions & 2 deletions rpc/task_transport_test.go
Expand Up @@ -18,6 +18,7 @@
package rpc

import (
"context"
"fmt"
"testing"

Expand Down Expand Up @@ -72,7 +73,7 @@ func TestTaskClientFactory(t *testing.T) {
mockTaskClient.EXPECT().CloseSend().Return(fmt.Errorf("err")).AnyTimes()
taskService := protoCommonV1.NewMockTaskServiceClient(ctl)

fct := NewTaskClientFactory(&models.StatelessNode{HostIP: "127.0.0.1", GRPCPort: 123})
fct := NewTaskClientFactory(context.TODO(), &models.StatelessNode{HostIP: "127.0.0.1", GRPCPort: 123})
receiver := NewMockTaskReceiver(ctl)
receiver.EXPECT().Receive(gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")).AnyTimes()
fct.SetTaskReceiver(receiver)
Expand Down Expand Up @@ -120,8 +121,11 @@ func TestTaskClientFactory_handler(t *testing.T) {
ctrl.Finish()
}()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

receiver := NewMockTaskReceiver(ctrl)
fct := NewTaskClientFactory(&models.StatelessNode{HostIP: "127.0.0.1", GRPCPort: 123})
fct := NewTaskClientFactory(ctx, &models.StatelessNode{HostIP: "127.0.0.1", GRPCPort: 123})
fct.SetTaskReceiver(receiver)

target := models.StatelessNode{HostIP: "127.0.0.1", GRPCPort: 321}
Expand Down

0 comments on commit 4c60b87

Please sign in to comment.