Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make agent reconnect in case of connection dropping #3342

Merged
merged 13 commits into from
Nov 9, 2023
23 changes: 20 additions & 3 deletions agent/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"log"
"os"
"strings"
"time"
Expand All @@ -12,8 +13,9 @@ import (
)

type Config struct {
APIKey string
AgentName string
APIKey string
AgentName string
PingPeriod time.Duration
}

type SessionConfig struct {
Expand All @@ -22,6 +24,7 @@ type SessionConfig struct {
}

type Client struct {
endpoint string
conn *grpc.ClientConn
config Config
sessionConfig *SessionConfig
Expand Down Expand Up @@ -69,7 +72,10 @@ func (c *Client) Start(ctx context.Context) error {
return err
}

c.startHearthBeat(ctx)
err = c.startHearthBeat(ctx)
if err != nil {
return err
}

return nil
}
Expand Down Expand Up @@ -144,3 +150,14 @@ func (c *Client) getName() (string, error) {
func isCancelledError(err error) bool {
return err != nil && strings.Contains(err.Error(), "context canceled")
}

func (c *Client) reconnect() error {
// connection is not working. We need to reconnect
newConn, err := connect(context.Background(), c.endpoint)
if err != nil {
log.Fatal(fmt.Errorf("could not reconnect to server: %w", err))
}

c.conn = newConn
return nil
}
41 changes: 41 additions & 0 deletions agent/client/client_reconnection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package client_test

import (
"context"
"testing"
"time"

"github.com/kubeshop/tracetest/agent/client"
"github.com/kubeshop/tracetest/agent/client/mocks"
"github.com/kubeshop/tracetest/agent/proto"
"github.com/stretchr/testify/require"
"gotest.tools/v3/assert"
)

func TestClientReconnection(t *testing.T) {
server := mocks.NewGrpcServer()

client, err := client.Connect(context.Background(), server.Addr(), client.WithPingPeriod(time.Second))
require.NoError(t, err)

client.Start(context.Background())

err = client.Start(context.Background())
require.NoError(t, err)

server.Stop()

err = client.SendTriggerResponse(context.Background(), &proto.TriggerResponse{RequestID: "my-request-id"})
require.NotNil(t, err)

time.Sleep(2 * time.Second)

server.Restart()

err = client.SendTriggerResponse(context.Background(), &proto.TriggerResponse{RequestID: "my-request-id"})
require.NoError(t, err)

triggerResponse := server.GetLastTriggerResponse()
require.NotNil(t, triggerResponse)
assert.Equal(t, "my-request-id", triggerResponse.RequestID)
}
7 changes: 6 additions & 1 deletion agent/client/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ func Connect(ctx context.Context, endpoint string, opts ...Option) (*Client, err
return nil, err
}

client := &Client{conn: conn}
config := Config{
PingPeriod: 30 * time.Second,
}

client := &Client{endpoint: endpoint, conn: conn, config: config}
for _, opt := range opts {
opt(client)
}
Expand Down Expand Up @@ -55,6 +59,7 @@ func connect(ctx context.Context, endpoint string) (*grpc.ClientConn, error) {
ctx, endpoint,
grpc.WithTransportCredentials(transportCredentials),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithIdleTimeout(0), // disable grpc idle timeout
)
if err != nil {
return nil, fmt.Errorf("could not connect to server: %w", err)
Expand Down
23 changes: 20 additions & 3 deletions agent/client/mocks/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type GrpcServerMock struct {

lastTriggerResponse *proto.TriggerResponse
lastPollingResponse *proto.PollingResponse

server *grpc.Server
}

func NewGrpcServer() *GrpcServerMock {
Expand All @@ -33,7 +35,8 @@ func NewGrpcServer() *GrpcServerMock {
var wg sync.WaitGroup
wg.Add(1)

go server.start(&wg)
anyPort := 0
go server.start(&wg, anyPort)

wg.Wait()

Expand All @@ -44,8 +47,8 @@ func (s *GrpcServerMock) Addr() string {
return fmt.Sprintf("localhost:%d", s.port)
}

func (s *GrpcServerMock) start(wg *sync.WaitGroup) {
lis, err := net.Listen("tcp", ":0")
func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
Expand All @@ -55,6 +58,8 @@ func (s *GrpcServerMock) start(wg *sync.WaitGroup) {
server := grpc.NewServer()
proto.RegisterOrchestratorServer(server, s)

s.server = server

wg.Done()
if err := server.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
Expand Down Expand Up @@ -166,3 +171,15 @@ func (s *GrpcServerMock) TerminateConnection(reason string) {
Reason: reason,
}
}

func (s *GrpcServerMock) Restart() {
var wg sync.WaitGroup
wg.Add(1)
go s.start(&wg, s.port)

wg.Wait()
}

func (s *GrpcServerMock) Stop() {
s.server.Stop()
}
8 changes: 8 additions & 0 deletions agent/client/options.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package client

import "time"

type Option func(*Client)

func WithAPIKey(apiKey string) Option {
Expand All @@ -13,3 +15,9 @@ func WithAgentName(name string) Option {
c.config.AgentName = name
}
}

func WithPingPeriod(period time.Duration) Option {
return func(c *Client) {
c.config.PingPeriod = period
}
}
7 changes: 6 additions & 1 deletion agent/client/workflow_listen_for_ds_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error
}

if err != nil {
log.Fatal("could not get message from ds connection stream: %w", err)
c.reconnect()
}

if c.dataStoreConnectionListener == nil {
log.Println("warning: datastore connection listener is nil")
continue
}

// TODO: Get ctx from request
Expand Down
7 changes: 6 additions & 1 deletion agent/client/workflow_listen_for_poll_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ func (c *Client) startPollerListener(ctx context.Context) error {
}

if err != nil {
log.Fatal("could not get message from polling stream: %w", err)
c.reconnect()
}

if c.pollListener == nil {
log.Println("warning: polling listener is nil")
continue
}

// TODO: Get ctx from request
Expand Down
7 changes: 6 additions & 1 deletion agent/client/workflow_listen_for_trigger_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ func (c *Client) startTriggerListener(ctx context.Context) error {
}

if err != nil {
log.Fatal("could not get message from trigger stream: %w", err)
c.reconnect()
}

if c.triggerListener == nil {
log.Println("warning: trigger listener is nil")
continue
}

// TODO: get context from request
Expand Down
8 changes: 6 additions & 2 deletions agent/client/workflow_ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ import (

func (c *Client) startHearthBeat(ctx context.Context) error {
client := proto.NewOrchestratorClient(c.conn)
ticker := time.NewTicker(2 * time.Minute)
ticker := time.NewTicker(c.config.PingPeriod)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the ping to 30s


go func() {
for range ticker.C {
client.Ping(ctx, c.sessionConfig.AgentIdentification)
_, err := client.Ping(ctx, c.sessionConfig.AgentIdentification)
if err != nil {
// Something is wrong with the connection
c.reconnect()
}
}
}()

Expand Down
7 changes: 6 additions & 1 deletion agent/client/workflow_shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ func (c *Client) startShutdownListener(ctx context.Context) error {
}

if err != nil {
log.Fatal("could not get shutdown listener: %w", err)
c.reconnect()
}

if c.shutdownListener == nil {
log.Println("warning: shutdown listener is nil")
continue
}

// TODO: get context from request
Expand Down