Skip to content

Commit

Permalink
Merge bcb4320 into 18ec07d
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovic authored Feb 5, 2019
2 parents 18ec07d + bcb4320 commit 044aa72
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 34 deletions.
50 changes: 47 additions & 3 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ import (
"github.com/nats-io/nats-streaming-server/stores"
)

const (
maxKnownInvalidConns = 256
pruneKnownInvalidConns = 32
)

// This is a proxy to the store interface.
type clientStore struct {
sync.RWMutex
clients map[string]*client
connIDs map[string]*client
waitOnRegister map[string]chan struct{}
knownInvalid map[string]struct{}
store stores.Store
}

Expand All @@ -43,9 +49,10 @@ type client struct {
// newClientStore creates a new clientStore instance using `store` as the backing storage.
func newClientStore(store stores.Store) *clientStore {
return &clientStore{
clients: make(map[string]*client),
connIDs: make(map[string]*client),
store: store,
clients: make(map[string]*client),
connIDs: make(map[string]*client),
knownInvalid: make(map[string]struct{}),
store: store,
}
}

Expand Down Expand Up @@ -73,7 +80,10 @@ func (cs *clientStore) register(info *spb.ClientInfo) (*client, error) {
cs.clients[c.info.ID] = c
if len(c.info.ConnID) > 0 {
cs.connIDs[string(c.info.ConnID)] = c
// Delete from being possibly in knownInvalid
delete(cs.knownInvalid, string(c.info.ConnID))
}
delete(cs.knownInvalid, info.ID)
if cs.waitOnRegister != nil {
ch := cs.waitOnRegister[c.info.ID]
if ch != nil {
Expand Down Expand Up @@ -132,6 +142,10 @@ func (cs *clientStore) isValidWithTimeout(ID string, connID []byte, timeout time
cs.Unlock()
return true
}
if cs.knownToBeInvalid(ID, connID) {
cs.Unlock()
return false
}
if cs.waitOnRegister == nil {
cs.waitOnRegister = make(map[string]chan struct{})
}
Expand All @@ -145,11 +159,41 @@ func (cs *clientStore) isValidWithTimeout(ID string, connID []byte, timeout time
// We timed out, remove the entry in the map
cs.Lock()
delete(cs.waitOnRegister, ID)
cs.addToKnownInvalid(ID, connID)
cs.Unlock()
return false
}
}

func (cs *clientStore) addToKnownInvalid(ID string, connID []byte) {
cID := string(connID)
if cID == "" {
cID = ID
}
cs.knownInvalid[cID] = struct{}{}
if len(cs.knownInvalid) >= maxKnownInvalidConns {
r := 0
for id := range cs.knownInvalid {
if id != cID {
delete(cs.knownInvalid, id)
if r++; r > pruneKnownInvalidConns {
break
}
}
}
}
}

func (cs *clientStore) knownToBeInvalid(ID string, connID []byte) bool {
var invalid bool
if len(connID) > 0 {
_, invalid = cs.knownInvalid[string(connID)]
} else {
_, invalid = cs.knownInvalid[ID]
}
return invalid
}

// Lookup client by ConnID if not nil, otherwise by clientID.
// Assume at least clientStore RLock is held on entry.
func (cs *clientStore) lookupByConnIDOrID(ID string, connID []byte) *client {
Expand Down
143 changes: 115 additions & 28 deletions server/partitions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,31 @@ func checkWaitOnRegisterMap(t tLogger, s *StanServer, size int) {
}
}

func checkKnownInvalidMap(t *testing.T, s *StanServer, size int, id string) {
t.Helper()
waitFor(t, clientCheckTimeout+50*time.Millisecond, 15*time.Millisecond, func() error {
var ki bool
s.clients.RLock()
if id != "" {
_, ki = s.clients.knownInvalid[id]
}
mlen := len(s.clients.knownInvalid)
s.clients.RUnlock()
if size != mlen {
return fmt.Errorf("expected map size to be %v, got %v", size, mlen)
}
if size > 0 && !ki {
return fmt.Errorf("expected %q to be in the map, it was not", id)
}
return nil
})
}

func TestPartitionsRaceOnPub(t *testing.T) {
setPartitionsVarsForTest()
defer resetDefaultPartitionsVars()

clientCheckTimeout = 150 * time.Millisecond
clientCheckTimeout = 250 * time.Millisecond
defer func() { clientCheckTimeout = defaultClientCheckTimeout }()

opts := GetDefaultOptions()
Expand Down Expand Up @@ -690,11 +710,6 @@ func TestPartitionsRaceOnPub(t *testing.T) {
pubReq := &pb.PubMsg{ClientID: clientName, Subject: "foo", Data: []byte("hello")}
pubNuid := nuid.New()

pubSub, err := nc.SubscribeSync(nats.NewInbox())
if err != nil {
t.Fatalf("Error creating sub on pub response: %v", err)
}

// Repeat the test, because even with bug, it would be possible
// that the connection request is still processed first, which
// would make the test pass.
Expand All @@ -716,23 +731,40 @@ func TestPartitionsRaceOnPub(t *testing.T) {
}
// Ensure that the notification map has been created, but is empty.
checkWaitOnRegisterMap(t, s, 0)
checkKnownInvalidMap(t, s, 1, clientName)

// Now resend a message, but this time don't wait for the response here,
// instead connect, which should cause the PubMsg to be processed correctly.
if err := nc.PublishRequest(pubSubj, pubSub.Subject, pubBytes); err != nil {
t.Fatalf("Error sending PubMsg: %v", err)
// Now resend a message and it should fail quickly since server
// should have recorded this as an invalid client ID.
start := time.Now()
resp, err = nc.Request(pubSubj, pubBytes, clientCheckTimeout)
if err != nil {
t.Fatalf("Error on request: %v", err)
}
checkWaitOnRegisterMap(t, s, 1)
dur := time.Since(start)
pubResp = &pb.PubAck{}
pubResp.Unmarshal(resp.Data)
if pubResp.Error != ErrInvalidPubReq.Error() {
t.Fatalf("Expected error %q, got %q", ErrInvalidPubReq, pubResp.Error)
}
// It is expected to have taken less than the clientCheckTimeout this time
if dur > clientCheckTimeout {
t.Fatalf("Second failure should not have take longer than %v, took %v", clientCheckTimeout, dur)
}
checkKnownInvalidMap(t, s, 1, clientName)

// Now connect
sc, err := stan.Connect(clusterName, clientName, stan.NatsConn(nc))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer sc.Close()
// This should clear the knownInvalid map
checkKnownInvalidMap(t, s, 0, "")

// Now we should get the OK for the PubMsg.
resp, err = pubSub.NextMsg(clientCheckTimeout + 100*time.Millisecond)
// Verify that we can send ok now..
resp, err = nc.Request(pubSubj, pubBytes, clientCheckTimeout)
if err != nil {
t.Fatalf("Error waiting for pub response: %v", err)
t.Fatalf("Error on request: %v", err)
}
pubResp = &pb.PubAck{}
pubResp.Unmarshal(resp.Data)
Expand All @@ -748,7 +780,7 @@ func TestPartitionsRaceOnSub(t *testing.T) {
setPartitionsVarsForTest()
defer resetDefaultPartitionsVars()

clientCheckTimeout = 150 * time.Millisecond
clientCheckTimeout = 250 * time.Millisecond
defer func() { clientCheckTimeout = defaultClientCheckTimeout }()

opts := GetDefaultOptions()
Expand All @@ -771,11 +803,6 @@ func TestPartitionsRaceOnSub(t *testing.T) {
subSubj := s.info.Subscribe
subReq := &pb.SubscriptionRequest{ClientID: clientName, Subject: "foo", AckWaitInSecs: 30, MaxInFlight: 1}

subSub, err := nc.SubscribeSync(nats.NewInbox())
if err != nil {
t.Fatalf("Error creating sub on sub response: %v", err)
}

// Repeat the test, because even with bug, it would be possible
// that the connection request is still processed first, which
// would make the test pass.
Expand All @@ -797,23 +824,40 @@ func TestPartitionsRaceOnSub(t *testing.T) {
}
// Ensure that the notification map has been created, but is empty.
checkWaitOnRegisterMap(t, s, 0)
checkKnownInvalidMap(t, s, 1, clientName)

// Now resend the subscription, but this time don't wait for the response here,
// instead connect, which should cause the SubscriptionRequest to be processed correctly.
if err := nc.PublishRequest(subSubj, subSub.Subject, subBytes); err != nil {
t.Fatalf("Error sending PubMsg: %v", err)
// Now resend the subscription and it should fail quickly since server
// should have recorded this as an invalid client ID.
start := time.Now()
resp, err = nc.Request(subSubj, subBytes, clientCheckTimeout)
if err != nil {
t.Fatalf("Error on request: %v", err)
}
checkWaitOnRegisterMap(t, s, 1)
dur := time.Since(start)
subResp = &pb.SubscriptionResponse{}
subResp.Unmarshal(resp.Data)
if subResp.Error != ErrInvalidSubReq.Error() {
t.Fatalf("Expected error %q, got %q", ErrInvalidSubReq, subResp.Error)
}
// It is expected to have taken less than the clientCheckTimeout this time
if dur > clientCheckTimeout {
t.Fatalf("Second failure should not have take longer than %v, took %v", clientCheckTimeout, dur)
}
checkKnownInvalidMap(t, s, 1, clientName)

// Now connect
sc, err := stan.Connect(clusterName, clientName, stan.NatsConn(nc))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer sc.Close()
// SHould be removed from map
checkKnownInvalidMap(t, s, 0, "")

// Now we should get the OK for the PubMsg.
resp, err = subSub.NextMsg(clientCheckTimeout + 100*time.Millisecond)
// Now we should get the OK for the subscription.
resp, err = nc.Request(subSubj, subBytes, clientCheckTimeout)
if err != nil {
t.Fatalf("Error waiting for pub response: %v", err)
t.Fatalf("Error on request: %v", err)
}
subResp = &pb.SubscriptionResponse{}
subResp.Unmarshal(resp.Data)
Expand Down Expand Up @@ -892,3 +936,46 @@ func TestPartitionsClientPings(t *testing.T) {

testClientPings(t, s1)
}

func TestPartitionsCleanInvalidConns(t *testing.T) {
setPartitionsVarsForTest()
defer resetDefaultPartitionsVars()

clientCheckTimeout = 1 * time.Millisecond
defer func() { clientCheckTimeout = defaultClientCheckTimeout }()

opts := GetDefaultOptions()
opts.Partitioning = true
opts.AddPerChannel("foo", &stores.ChannelLimits{})
s := runServerWithOpts(t, opts, nil)
defer s.Shutdown()

// Create a direct NATS connection
nc, err := nats.Connect(nats.DefaultURL)
if err != nil {
t.Fatalf("Unable to connect: %v", err)
}
defer nc.Close()

total := 300

pubSubj := fmt.Sprintf("%s.foo", s.info.Publish)
for i := 0; i < total; i++ {
pubReq := &pb.PubMsg{
ClientID: fmt.Sprintf("%s%d", clientName, i),
Subject: "foo",
Data: []byte("hello"),
Guid: fmt.Sprintf("guid%d", i),
}
b, _ := pubReq.Marshal()
nc.Request(pubSubj, b, time.Second)
}
// The map should not have grown to 300
s.clients.RLock()
mlen := len(s.clients.knownInvalid)
s.clients.RUnlock()

if mlen > maxKnownInvalidConns {
t.Fatalf("Should not be more than %v, got %v", maxKnownInvalidConns, mlen)
}
}
10 changes: 7 additions & 3 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ const (
// To prevent that, when checking if a client exists, in this particular
// mode we will possibly wait to be notified when the client has been
// registered. This is the default duration for this wait.
defaultClientCheckTimeout = 4 * time.Second
defaultClientCheckTimeout = time.Second

// Interval at which server goes through list of subscriptions with
// pending sent/ack operations that needs to be replicated.
Expand Down Expand Up @@ -2931,11 +2931,15 @@ func (s *StanServer) processClientPublish(m *nats.Msg) {
return
}

if s.debug {
s.log.Tracef("[Client:%s] Received message from publisher subj=%s guid=%s", pm.ClientID, pm.Subject, pm.Guid)
}

// Check if the client is valid. We do this after the clustered check so
// that only the leader performs this check.
valid := false
if s.partitions != nil || s.isClustered {
// In partitioning or clustering it is possible that we get there
if s.partitions != nil {
// In partitioning mode it is possible that we get there
// before the connect request is processed. If so, make sure we wait
// for conn request to be processed first. Check clientCheckTimeout
// doc for details.
Expand Down

0 comments on commit 044aa72

Please sign in to comment.