Skip to content

Commit

Permalink
Merge pull request #1257 from nats-io/fix_1256
Browse files Browse the repository at this point in the history
[FIXED] RAFT subscriptions leak in failed connect
  • Loading branch information
kozlovic committed Jul 6, 2022
2 parents cc8b894 + f36eb36 commit 3520b47
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 30 deletions.
116 changes: 111 additions & 5 deletions server/clustering_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4739,7 +4739,7 @@ type myProxy struct {
connectTo string
addr string
c net.Conn
doPause bool
doPause time.Duration
}

func newProxy(connectTo string) (*myProxy, error) {
Expand Down Expand Up @@ -4776,8 +4776,8 @@ func (p *myProxy) proxy(c net.Conn) {
p.Lock()
pause := p.doPause
p.Unlock()
if pause {
time.Sleep(10 * time.Millisecond)
if pause > 0 {
time.Sleep(pause)
} else {
break
}
Expand Down Expand Up @@ -4821,15 +4821,19 @@ func (p *myProxy) getAddr() string {
}

func (p *myProxy) pause() {
p.pauseFor(10 * time.Millisecond)
}

func (p *myProxy) pauseFor(dur time.Duration) {
p.Lock()
defer p.Unlock()
p.doPause = true
p.doPause = dur
}

func (p *myProxy) resume() {
p.Lock()
defer p.Unlock()
p.doPause = false
p.doPause = 0
}

func (p *myProxy) close() {
Expand Down Expand Up @@ -8939,3 +8943,105 @@ func TestClusteringDoNotReportSubCloseMissingSubjectOnReplay(t *testing.T) {
// OK
}
}

func TestClusteringRaftSubsAndConnsLeak(t *testing.T) {
for _, test := range []struct {
name string
nodesConnections bool
}{
{"subs", false},
{"conns", true},
} {
t.Run(test.name, func(t *testing.T) {
cleanupDatastore(t)
defer cleanupDatastore(t)
cleanupRaftLog(t)
defer cleanupRaftLog(t)

// For this test, we need 2 NATS Servers
do := natsdTest.DefaultTestOptions
ns1Opts := do.Clone()
ns1Opts.Cluster.Name = "abc"
ns1Opts.Cluster.Host = "127.0.0.1"
ns1Opts.Cluster.Port = -1
ns1 := natsdTest.RunServer(ns1Opts)
defer ns1.Shutdown()

// Start a proxy to which ns2 will connect to.
// We want the two to be split at one point.
proxy, err := newProxy(fmt.Sprintf("%s:%d", ns1Opts.Cluster.Host, ns1Opts.Cluster.Port))
if err != nil {
t.Fatalf("Error creating proxy: %v", err)
}
defer proxy.close()
// Wait for it to be ready to accept connection.
time.Sleep(200 * time.Millisecond)

ns2Opts := do.Clone()
ns2Opts.Port = 4223
ns2Opts.Cluster.Name = "abc"
ns2Opts.Cluster.Host = "127.0.0.1"
ns2Opts.Cluster.Port = -1
ns2Opts.Routes = natsd.RoutesFromStr(proxy.getAddr())
ns2 := natsdTest.RunServer(ns2Opts)
defer ns2.Shutdown()

// Configure first server
s1sOpts := getTestDefaultOptsForClustering("a", true)
s1sOpts.Clustering.NodesConnections = test.nodesConnections
s1 := runServerWithOpts(t, s1sOpts, nil)
defer s1.Shutdown()

// Configure second server.
s2sOpts := getTestDefaultOptsForClustering("b", false)
s2sOpts.NATSServerURL = "nats://127.0.0.1:4223"
s2sOpts.Clustering.NodesConnections = test.nodesConnections
// Make it connect to ns2
s2 := runServerWithOpts(t, s2sOpts, ns2Opts)
defer s2.Shutdown()

// Configure a third server.
s3sOpts := getTestDefaultOptsForClustering("c", false)
s3sOpts.NATSServerURL = "nats://127.0.0.1:4223"
s3sOpts.Clustering.NodesConnections = test.nodesConnections
// Make it connect to ns2
s3 := runServerWithOpts(t, s3sOpts, ns2Opts)
defer s3.Shutdown()

getLeader(t, 10*time.Second, s1, s2, s3)

proxy.pauseFor(500 * time.Millisecond)

time.Sleep(2 * time.Second)

proxy.resume()

time.Sleep(time.Second)

conns, err := ns1.Connz(&natsd.ConnzOptions{Subscriptions: true})
if err != nil {
t.Fatalf("Error getting connz: %v", err)
}
num := uint32(0)
for _, conn := range conns.Conns {
if test.nodesConnections {
if strings.Contains(conn.Name, "-a-to-") {
num++
}
} else if strings.HasSuffix(conn.Name, "-raft") {
for _, sub := range conn.Subs {
if strings.Contains(sub, ".request.") {
num++
}
}
break
}
}
// If "num" is greater than, say, 10 (for any mode), then it is indicative
// of the issue.
if num >= 10 {
t.Fatalf("Unexpected number of subs/conns: %v", num)
}
})
}
}
108 changes: 83 additions & 25 deletions server/raft_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (
natsRequestInbox = "raft.%s.request.%s"
timeoutForDialAndFlush = 2 * time.Second
natsLogAppName = "raft-nats"
dialInProgress = "1"
dialComplete = "2"
)

var errTransportShutdown = errors.New("raft-nats: transport is being shutdown")
Expand All @@ -57,6 +59,7 @@ func (n natsAddr) String() string {
type connectRequestProto struct {
ID string `json:"id"`
Inbox string `json:"inbox"`
Check string `json:"check"`
}

type connectResponseProto struct {
Expand Down Expand Up @@ -238,7 +241,7 @@ func (n *natsConn) close(signalRemote bool) error {
return nil
}

if signalRemote {
if signalRemote && n.outbox != "" {
// Send empty message to signal EOF for a graceful disconnect. Not
// concerned with errors here as this is best effort.
n.conn.Publish(n.outbox, nil)
Expand All @@ -250,11 +253,14 @@ func (n *natsConn) close(signalRemote bool) error {
// check for sub != nil because this can be called during setup where
// sub has not been attached.
var err error
if n.streamConn {
if n.sub != nil {
var inbox string
if n.sub != nil {
inbox = n.sub.Subject
if n.streamConn {
err = n.sub.Unsubscribe()
}
} else {
}
if !n.streamConn {
n.conn.Close()
}

Expand All @@ -266,6 +272,7 @@ func (n *natsConn) close(signalRemote bool) error {

stream.mu.Lock()
delete(stream.conns, n)
stream.dialInboxes.Delete(inbox)
stream.mu.Unlock()

return err
Expand Down Expand Up @@ -311,6 +318,10 @@ type natsStreamLayer struct {
// This is the timeout we will use for flush and dial (request timeout),
// not the timeout that RAFT will use to call SetDeadline.
dfTimeout time.Duration
// This is for dial connections so that accept side can check if the
// dial side has timed-out before receiving the response from accept-side.
csub *nats.Subscription
dialInboxes sync.Map
}

func newNATSStreamLayer(id string, conn *nats.Conn, logger hclog.Logger, timeout time.Duration, makeConn natsRaftConnCreator) (*natsStreamLayer, error) {
Expand All @@ -322,19 +333,36 @@ func newNATSStreamLayer(id string, conn *nats.Conn, logger hclog.Logger, timeout
conns: map[*natsConn]struct{}{},
dfTimeout: timeoutForDialAndFlush,
}
csub, err := conn.Subscribe(nats.NewInbox(), func(m *nats.Msg) {
di := string(m.Data)
if len(di) == 0 {
return
}
var resp string
if v, ok := n.dialInboxes.Load(di); ok {
resp = v.(string)
}
m.Respond([]byte(resp))
})
if err != nil {
return nil, err
}
// Could be the case in tests...
if timeout < n.dfTimeout {
n.dfTimeout = timeout
}
sub, err := conn.SubscribeSync(fmt.Sprintf(natsConnectInbox, id))
if err != nil {
csub.Unsubscribe()
return nil, err
}
if err := conn.FlushTimeout(n.dfTimeout); err != nil {
csub.Unsubscribe()
sub.Unsubscribe()
return nil, err
}
n.sub = sub
n.csub = csub
return n, nil
}

Expand Down Expand Up @@ -373,6 +401,7 @@ func (n *natsStreamLayer) Dial(address raft.ServerAddress, timeout time.Duration
connect := &connectRequestProto{
ID: n.localAddr.String(),
Inbox: fmt.Sprintf(natsRequestInbox, n.localAddr.String(), nats.NewInbox()),
Check: n.csub.Subject,
}
data, err := json.Marshal(connect)
if err != nil {
Expand All @@ -385,24 +414,16 @@ func (n *natsStreamLayer) Dial(address raft.ServerAddress, timeout time.Duration
timeout = n.dfTimeout
}

// Make connect request to peer.
msg, err := n.conn.Request(fmt.Sprintf(natsConnectInbox, address), data, timeout)
if err != nil {
return nil, err
}
var resp connectResponseProto
if err := json.Unmarshal(msg.Data, &resp); err != nil {
return nil, err
}

// Success, so now create a new NATS connection...
// Create a new NATS connection...
peerConn, err := n.newNATSConn(string(address))
if err != nil {
return nil, fmt.Errorf("raft-nats: unable to create connection to %q: %v", string(address), err)
}

// Setup inbox.
peerConn.mu.Lock()
// Need to prepare the subscription before sending the request
// in case the accept-side immediately closes the connection.
sub, err := peerConn.conn.Subscribe(connect.Inbox, peerConn.onMsg)
if err != nil {
peerConn.mu.Unlock()
Expand All @@ -411,13 +432,31 @@ func (n *natsStreamLayer) Dial(address raft.ServerAddress, timeout time.Duration
}
sub.SetPendingLimits(-1, -1)
peerConn.sub = sub
peerConn.outbox = resp.Inbox
peerConn.mu.Unlock()

if err := peerConn.conn.FlushTimeout(timeout); err != nil {
n.dialInboxes.Store(connect.Inbox, dialInProgress)

// Make connect request to peer.
msg, err := n.conn.Request(fmt.Sprintf(natsConnectInbox, address), data, timeout)
if err != nil {
n.dialInboxes.Delete(connect.Inbox)
peerConn.Close()
return nil, err
}

var resp connectResponseProto
// Decode the response
if err := json.Unmarshal(msg.Data, &resp); err != nil {
n.dialInboxes.Delete(connect.Inbox)
peerConn.Close()
return nil, err
}
// Keep track of the remote's inbox
peerConn.mu.Lock()
peerConn.outbox = resp.Inbox
peerConn.mu.Unlock()

n.dialInboxes.Store(connect.Inbox, dialComplete)

n.mu.Lock()
if n.closed {
Expand All @@ -438,20 +477,17 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) {
return nil, err
}
if msg.Reply == "" {
n.logger.Error("Invalid connect message (missing reply inbox)")
continue
return nil, fmt.Errorf("invalid connect message (missing reply inbox)")
}

var connect connectRequestProto
if err := json.Unmarshal(msg.Data, &connect); err != nil {
n.logger.Error("Invalid connect message (invalid data)")
continue
return nil, fmt.Errorf("invalid connect message: %v", err)
}

peerConn, err := n.newNATSConn(connect.ID)
if err != nil {
n.logger.Error("Unable to create connection to %q: %v", connect.ID, err)
continue
return nil, fmt.Errorf("unable to create connection to %s: %v", connect.ID, err)
}

// Setup inbox for peer.
Expand All @@ -460,9 +496,8 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) {
sub, err := peerConn.conn.Subscribe(inbox, peerConn.onMsg)
if err != nil {
peerConn.mu.Unlock()
n.logger.Error("Failed to create inbox for remote peer", "error", err)
peerConn.Close()
continue
return nil, fmt.Errorf("unable to create inbox for remote peer %s: %v", connect.ID, err)
}
sub.SetPendingLimits(-1, -1)
peerConn.outbox = connect.Inbox
Expand Down Expand Up @@ -493,6 +528,29 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) {
peerConn.Close()
continue
}
if connect.Check != "" {
var retryAccept bool
for {
resp, err := n.conn.Request(connect.Check, []byte(connect.Inbox), n.dfTimeout)
if err != nil {
retryAccept = true
break
}
if s := string(resp.Data); s == dialInProgress {
continue
} else if s == dialComplete {
break
} else {
n.logger.Warn(fmt.Sprintf("retrying because dialed connection from remote %s is no longer valid", connect.ID))
retryAccept = true
break
}
}
if retryAccept {
peerConn.Close()
continue
}
}
n.mu.Lock()
if n.closed {
n.mu.Unlock()
Expand Down

0 comments on commit 3520b47

Please sign in to comment.