Skip to content

Commit

Permalink
ACS: Handle heartbeats vs idle correctly
Browse files Browse the repository at this point in the history
Previously a heartbeat message was required to consider the channel
active.
In realitiy, heartbeat messages were only sent when the channel was
inactive and no other messages were being sent.
This avoids treating a lack of heartbeats as an idle channel and closing
it unless there are also no other messages.

In addition, this tweaks how backoffs are handled (time, resets, etc) a
bit to be more forgiving to these sorts of issues (where the connection
is lost, but can be re-established).

Relates to aws#103
  • Loading branch information
euank committed Jun 9, 2015
1 parent 3393896 commit 943ba89
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 22 deletions.
83 changes: 66 additions & 17 deletions agent/acs/handler/acs_handler.go
Expand Up @@ -21,6 +21,8 @@ import (
"strconv"
"time"

"golang.org/x/net/context"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
Expand Down Expand Up @@ -51,35 +53,85 @@ const payloadMessageBufferSize = 10
// the last sequence number successfully handled.
var SequenceNumber = utilatomic.NewIncreasingInt64(1)

// StartSessionArguments is a struct representing all the things this handler
// needs... This is really a hack to get by-name instead of positional
// arguments since there are too many for positional to be wieldy
type StartSessionArguments struct {
ContainerInstanceArn string
CredentialProvider credentials.AWSCredentialProvider
Config *config.Config
TaskEngine engine.TaskEngine
ECSClient api.ECSClient
StateManager statemanager.StateManager
AcceptInvalidCert bool
}

// StartSession creates a session with ACS and handles requests using the passed
// in arguments.
func StartSession(containerInstanceArn string, credentialProvider credentials.AWSCredentialProvider, cfg *config.Config, taskEngine engine.TaskEngine, ecsclient api.ECSClient, stateManager statemanager.StateManager, acceptInvalidCert bool) error {
backoff := utils.NewSimpleBackoff(time.Second, 2*time.Minute, 0.2, 2)
func StartSession(ctx context.Context, args StartSessionArguments) error {
ecsclient := args.ECSClient
cfg := args.Config
backoff := utils.NewSimpleBackoff(250*time.Millisecond, 2*time.Minute, 0.2, 1.5)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

acsError := func() error {
acsEndpoint, err := ecsclient.DiscoverPollEndpoint(containerInstanceArn)
acsEndpoint, err := ecsclient.DiscoverPollEndpoint(args.ContainerInstanceArn)
if err != nil {
log.Error("Unable to discover poll endpoint", "err", err)
return err
}
log.Debug("Connecting to ACS endpoint " + acsEndpoint)

url := AcsWsUrl(acsEndpoint, cfg.Cluster, containerInstanceArn, taskEngine)
url := AcsWsUrl(acsEndpoint, cfg.Cluster, args.ContainerInstanceArn, args.TaskEngine)

client := acsclient.New(url, cfg.AWSRegion, credentialProvider, acceptInvalidCert)
client := acsclient.New(url, cfg.AWSRegion, args.CredentialProvider, args.AcceptInvalidCert)
defer client.Close()

client.AddRequestHandler(payloadMessageHandler(client, cfg.Cluster, containerInstanceArn, taskEngine, ecsclient, stateManager))
client.AddRequestHandler(heartbeatHandler(client))
timer := ttime.AfterFunc(utils.AddJitter(heartbeatTimeout, heartbeatJitter), func() {
log.Warn("ACS Connection hasn't had any activity for too long; closing connection")
closeErr := client.Close()
if closeErr != nil {
log.Warn("Error disconnecting: " + closeErr.Error())
}
})
defer timer.Stop()
// Any message from the server resets the disconnect timeout
client.SetAnyRequestHandler(anyMessageHandler(timer))
client.AddRequestHandler(payloadMessageHandler(client, cfg.Cluster, args.ContainerInstanceArn, args.TaskEngine, args.ECSClient, args.StateManager))
// Ignore heartbeat messages; anyMessageHandler gets 'em
client.AddRequestHandler(func(*ecsacs.HeartbeatMessage) {})

updater.AddAgentUpdateHandlers(client, cfg, stateManager, taskEngine)
updater.AddAgentUpdateHandlers(client, cfg, args.StateManager, args.TaskEngine)

err = client.Connect()
if err != nil {
log.Error("Error connecting to ACS: " + err.Error())
return err
}
return client.Serve()
ttime.AfterFunc(utils.AddJitter(heartbeatTimeout, heartbeatJitter), func() {
// If we do not have an error connecting and remain connected for at
// least 5 or so minutes, reset the backoff. This prevents disconnect
// errors that only happen infrequently from damaging the
// reconnectability as significantly.
backoff.Reset()
})

serveErr := make(chan error, 1)
go func() {
serveErr <- client.Serve()
}()

select {
case <-ctx.Done():
return ctx.Err()
case err := <-serveErr:
return err
}
}()
if acsError == nil || acsError == io.EOF {
backoff.Reset()
Expand All @@ -90,14 +142,11 @@ func StartSession(containerInstanceArn string, credentialProvider credentials.AW
}
}

// heartbeatHandler starts a timer and listens for acs heartbeats. If there are
// none for unexpectedly long, it closes the passed in connection.
func heartbeatHandler(acsConnection io.Closer) func(*ecsacs.HeartbeatMessage) {
timer := time.AfterFunc(utils.AddJitter(heartbeatTimeout, heartbeatJitter), func() {
log.Debug("ACS Connection hasn't had a heartbeat in too long of a timeout; disconnecting")
acsConnection.Close()
})
return func(*ecsacs.HeartbeatMessage) {
// anyMessageHandler handles any server message. Any server message means the
// connection is active and thus the heartbeat disconnect should not occur
func anyMessageHandler(timer *time.Timer) func(interface{}) {
return func(interface{}) {
log.Debug("ACS activity occured")
timer.Reset(utils.AddJitter(heartbeatTimeout, heartbeatJitter))
}
}
Expand Down
64 changes: 60 additions & 4 deletions agent/acs/handler/acs_handler_test.go
@@ -1,24 +1,30 @@
package handler_test

import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"golang.org/x/net/context"

"github.com/aws/amazon-ecs-agent/agent/acs/handler"
"github.com/aws/amazon-ecs-agent/agent/api/mocks"
"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/ecs_client/authv4/credentials"
"github.com/aws/amazon-ecs-agent/agent/engine/mocks"
"github.com/aws/amazon-ecs-agent/agent/statemanager"
"github.com/aws/amazon-ecs-agent/agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/agent/version"
"github.com/gorilla/websocket"

"code.google.com/p/gomock/gomock"
)

const samplePayloadMessage = `{"type":"PayloadMessage","message":{"messageId":"123","tasks":[{"taskDefinitionAccountId":"123","containers":[{"environment":{},"name":"name","cpu":1,"essential":true,"memory":1,"portMappings":[],"overrides":"{}","image":"i","mountPoints":[],"volumesFrom":[]}],"version":"3","volumes":[],"family":"f","arn":"arn","desiredStatus":"RUNNING"}],"generatedAt":1,"clusterArn":"1","containerInstanceArn":"1","seqNum":1}}`

func TestAcsWsUrl(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down Expand Up @@ -68,13 +74,13 @@ func TestHandlerReconnects(t *testing.T) {
t.Fatal(err)
}

ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).AnyTimes()
ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(10)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ctx, cancel := context.WithCancel(context.Background())
ended := make(chan bool, 1)
go func() {
handler.StartSession("myArn", credentials.NewCredentialProvider("", ""), &config.Config{Cluster: "someCluster"}, taskEngine, ecsclient, statemanager, true)
// This should never return
handler.StartSession(ctx, handler.StartSessionArguments{"myArn", credentials.NewCredentialProvider("", ""), &config.Config{Cluster: "someCluster"}, taskEngine, ecsclient, statemanager, true})
ended <- true
}()
start := time.Now()
Expand All @@ -88,9 +94,58 @@ func TestHandlerReconnects(t *testing.T) {

select {
case <-ended:
t.Fatal("Should never stop session")
t.Fatal("Should not have stopped session")
default:
}
cancel()
<-ended
}

func TestHeartbeatOnlyWhenIdle(t *testing.T) {
testTime := ttime.NewTestTime()
ttime.SetTime(testTime)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
ecsclient := mock_api.NewMockECSClient(ctrl)
statemanager := statemanager.NewNoopStateManager()

closeWS := make(chan bool)
server, serverIn, _, errChan, err := startMockAcsServer(t, closeWS)
if err != nil {
t.Fatal(err)
}

// We're testing that it does not reconnect here; must be the case
ecsclient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(1)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ctx, cancel := context.WithCancel(context.Background())
ended := make(chan bool, 1)
go func() {
handler.StartSession(ctx, handler.StartSessionArguments{"myArn", credentials.NewCredentialProvider("", ""), &config.Config{Cluster: "someCluster"}, taskEngine, ecsclient, statemanager, true})
ended <- true
}()

taskAdded := make(chan bool)
taskEngine.EXPECT().AddTask(gomock.Any()).Do(func(interface{}) {
taskAdded <- true
}).Times(10)
for i := 0; i < 10; i++ {
serverIn <- samplePayloadMessage
testTime.Warp(1 * time.Minute)
<-taskAdded
}

select {
case <-ended:
t.Fatal("Should not have stop session")
case err := <-errChan:
t.Fatal("Error should not have been returned from server", err)
default:
}
cancel()
<-ended
}

func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) {
Expand All @@ -105,6 +160,7 @@ func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, ch
<-closeWS
ws.WriteMessage(websocket.CloseMessage, nil)
ws.Close()
errChan <- io.EOF
}()
if err != nil {
errChan <- err
Expand Down
14 changes: 13 additions & 1 deletion agent/agent.go
Expand Up @@ -19,6 +19,8 @@ import (
"os"
"time"

"golang.org/x/net/context"

acshandler "github.com/aws/amazon-ecs-agent/agent/acs/handler"
"github.com/aws/amazon-ecs-agent/agent/api"
"github.com/aws/amazon-ecs-agent/agent/auth"
Expand Down Expand Up @@ -68,6 +70,8 @@ func _main() int {
return exitcodes.ExitSuccess
}

ctx := context.Background()

if err != nil {
log.Criticalf("Error loading config: %v", err)
// All required config values can be inferred from EC2 Metadata, so this error could be transient.
Expand Down Expand Up @@ -192,7 +196,15 @@ func _main() int {
go tcshandler.StartMetricsSession(telemetrySessionParams)

log.Info("Beginning Polling for updates")
err = acshandler.StartSession(containerInstanceArn, credentialProvider, cfg, taskEngine, client, stateManager, *acceptInsecureCert)
err = acshandler.StartSession(ctx, acshandler.StartSessionArguments{
AcceptInvalidCert: *acceptInsecureCert,
Config: cfg,
ContainerInstanceArn: containerInstanceArn,
CredentialProvider: credentialProvider,
ECSClient: client,
StateManager: stateManager,
TaskEngine: taskEngine,
})
if err != nil {
log.Criticalf("Unretriable error starting communicating with ACS: %v", err)
return exitcodes.ExitTerminal
Expand Down
11 changes: 11 additions & 0 deletions agent/utils/ttime/test_time.go
Expand Up @@ -54,6 +54,17 @@ func (t *TestTime) After(d time.Duration) <-chan time.Time {
return done
}

// AfterFunc returns a timer and calls a function after a given time
// taking into account time-warping
func (t *TestTime) AfterFunc(d time.Duration, f func()) *time.Timer {
timer := time.AfterFunc(d, f)
go func() {
t.Sleep(d)
timer.Reset(0)
}()
return timer
}

// Sleep sleeps the given duration in mock-time; that is to say that Warps will
// reduce the amount of time slept and LudicrousSpeed will cause instant
// success.
Expand Down
13 changes: 13 additions & 0 deletions agent/utils/ttime/ttime.go
Expand Up @@ -8,6 +8,7 @@ type Time interface {
Now() time.Time
Sleep(d time.Duration)
After(d time.Duration) <-chan time.Time
AfterFunc(d time.Duration, f func()) *time.Timer
}

// DefaultTime is a Time that behaves normally
Expand All @@ -30,6 +31,13 @@ func (*DefaultTime) After(d time.Duration) <-chan time.Time {
return time.After(d)
}

// AfterFunc waits for the duration to elapse and then calls f in its own
// goroutine. It returns a Timer that can be used to cancel the call using its
// Stop method.
func (*DefaultTime) AfterFunc(d time.Duration, f func()) *time.Timer {
return time.AfterFunc(d, f)
}

// SetTime configures what 'Time' implementation to use for each of the
// package-level methods.
func SetTime(t Time) {
Expand All @@ -55,3 +63,8 @@ func Since(t time.Time) time.Duration {
func After(t time.Duration) <-chan time.Time {
return _time.After(t)
}

// AfterFunc calls the implementations AfterFunc method
func AfterFunc(d time.Duration, f func()) *time.Timer {
return _time.AfterFunc(d, f)
}
17 changes: 17 additions & 0 deletions agent/wsclient/client.go
Expand Up @@ -83,6 +83,11 @@ type RequestHandler interface{}
// ClientServer is a combined client and server for the backend websocket connection
type ClientServer interface {
AddRequestHandler(RequestHandler)
// SetAnyRequestHandler takes a function with the signature 'func(i
// interface{})' and calls it with every message the server passes down.
// Only a single 'AnyRequestHandler' will be active at a given time for a
// ClientServer
SetAnyRequestHandler(RequestHandler)
MakeRequest(input interface{}) error
Connect() error
Serve() error
Expand All @@ -101,6 +106,10 @@ type ClientServerImpl struct {
// form:
// "FooMessage": func(message *ecsacs.FooMessage)
RequestHandlers map[string]RequestHandler
// AnyRequestHandler is a request handler that, if set, is called on every
// message with said message. It will be called before a RequestHandler is
// called. It must take a single interface{} argument.
AnyRequestHandler RequestHandler
// URL is the full url to the backend, including path, querystring, and so on.
URL string
ClientServer
Expand Down Expand Up @@ -184,6 +193,10 @@ func (cs *ClientServerImpl) AddRequestHandler(f RequestHandler) {
cs.RequestHandlers[firstArgTypeStr] = f
}

func (cs *ClientServerImpl) SetAnyRequestHandler(f RequestHandler) {
cs.AnyRequestHandler = f
}

// MakeRequest makes a request using the given input. Note, the input *MUST* be
// a pointer to a valid backend type that this client recognises
func (cs *ClientServerImpl) MakeRequest(input interface{}) error {
Expand Down Expand Up @@ -262,6 +275,10 @@ func (cs *ClientServerImpl) handleMessage(data []byte) {
return
}

if cs.AnyRequestHandler != nil {
reflect.ValueOf(cs.AnyRequestHandler).Call([]reflect.Value{reflect.ValueOf(typedMessage)})
}

if handler, ok := cs.RequestHandlers[typeStr]; ok {
reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(typedMessage)})
} else {
Expand Down
8 changes: 8 additions & 0 deletions agent/wsclient/mock/client.go
Expand Up @@ -89,3 +89,11 @@ func (_m *MockClientServer) Serve() error {
func (_mr *_MockClientServerRecorder) Serve() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "Serve")
}

func (_m *MockClientServer) SetAnyRequestHandler(_param0 wsclient.RequestHandler) {
_m.ctrl.Call(_m, "SetAnyRequestHandler", _param0)
}

func (_mr *_MockClientServerRecorder) SetAnyRequestHandler(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SetAnyRequestHandler", arg0)
}

0 comments on commit 943ba89

Please sign in to comment.