Skip to content

Commit

Permalink
feat: make agent reconnect in case of connection dropping (#3342)
Browse files Browse the repository at this point in the history
* feat: make agent reconnect in case of connection dropping

* fix panic

* fix(agent): use mutex for reconnecting client

* fix tests

* remove log

* fix reconnection logic

* add retry-go

* ignore EOF errors

* set max retries to 3 in grpc server mock

* fix: reconnection logic

* remove unused method

* PR patches

* fix build
  • Loading branch information
mathnogueira committed Nov 9, 2023
1 parent d2abf0e commit 3c24e6e
Show file tree
Hide file tree
Showing 18 changed files with 219 additions and 38 deletions.
69 changes: 66 additions & 3 deletions agent/client/client.go
Expand Up @@ -2,18 +2,30 @@ package client

import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"strings"
"sync"
"time"

retry "github.com/avast/retry-go"
"github.com/kubeshop/tracetest/agent/proto"
"google.golang.org/grpc"
)

const (
ReconnectRetryAttempts = 6
ReconnectRetryAttemptDelay = 1 * time.Second
defaultPingPeriod = 30 * time.Second
)

type Config struct {
APIKey string
AgentName string
APIKey string
AgentName string
PingPeriod time.Duration
}

type SessionConfig struct {
Expand All @@ -22,6 +34,8 @@ type SessionConfig struct {
}

type Client struct {
mutex sync.Mutex
endpoint string
conn *grpc.ClientConn
config Config
sessionConfig *SessionConfig
Expand Down Expand Up @@ -69,7 +83,10 @@ func (c *Client) Start(ctx context.Context) error {
return err
}

c.startHearthBeat(ctx)
err = c.startHearthBeat(ctx)
if err != nil {
return err
}

return nil
}
Expand Down Expand Up @@ -144,3 +161,49 @@ func (c *Client) getName() (string, error) {
func isCancelledError(err error) bool {
return err != nil && strings.Contains(err.Error(), "context canceled")
}

func (c *Client) reconnect() error {
// connection is not working. We need to reconnect
err := retry.Do(func() error {
return c.connect(context.Background())
}, retry.Attempts(ReconnectRetryAttempts), retry.Delay(ReconnectRetryAttemptDelay))

if err != nil {
return fmt.Errorf("could not reconnect to server: %w", err)
}

return c.Start(context.Background())
}

func (c *Client) handleDisconnectionError(inputErr error) (bool, error) {
if !isConnectionError(inputErr) {
// if it's nil or any error other than the one we care about, return it and let the caller handle it
return false, inputErr
}

err := retry.Do(func() error {
return c.reconnect()
})

if err != nil {
log.Fatal(err)
}

return true, nil
}

func isConnectionError(err error) bool {
return err != nil && strings.Contains(err.Error(), "connection refused")
}

func isEndOfFileError(err error) bool {
if err == nil {
return false
}

if isEof := errors.Is(err, io.EOF); isEof {
return true
}

return strings.Contains(err.Error(), "EOF")
}
37 changes: 27 additions & 10 deletions agent/client/connector.go
Expand Up @@ -13,16 +13,28 @@ import (
)

func Connect(ctx context.Context, endpoint string, opts ...Option) (*Client, error) {
conn, err := connect(ctx, endpoint)
if err != nil {
return nil, err
config := Config{
PingPeriod: defaultPingPeriod,
}

client := &Client{
endpoint: endpoint,
config: config,
triggerListener: triggerListener,
pollListener: pollListener,
shutdownListener: shutdownListener,
dataStoreConnectionListener: dataStoreConnectionListener,
}

client := &Client{conn: conn}
for _, opt := range opts {
opt(client)
}

err := client.connect(ctx)
if err != nil {
return nil, err
}

return client, nil
}

Expand All @@ -42,25 +54,30 @@ var retryPolicy = `{
}]
}`

func connect(ctx context.Context, endpoint string) (*grpc.ClientConn, error) {
func (c *Client) connect(ctx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

transportCredentials, err := getTransportCredentialsForEndpoint(endpoint)
transportCredentials, err := getTransportCredentialsForEndpoint(c.endpoint)
if err != nil {
return nil, fmt.Errorf("could not get transport credentials: %w", err)
return fmt.Errorf("could not get transport credentials: %w", err)
}

conn, err := grpc.DialContext(
ctx, endpoint,
ctx, c.endpoint,
grpc.WithTransportCredentials(transportCredentials),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithIdleTimeout(0), // disable grpc idle timeout
)
if err != nil {
return nil, fmt.Errorf("could not connect to server: %w", err)
return fmt.Errorf("could not connect to server: %w", err)
}

return conn, nil
c.conn = conn
return nil
}

func getTransportCredentialsForEndpoint(endpoint string) (credentials.TransportCredentials, error) {
Expand Down
23 changes: 23 additions & 0 deletions agent/client/default_listeners.go
@@ -0,0 +1,23 @@
package client

import (
"context"

"github.com/kubeshop/tracetest/agent/proto"
)

func triggerListener(_ context.Context, _ *proto.TriggerRequest) error {
return nil
}

func pollListener(_ context.Context, _ *proto.PollingRequest) error {
return nil
}

func shutdownListener(_ context.Context, _ *proto.ShutdownRequest) error {
return nil
}

func dataStoreConnectionListener(_ context.Context, _ *proto.DataStoreConnectionTestRequest) error {
return nil
}
34 changes: 27 additions & 7 deletions agent/client/mocks/grpc_server.go
Expand Up @@ -7,6 +7,8 @@ import (
"net"
"sync"

"github.com/avast/retry-go"
"github.com/kubeshop/tracetest/agent/client"
"github.com/kubeshop/tracetest/agent/proto"
"google.golang.org/grpc"
)
Expand All @@ -21,6 +23,8 @@ type GrpcServerMock struct {

lastTriggerResponse *proto.TriggerResponse
lastPollingResponse *proto.PollingResponse

server *grpc.Server
}

func NewGrpcServer() *GrpcServerMock {
Expand All @@ -33,7 +37,12 @@ func NewGrpcServer() *GrpcServerMock {
var wg sync.WaitGroup
wg.Add(1)

go server.start(&wg)
err := retry.Do(func() error {
return server.start(&wg, 0)
}, retry.Attempts(client.ReconnectRetryAttempts), retry.Delay(client.ReconnectRetryAttemptDelay))
if err != nil {
log.Fatal(err)
}

wg.Wait()

Expand All @@ -44,21 +53,28 @@ func (s *GrpcServerMock) Addr() string {
return fmt.Sprintf("localhost:%d", s.port)
}

func (s *GrpcServerMock) start(wg *sync.WaitGroup) {
lis, err := net.Listen("tcp", ":0")
func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) error {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
return fmt.Errorf("failed to listen: %w", err)
}

s.port = lis.Addr().(*net.TCPAddr).Port

server := grpc.NewServer()
proto.RegisterOrchestratorServer(server, s)

s.server = server

wg.Done()
if err := server.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}

go func() {
if err := server.Serve(lis); err != nil {
log.Fatal("failed to serve: %w", err)
}
}()

return nil
}

func (s *GrpcServerMock) Connect(ctx context.Context, req *proto.ConnectRequest) (*proto.AgentConfiguration, error) {
Expand Down Expand Up @@ -166,3 +182,7 @@ func (s *GrpcServerMock) TerminateConnection(reason string) {
Reason: reason,
}
}

func (s *GrpcServerMock) Stop() {
s.server.Stop()
}
8 changes: 8 additions & 0 deletions agent/client/options.go
@@ -1,5 +1,7 @@
package client

import "time"

type Option func(*Client)

func WithAPIKey(apiKey string) Option {
Expand All @@ -13,3 +15,9 @@ func WithAgentName(name string) Option {
c.config.AgentName = name
}
}

func WithPingPeriod(period time.Duration) Option {
return func(c *Client) {
c.config.PingPeriod = period
}
}
14 changes: 10 additions & 4 deletions agent/client/workflow_listen_for_ds_connection_tests.go
Expand Up @@ -2,10 +2,9 @@ package client

import (
"context"
"errors"
"fmt"
"io"
"log"
"time"

"github.com/kubeshop/tracetest/agent/proto"
)
Expand All @@ -22,12 +21,19 @@ func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error
for {
req := proto.DataStoreConnectionTestRequest{}
err := stream.RecvMsg(&req)
if errors.Is(err, io.EOF) || isCancelledError(err) {
if isEndOfFileError(err) || isCancelledError(err) {
return
}

reconnected, err := c.handleDisconnectionError(err)
if reconnected {
return
}

if err != nil {
log.Fatal("could not get message from ds connection stream: %w", err)
log.Println("could not get message from data store connection stream: %w", err)
time.Sleep(1 * time.Second)
continue
}

// TODO: Get ctx from request
Expand Down
14 changes: 10 additions & 4 deletions agent/client/workflow_listen_for_poll_requests.go
Expand Up @@ -2,10 +2,9 @@ package client

import (
"context"
"errors"
"fmt"
"io"
"log"
"time"

"github.com/kubeshop/tracetest/agent/proto"
)
Expand All @@ -22,12 +21,19 @@ func (c *Client) startPollerListener(ctx context.Context) error {
for {
resp := proto.PollingRequest{}
err := stream.RecvMsg(&resp)
if errors.Is(err, io.EOF) || isCancelledError(err) {
if isEndOfFileError(err) || isCancelledError(err) {
return
}

reconnected, err := c.handleDisconnectionError(err)
if reconnected {
return
}

if err != nil {
log.Fatal("could not get message from polling stream: %w", err)
log.Println("could not get message from poller stream: %w", err)
time.Sleep(1 * time.Second)
continue
}

// TODO: Get ctx from request
Expand Down
1 change: 1 addition & 0 deletions agent/client/workflow_listen_for_poll_requests_test.go
Expand Up @@ -14,6 +14,7 @@ import (

func TestPollWorkflow(t *testing.T) {
server := mocks.NewGrpcServer()
defer server.Stop()

client, err := client.Connect(context.Background(), server.Addr())
require.NoError(t, err)
Expand Down

0 comments on commit 3c24e6e

Please sign in to comment.