diff --git a/api/v1/lib/httpcli/httpsched/state.go b/api/v1/lib/httpcli/httpsched/state.go index f65565c6..b64318f5 100644 --- a/api/v1/lib/httpcli/httpsched/state.go +++ b/api/v1/lib/httpcli/httpsched/state.go @@ -220,6 +220,10 @@ func doSubscribe(ctx context.Context, ci callerInternal, call *stateCall) (mesos } func mustSubscribe(ctx context.Context, state *state) phase { + return mustSubscribe0(ctx, state, doSubscribe) +} + +func mustSubscribe0(ctx context.Context, state *state, doSubscribe func(context.Context, callerInternal, *stateCall) (string, context.CancelFunc)) phase { // (a) validate call = SUBSCRIBE if t := state.call.GetType(); t != scheduler.Call_SUBSCRIBE { state.call.err = apierrors.CodeUnsubscribed.Error("httpsched: expected SUBSCRIBE instead of " + t.String()) diff --git a/api/v1/lib/httpcli/httpsched/state_test.go b/api/v1/lib/httpcli/httpsched/state_test.go index 54a752c5..16601ee8 100644 --- a/api/v1/lib/httpcli/httpsched/state_test.go +++ b/api/v1/lib/httpcli/httpsched/state_test.go @@ -1,9 +1,12 @@ package httpsched import ( + "context" "errors" + "sync" "testing" + "github.com/mesos/mesos-go/api/v1/lib" "github.com/mesos/mesos-go/api/v1/lib/encoding" "github.com/mesos/mesos-go/api/v1/lib/extras/latch" "github.com/mesos/mesos-go/api/v1/lib/scheduler" @@ -53,3 +56,129 @@ func TestDisconnectionDecoder(t *testing.T) { t.Error("disconnect func was not called") } } + +func TestMustSubscribe(t *testing.T) { + subscribeCall := &scheduler.Call{Type: scheduler.Call_SUBSCRIBE} + type subscription struct { + resp mesos.Response + cancel context.CancelFunc + } + newSubscription := func(err error) subscription { + closed := make(chan struct{}) + var closeOnce sync.Once + cancel := func() { closeOnce.Do(func() { close(closed) }) } + resp := &mesos.ResponseWrapper{ + Decoder: encoding.DecoderFunc(func(encoding.Unmarshaler) (_ error) { + if err == context.Canceled { + select { + case <-closed: + return err + default: + return + } + } + if err != nil { + return err + } + return + }), + Closer: mesos.CloseFunc(func() (_ error) { cancel(); return }), + } + return subscription{ + cancel: cancel, + resp: resp, + } + } + for ti, tc := range map[string]struct { + state *state + streamID string + sub subscription + un encoding.Unmarshaler + //-- wants: + wantsDisconnected bool + }{ + "<>": { + state: &state{call: &stateCall{}}, + sub: subscription{cancel: func() {}}, + wantsDisconnected: true, + }, + "subFailed": { + state: &state{call: &stateCall{Call: subscribeCall}}, + sub: subscription{cancel: func() {}}, + wantsDisconnected: true, + }, + "subWorkedDecoderCanceled": { + state: &state{ + call: &stateCall{Call: subscribeCall}, + client: &client{}, + notifyQueue: make(chan Notification, 1)}, + streamID: "1", + sub: newSubscription(context.Canceled), + un: &scheduler.Event{}, + // response decoder will not return context canceled unless the disconnector has been invoked + }, + "subWorkedDecoderDeadlineExceeded": { + state: &state{ + call: &stateCall{Call: subscribeCall}, + client: &client{}, + notifyQueue: make(chan Notification, 1)}, + streamID: "1", + sub: newSubscription(context.DeadlineExceeded), + }, + "subWorkedDecoderBadObject": { + state: &state{ + call: &stateCall{Call: subscribeCall}, + client: &client{}, + notifyQueue: make(chan Notification, 1)}, + streamID: "1", + sub: newSubscription(nil), + un: &scheduler.Call{}, + }, + "subWorkedDecoderSchedulerError": { + state: &state{ + call: &stateCall{Call: subscribeCall}, + client: &client{}, + notifyQueue: make(chan Notification, 1)}, + streamID: "1", + sub: newSubscription(nil), + un: &scheduler.Event{Type: scheduler.Event_ERROR}, + }, + } { + t.Run(ti, func(t *testing.T) { + p := mustSubscribe0(context.Background(), tc.state, + func(_ context.Context, _ callerInternal, cl *stateCall) (string, context.CancelFunc) { + cl.resp = tc.sub.resp + return tc.streamID, tc.sub.cancel + }) + if tc.wantsDisconnected != p.isDisconnected() { + if tc.wantsDisconnected { + t.Fatal("unexpectedly disconnected") + } + t.Fatal("expected to be disconnected but was not") + } + if tc.wantsDisconnected { + return + } + + // check that state disconnector() doesn't change the phase + tc.state.fn = p + tc.state.disconnector() + if tc.state.fn.isDisconnected() { + t.Fatal("disconnector() incorrectly transitioned phase to disconnected") + } + if tc.state.streamID != tc.streamID { + t.Fatalf("expected stream %q instead of %q", tc.streamID, tc.state.streamID) + } + + // fake response whose decoder disconnects for various reasons + err := tc.state.call.resp.Decode(tc.un) + if _, ok := tc.un.(*scheduler.Event); err == nil && !ok { + t.Fatal("expected error but got none") + } + t.Log(err) + if !tc.state.fn.isDisconnected() { + t.Fatal("disconnector() incorrectly transitioned phase to disconnected") + } + }) + } +}