Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(abciclient): support timeouts in abci calls #749

Merged
merged 13 commits into from
Mar 7, 2024
7 changes: 5 additions & 2 deletions abci/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ type requestAndResponse struct {
*types.Request
*types.Response

mtx sync.Mutex
mtx sync.Mutex
// context for the request; we check if it's not expired before sending
ctx context.Context
signal chan struct{}
}

func makeReqRes(req *types.Request) *requestAndResponse {
func makeReqRes(ctx context.Context, req *types.Request) *requestAndResponse {
return &requestAndResponse{
Request: req,
Response: nil,
ctx: ctx,
signal: make(chan struct{}),
}
}
Expand Down
82 changes: 55 additions & 27 deletions abci/client/socket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type socketClient struct {
mustConnect bool
conn net.Conn

// Requests queue
reqQueue chan *requestAndResponse

mtx sync.Mutex
Expand Down Expand Up @@ -116,37 +117,62 @@ func (cli *socketClient) Error() error {

//----------------------------------------

// Add the request to the pending messages queue.
//
// If the context `ctx` is canceled, return ctx.Err().
func (cli *socketClient) enqueue(ctx context.Context, reqres *requestAndResponse) error {
select {
case <-ctx.Done():
return ctx.Err()
case cli.reqQueue <- reqres:
return nil
}
}

// Block until first request arrives, then return it.
//
// If the context `ctx` is canceled, return nil.
func (cli *socketClient) dequeue(ctx context.Context) *requestAndResponse {
select {
case item := <-cli.reqQueue:
return item
case <-ctx.Done():
return nil
}
}

func (cli *socketClient) sendRequestsRoutine(ctx context.Context, conn io.Writer) {
bw := bufio.NewWriter(conn)
for {
select {
case <-ctx.Done():
// dequeue will block until a message arrives
for reqres := cli.dequeue(ctx); reqres != nil && ctx.Err() == nil; reqres = cli.dequeue(ctx) {
if err := reqres.ctx.Err(); err != nil {
// request expired, skip it
cli.logger.Debug("abci.socketClient request expired, skipping", "req", reqres.Request.Value, "error", err)
continue
}

// N.B. We must track request before sending it out, otherwise the
// server may reply before we do it, and the receiver will fail for an
// unsolicited reply.
cli.trackRequest(reqres)

if err := types.WriteMessage(reqres.Request, bw); err != nil {
cli.stopForError(fmt.Errorf("write to buffer: %w", err))
return
case reqres := <-cli.reqQueue:
// N.B. We must enqueue before sending out the request, otherwise the
// server may reply before we do it, and the receiver will fail for an
// unsolicited reply.
cli.trackRequest(reqres)

if err := types.WriteMessage(reqres.Request, bw); err != nil {
cli.stopForError(fmt.Errorf("write to buffer: %w", err))
return
}
}

if err := bw.Flush(); err != nil {
cli.stopForError(fmt.Errorf("flush buffer: %w", err))
return
}
if err := bw.Flush(); err != nil {
cli.stopForError(fmt.Errorf("flush buffer: %w", err))
return
}
}

cli.logger.Debug("context canceled, stopping sendRequestsRoutine")
}

func (cli *socketClient) recvResponseRoutine(ctx context.Context, conn io.Reader) {
r := bufio.NewReader(conn)
for {
if ctx.Err() != nil {
return
}
for ctx.Err() == nil {
res := &types.Response{}

if err := types.ReadMessage(r, res); err != nil {
Expand All @@ -166,6 +192,8 @@ func (cli *socketClient) recvResponseRoutine(ctx context.Context, conn io.Reader
}
}
}

cli.logger.Debug("context canceled, stopping recvResponseRoutine")
}

func (cli *socketClient) trackRequest(reqres *requestAndResponse) {
Expand Down Expand Up @@ -209,15 +237,15 @@ func (cli *socketClient) doRequest(ctx context.Context, req *types.Request) (*ty
return nil, errors.New("client has stopped")
}

reqres := makeReqRes(req)

select {
case cli.reqQueue <- reqres:
case <-ctx.Done():
return nil, fmt.Errorf("can't queue req: %w", ctx.Err())
reqres := makeReqRes(ctx, req)
if err := cli.enqueue(ctx, reqres); err != nil {
return nil, err
}

// wait for response for our request
select {
case <-reqres.ctx.Done():
return nil, reqres.ctx.Err()
case <-reqres.signal:
if err := cli.Error(); err != nil {
return nil, err
Expand Down
120 changes: 120 additions & 0 deletions abci/client/socket_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package abciclient

import (
"context"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/dashpay/tenderdash/abci/server"
"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/abci/types/mocks"
"github.com/dashpay/tenderdash/libs/log"
)

// TestSocketClientTimeout tests that the socket client times out correctly.
func TestSocketClientTimeout(t *testing.T) {
const (
Success = 0
FailDuringEnqueue = 1
FailDuringProcessing = 2

baseTime = 10 * time.Millisecond
)
type testCase struct {
name string
timeout time.Duration
enqueueSleep time.Duration
processingSleep time.Duration
expect int
}
testCases := []testCase{
{name: "immediate", timeout: baseTime, enqueueSleep: 0, processingSleep: 0, expect: Success},
{name: "small enqueue delay", timeout: 4 * baseTime, enqueueSleep: 1 * baseTime, processingSleep: 0, expect: Success},
{name: "small processing delay", timeout: 4 * baseTime, enqueueSleep: 0, processingSleep: 1 * baseTime, expect: Success},
{name: "within timeout", timeout: 4 * baseTime, enqueueSleep: 1 * baseTime, processingSleep: 1 * baseTime, expect: Success},
{name: "timeout during enqueue", timeout: 3 * baseTime, enqueueSleep: 4 * baseTime, processingSleep: 1 * baseTime, expect: FailDuringEnqueue},
{name: "timeout during processing", timeout: 4 * baseTime, enqueueSleep: 1 * baseTime, processingSleep: 4 * baseTime, expect: FailDuringProcessing},
}

logger := log.NewTestingLogger(t)

for i, tc := range testCases {
i := i
tc := tc
t.Run(tc.name, func(t *testing.T) {

// wait until all threads end, otherwise we'll get data race in t.Log()
wg := sync.WaitGroup{}
defer wg.Wait()

mainCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

socket := "unix://" + t.TempDir() + "/socket." + strconv.Itoa(i)

checkTxExecuted := atomic.Bool{}

app := mocks.NewApplication(t)
app.On("Echo", mock.Anything, mock.Anything).Return(&types.ResponseEcho{}, nil).Maybe()
app.On("Info", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) {
wg.Add(1)
logger.Debug("Info before sleep")
time.Sleep(tc.enqueueSleep)
logger.Debug("Info after sleep")
wg.Done()
}).Return(&types.ResponseInfo{}, nil).Maybe()
app.On("CheckTx", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) {
wg.Add(1)
logger.Debug("CheckTx before sleep")
checkTxExecuted.Store(true)
time.Sleep(tc.processingSleep)
logger.Debug("CheckTx after sleep")
wg.Done()
}).Return(&types.ResponseCheckTx{}, nil).Maybe()

service, err := server.NewServer(logger, socket, "socket", app)
require.NoError(t, err)
svr := service.(*server.SocketServer)
err = svr.Start(mainCtx)
require.NoError(t, err)
defer svr.Stop()

cli := NewSocketClient(logger, socket, true).(*socketClient)

err = cli.Start(mainCtx)
require.NoError(t, err)
defer cli.Stop()

reqCtx, reqCancel := context.WithTimeout(context.Background(), tc.timeout)
defer reqCancel()
// Info is here just to block for some time, so we don't want to enforce timeout on it

wg.Add(1)
go func() {
_, _ = cli.Info(mainCtx, &types.RequestInfo{})
wg.Done()
}()

time.Sleep(1 * time.Millisecond) // ensure the goroutine has started

_, err = cli.CheckTx(reqCtx, &types.RequestCheckTx{})
switch tc.expect {
case Success:
require.NoError(t, err)
require.True(t, checkTxExecuted.Load())
case FailDuringEnqueue:
require.Error(t, err)
require.False(t, checkTxExecuted.Load())
case FailDuringProcessing:
require.Error(t, err)
require.True(t, checkTxExecuted.Load())
}
})
}
}
24 changes: 16 additions & 8 deletions internal/mempool/p2p_msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"time"

"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/internal/p2p/client"
Expand All @@ -12,6 +13,12 @@ import (
"github.com/dashpay/tenderdash/types"
)

const (
// CheckTxTimeout is the maximum time we wait for CheckTx to return.
// TODO: Change to config option
CheckTxTimeout = 1 * time.Second
)

type (
mempoolP2PMessageHandler struct {
logger log.Logger
Expand Down Expand Up @@ -53,7 +60,10 @@ func (h *mempoolP2PMessageHandler) Handle(ctx context.Context, _ *client.Client,
SenderNodeID: envelope.From,
}
for _, tx := range protoTxs {
if err := h.checker.CheckTx(ctx, tx, nil, txInfo); err != nil {
subCtx, subCtxCancel := context.WithTimeout(ctx, CheckTxTimeout)
defer subCtxCancel()

if err := h.checker.CheckTx(subCtx, tx, nil, txInfo); err != nil {
if errors.Is(err, types.ErrTxInCache) {
// if the tx is in the cache,
// then we've been gossiped a
Expand All @@ -63,13 +73,11 @@ func (h *mempoolP2PMessageHandler) Handle(ctx context.Context, _ *client.Client,
// problem.
continue
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// Do not propagate context
// cancellation errors, but do
// not continue to check
// transactions from this
// message if we are shutting down.
return err

// In case of ctx cancelation, we return error as we are most likely shutting down.
// Otherwise we just reject the tx.
if errCtx := ctx.Err(); errCtx != nil {
return errCtx
}
logger.Error("checktx failed for tx",
"tx", fmt.Sprintf("%X", types.Tx(tx).Hash()),
Expand Down
9 changes: 7 additions & 2 deletions internal/p2p/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,15 @@ func (c *Client) GetSyncStatus(ctx context.Context) error {
}

// SendTxs sends a transaction to the peer
func (c *Client) SendTxs(ctx context.Context, peerID types.NodeID, tx types.Tx) error {
func (c *Client) SendTxs(ctx context.Context, peerID types.NodeID, tx ...types.Tx) error {
txs := make([][]byte, len(tx))
for i := 0; i < len(tx); i++ {
txs[i] = tx[i]
}

return c.Send(ctx, p2p.Envelope{
To: peerID,
Message: &protomem.Txs{Txs: [][]byte{tx}},
Message: &protomem.Txs{Txs: txs},
})
}

Expand Down
16 changes: 14 additions & 2 deletions internal/rpc/core/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,17 @@ import (
// More:
// https://docs.tendermint.com/master/rpc/#/Tx/broadcast_tx_async
// Deprecated and should be removed in 0.37
func (env *Environment) BroadcastTxAsync(ctx context.Context, req *coretypes.RequestBroadcastTx) (*coretypes.ResultBroadcastTx, error) {
go func() { _ = env.Mempool.CheckTx(ctx, req.Tx, nil, mempool.TxInfo{}) }()
func (env *Environment) BroadcastTxAsync(_ctx context.Context, req *coretypes.RequestBroadcastTx) (*coretypes.ResultBroadcastTx, error) {
go func() {
// We need to create a new context here, because the original context
// may be canceled after parent function returns.
ctx, cancel := context.WithTimeout(context.Background(), mempool.CheckTxTimeout)
defer cancel()

if res, err := env.BroadcastTx(ctx, req); err != nil || res.Code != abci.CodeTypeOK {
env.Logger.Error("error on broadcastTxAsync", "err", err, "result", res, "tx", req.Tx.Hash())
}
}()

return &coretypes.ResultBroadcastTx{Hash: req.Tx.Hash()}, nil
}
Expand All @@ -37,6 +46,9 @@ func (env *Environment) BroadcastTxSync(ctx context.Context, req *coretypes.Requ
// DeliverTx result.
// More: https://docs.tendermint.com/master/rpc/#/Tx/broadcast_tx_sync
func (env *Environment) BroadcastTx(ctx context.Context, req *coretypes.RequestBroadcastTx) (*coretypes.ResultBroadcastTx, error) {
ctx, cancel := context.WithTimeout(ctx, mempool.CheckTxTimeout)
defer cancel()

resCh := make(chan *abci.ResponseCheckTx, 1)
err := env.Mempool.CheckTx(
ctx,
Expand Down
Loading