diff --git a/stan.go b/stan.go index f925255..db3e256 100644 --- a/stan.go +++ b/stan.go @@ -331,6 +331,7 @@ type conn struct { pubNUID *nuid.NUID // NUID generator for published messages. connLostCB ConnectionLostHandler closed bool + fullyClosed bool ping pingInfo } @@ -672,21 +673,17 @@ func (sc *conn) Close() error { sc.Lock() defer sc.Unlock() - if sc.closed { - // We are already closed. + // If we are fully closed, simply return. + if sc.fullyClosed { return nil } - // Signals we are closed. - sc.closed = true - - // Capture for NATS calls below. - if sc.ncOwned { - defer sc.nc.Close() + // If this is the very first Close() call, do some internal cleanup, + // otherwise, simply send the close protocol message. + if !sc.closed { + sc.closed = true + sc.cleanupOnClose(ErrConnectionClosed) } - // Now close ourselves. - sc.cleanupOnClose(ErrConnectionClosed) - req := &pb.CloseRequest{ClientID: sc.clientID} b, _ := req.Marshal() reply, err := sc.nc.Request(sc.closeRequests, b, sc.opts.ConnectTimeout) @@ -701,6 +698,11 @@ func (sc *conn) Close() error { if err != nil { return err } + // As long as we got a valid response, we consider the connection fully closed. + sc.fullyClosed = true + if sc.ncOwned { + sc.nc.Close() + } if cr.Error != "" { return errors.New(cr.Error) } @@ -900,14 +902,17 @@ func (sc *conn) processMsg(raw *nats.Msg) { msg.Sub = sub sub.RLock() + if sub.closed { + sub.RUnlock() + return + } cb := sub.cb ackSubject := sub.ackInbox isManualAck := sub.opts.ManualAcks - subsc := sub.sc // Can be nil if sub has been unsubscribed. sub.RUnlock() // Perform the callback - if cb != nil && subsc != nil { + if cb != nil { cb(msg) } diff --git a/stan_test.go b/stan_test.go index 1c0b734..03a9bc7 100644 --- a/stan_test.go +++ b/stan_test.go @@ -1008,6 +1008,83 @@ func TestClose(t *testing.T) { } } +func TestConnCloseError(t *testing.T) { + s := RunServer(clusterName) + defer s.Shutdown() + + sc := NewDefaultConnection(t) + defer sc.Close() + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + // Alter the close subject so the request fails + scc := sc.(*conn) + scc.Lock() + closeSubj := scc.closeRequests + scc.closeRequests = "dummy" + scc.Unlock() + if err := sc.Close(); err == nil { + t.Fatal("Expected error, got none") + } + + checkInternalConnClosed := func(expectedClosed bool) { + t.Helper() + scc.RLock() + defer scc.RUnlock() + closed := scc.nc.IsClosed() + if expectedClosed && !closed { + t.Fatalf("Expected internal NATS connection to be closed, but it wasn't") + } else if !expectedClosed && closed { + t.Fatalf("Expected internal NATS connection to be not be closed, but it was") + } + } + // Internal NATS connection should not have been closed + checkInternalConnClosed(false) + + // Now setup a subscription to check if library is sending the close protocol. + crsub, err := nc.SubscribeSync(closeSubj) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + // Flush since this is a different connection than the one used by sc. + nc.Flush() + + // We should be able to call Close() again + if err := sc.Close(); err == nil { + t.Fatal("Expected error, did not get one") + } + // Connection still not closed + checkInternalConnClosed(false) + + // Fix close subject + scc.Lock() + scc.closeRequests = closeSubj + scc.Unlock() + // Now close should work. + if err := sc.Close(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // Now internal connection should have been closed + checkInternalConnClosed(true) + + // Check protocol was sent + if _, err := crsub.NextMsg(time.Second); err != nil { + t.Fatalf("Did not get close protocol: %v", err) + } + // Now, another call to Close() should just return nil + if err := sc.Close(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // And this time, no protocol should be sent + if _, err := crsub.NextMsg(100 * time.Millisecond); err == nil { + t.Fatal("Close protocol should not have been sent") + } +} + func TestDoubleClose(t *testing.T) { s := RunServer(clusterName) defer s.Shutdown() @@ -2737,6 +2814,67 @@ func TestSubTimeout(t *testing.T) { } } +func TestSubCloseError(t *testing.T) { + s := RunServer(clusterName) + defer s.Shutdown() + + sc := NewDefaultConnection(t) + defer sc.Close() + + sub, err := sc.Subscribe("foo", func(_ *Msg) {}) + if err != nil { + t.Fatalf("Error on sub: %v", err) + } + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatalf("Error on connect: ") + } + defer nc.Close() + + scc := sc.(*conn) + scc.Lock() + closeSubj := scc.subCloseRequests + // alter the subCloseRequests so that the sub close fails + scc.subCloseRequests = "dummy" + scc.Unlock() + + // Now setup a subscription to check if library is sending the close protocol. + crsub, err := nc.SubscribeSync(closeSubj) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + // Flush since this is a different connection than the one used by sc. + nc.Flush() + + if err := sub.Close(); err == nil { + t.Fatal("Expected error, got none") + } + + // Fix close subject + scc.Lock() + scc.subCloseRequests = closeSubj + scc.Unlock() + + // Try again, it should work + if err := sub.Close(); err != nil { + t.Fatalf("Error on close: %v", err) + } + // Check protocol was sent + if _, err := crsub.NextMsg(time.Second); err != nil { + t.Fatalf("Did not get close protocol: %v", err) + } + + // Now another call to Close() should return BadSubscription + if err := sub.Close(); err != ErrBadSubscription { + t.Fatalf("Expected %v, got %v", ErrBadSubscription, err) + } + // And this time, no protocol should be sent + if _, err := crsub.NextMsg(100 * time.Millisecond); err == nil { + t.Fatal("Close protocol should not have been sent") + } +} + func TestNatsOptions(t *testing.T) { snopts := natsd.DefaultTestOptions snopts.Username = "foo" diff --git a/sub.go b/sub.go index e612c39..f078e75 100644 --- a/sub.go +++ b/sub.go @@ -101,6 +101,12 @@ type subscription struct { inboxSub *nats.Subscription opts SubscriptionOptions cb MsgHandler + // closed indicate that sub.Close() was invoked, but fullyClosed + // is only set if the close/unsub protocol was successful. This + // allow the user to be able to call sub.Close() several times + // in case an error is returned. + closed bool + fullyClosed bool } // SubscriptionOption is a function on the options for a subscription. @@ -414,15 +420,22 @@ func (sub *subscription) SetPendingLimits(msgLimit, bytesLimit int) error { // given boolean. func (sub *subscription) closeOrUnsubscribe(doClose bool) error { sub.Lock() - sc := sub.sc - if sc == nil { - // Already closed. + // If we are fully closed, return error indicating that the + // subscription is invalid. Note that conn.Close() in this case + // returns nil, but keeping behavior same so we don't have breaking change. + if sub.fullyClosed { sub.Unlock() return ErrBadSubscription } - sub.sc = nil - sub.inboxSub.Unsubscribe() - sub.inboxSub = nil + wasClosed := sub.closed + // If this is the very first Close() call, do some internal cleanup, + // otherwise, simply send the close protocol message. + if !wasClosed { + sub.closed = true + sub.inboxSub.Unsubscribe() + sub.inboxSub = nil + } + sc := sub.sc sub.Unlock() sc.Lock() @@ -430,8 +443,9 @@ func (sub *subscription) closeOrUnsubscribe(doClose bool) error { sc.Unlock() return ErrConnectionClosed } - - delete(sc.subMap, sub.inbox) + if !wasClosed { + delete(sc.subMap, sub.inbox) + } reqSubject := sc.unsubRequests if doClose { reqSubject = sc.subCloseRequests @@ -464,10 +478,13 @@ func (sub *subscription) closeOrUnsubscribe(doClose bool) error { if err := r.Unmarshal(reply.Data); err != nil { return err } + // As long as we got a valid response, we consider the subscription fully closed. + sub.Lock() + sub.fullyClosed = true + sub.Unlock() if r.Error != "" { return errors.New(r.Error) } - return nil } @@ -493,13 +510,14 @@ func (msg *Msg) Ack() error { ackSubject := sub.ackInbox isManualAck := sub.opts.ManualAcks sc := sub.sc + closed := sub.closed sub.RUnlock() // Check for error conditions. if !isManualAck { return ErrManualAck } - if sc == nil { + if closed { return ErrBadSubscription }