From c6fc0a6f2e7a7975994c83bb80082dc53e2ddc5a Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 13 Dec 2016 12:37:51 -0700 Subject: [PATCH] [ADDED] Rate limiting Global configuration to limit per-client ingress message rate. Can be rate_msgs and/or rate_bytes. Resolves #346 --- README.md | 4 + main.go | 2 + server/client.go | 49 +++++++++++- server/client_test.go | 159 ++++++++++++++++++++++++++++++++++++++- server/configs/test.conf | 4 + server/opts.go | 12 +++ server/opts_test.go | 6 ++ 7 files changed, 232 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ee1ddbb29eb..8e956df7e48 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,10 @@ max_payload: 65536 # slow consumer threshold max_pending_size: 10000000 + +# rate limit +rate_msgs: 100 +rate_bytes: 100KB ``` ## Variables diff --git a/main.go b/main.go index e56318c71c6..f9385ef18ad 100644 --- a/main.go +++ b/main.go @@ -116,6 +116,8 @@ func main() { flag.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file.") flag.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate.") flag.StringVar(&opts.TLSCaCert, "tlscacert", "", "Client certificate CA for verification.") + flag.IntVar(&opts.RateMaxMsgs, "rate_msgs", 0, "Per client maximum rate of messages per second") + flag.Int64Var(&opts.RateMaxBytes, "rate_bytes", 0, "Per client maximum rate of bytes per second") flag.Usage = usage diff --git a/server/client.go b/server/client.go index fe266914695..2877403d3ab 100644 --- a/server/client.go +++ b/server/client.go @@ -107,6 +107,11 @@ type client struct { wfc int msgb [msgScratchSize]byte last time.Time + rate bool + ram int + rab int64 + rlc time.Time + rq chan struct{} parseState route *route @@ -192,6 +197,9 @@ func (c *client) initClient() { // after we process inbound msgs from our own connection. c.pcd = make(map[*client]struct{}) + // Channel to kick out a client from a sleep due to rate limit + c.rq = make(chan struct{}, 1) + // snapshot the string version of the connection conn := "-" if ip, ok := c.nc.(*net.TCPConn); ok { @@ -244,6 +252,10 @@ func (c *client) readLoop() { c.mu.Lock() nc := c.nc s := c.srv + c.rate = s.opts.RateMaxMsgs > 0 || s.opts.RateMaxBytes > 0 + if c.rate { + c.rlc = time.Now() + } defer s.grWG.Done() c.mu.Unlock() @@ -980,8 +992,9 @@ func (c *client) processMsg(msg []byte) { // Update statistics // The msg includes the CR_LF, so pull back out for accounting. - c.cache.inMsgs += 1 - c.cache.inBytes += len(msg) - LEN_CR_LF + msgSize := len(msg) - LEN_CR_LF + c.cache.inMsgs++ + c.cache.inBytes += msgSize if c.trace { c.traceMsg(msg) @@ -1032,6 +1045,31 @@ func (c *client) processMsg(msg []byte) { } } + if c.rate { + now := time.Now() + delta := now.Sub(c.rlc) + if delta < time.Second { + c.ram++ + c.rab += int64(msgSize) + if (srv.opts.RateMaxMsgs > 0 && c.ram >= srv.opts.RateMaxMsgs) || + (srv.opts.RateMaxBytes > 0 && c.rab >= srv.opts.RateMaxBytes) { + select { + case <-c.rq: + // Stop rate limiting in case processMsg is called again. + // This will allow for fast drainage of the socket. + c.mu.Lock() + c.rate = false + c.mu.Unlock() + return + case <-time.After(time.Second - delta): + } + c.ram, c.rab, c.rlc = 0, int64(0), time.Now() + } + } else { + c.ram, c.rab, c.rlc = 1, int64(msgSize), now + } + } + if c.opts.Verbose { c.sendOK() } @@ -1254,6 +1292,13 @@ func (c *client) clearConnection() { } c.nc.Close() c.nc.SetWriteDeadline(time.Time{}) + if c.rate { + // Kick out processMsg() if it is in a sleep (doing rate limiting). + select { + case c.rq <- struct{}{}: + default: + } + } } func (c *client) typeString() string { diff --git a/server/client_test.go b/server/client_test.go index 43a144bf41a..d8067c5cad0 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -5,18 +5,18 @@ package server import ( "bufio" "bytes" + "crypto/tls" "encoding/json" "fmt" "net" "reflect" "regexp" + "runtime" "strings" "sync" "testing" "time" - "crypto/tls" - "github.com/nats-io/go-nats" ) @@ -737,3 +737,158 @@ func TestTLSCloseClientConnection(t *testing.T) { cli.closeConnection() ch <- true } + +func TestRateLimiting(t *testing.T) { + var nc *nats.Conn + msg := []byte("hello") + ch := make(chan struct{}, 1) + errCh := make(chan error, 1) + + createConnFunc := func() *nats.Conn { + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", + DefaultOptions.Host, DefaultOptions.Port), + nats.NoReconnect()) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + return nc + } + sendFunc := func() { + for i := 0; i < 300; i++ { + if err := nc.Publish("foo", msg); err != nil { + errCh <- fmt.Errorf("Error on publish: %v", err) + return + } + } + nc.Flush() + ch <- struct{}{} + } + checkRateFunc := func(rateShouldBeLimited bool) { + select { + case err := <-errCh: + stackFatalf(t, err.Error()) + case <-ch: + if rateShouldBeLimited { + stackFatalf(t, "Rate should have been limited") + } + case <-time.After(time.Second): + if rateShouldBeLimited { + nc.Close() + <-ch + // Consume possible error + select { + case <-errCh: + default: + } + } else { + stackFatalf(t, "Rate should not have been limited") + } + } + if !rateShouldBeLimited { + nc.Close() + } + } + + // No rate limiting + s := RunServer(nil) + defer s.Shutdown() + nc = createConnFunc() + go sendFunc() + checkRateFunc(false) + s.Shutdown() + + // Rate limited to 100 msgs/sec + opts := DefaultOptions + opts.RateMaxMsgs = 100 + s = RunServer(&opts) + defer s.Shutdown() + nc = createConnFunc() + go sendFunc() + checkRateFunc(true) + s.Shutdown() + + // Rate limited to 500 bytes/sec + opts = DefaultOptions + opts.RateMaxBytes = 500 + s = RunServer(&opts) + defer s.Shutdown() + nc = createConnFunc() + go sendFunc() + checkRateFunc(true) + s.Shutdown() + + // Check that we can kick out processMsg from a sleep + opts = DefaultOptions + opts.RateMaxMsgs = 1 + s = RunServer(&opts) + defer s.Shutdown() + nc = createConnFunc() + defer nc.Close() + nc.Flush() + // There should be 1 client only, with CID==1 + s.mu.Lock() + cli := s.clients[1] + s.mu.Unlock() + // Since we set the rate to 1, sending the message below + // should processMsg to sleep + start := time.Now() + if err := nc.Publish("foo", msg); err != nil { + t.Fatalf("Error on publish: %v", err) + } + // Check that we are in processMsg() + buf := make([]byte, 10000) + timeout := start.Add(time.Second) + inProcessMsg := false + for time.Now().Before(timeout) { + n := runtime.Stack(buf, true) + if strings.Contains(string(buf[:n]), "processMsg") { + inProcessMsg = true + break + } + time.Sleep(10 * time.Millisecond) + } + if !inProcessMsg { + t.Fatal("Is not in processMsg") + } + // Clear the connection - note that nc.Close() is not + // helping since the connection is still half opened, + // which means server can still read from socket. + cli.mu.Lock() + cli.clearConnection() + cli.mu.Unlock() + // Check the duration, it should be less than a second + dur := time.Now().Sub(start) + if dur >= 990*time.Millisecond { + t.Fatalf("May not have been kicked out from sleep") + } + s.Shutdown() + + // Check counts are cleared when crossing over the 1 second period + opts = DefaultOptions + opts.RateMaxMsgs = 10000 + s = RunServer(&opts) + defer s.Shutdown() + nc = createConnFunc() + defer nc.Close() + nc.Flush() + // There should be 1 client only, with CID==1 + s.mu.Lock() + cli = s.clients[1] + s.mu.Unlock() + if err := nc.Publish("foo", msg); err != nil { + t.Fatalf("Error on publish: %v", err) + } + time.Sleep(1100 * time.Millisecond) + if err := nc.Publish("foo", msg); err != nil { + t.Fatalf("Error on publish: %v", err) + } + nc.Flush() + cli.mu.Lock() + accumulatedMsgs := cli.ram + cli.mu.Unlock() + // Should be 1 + if accumulatedMsgs != 1 { + t.Fatalf("Unexpected accumulated messages: %v", accumulatedMsgs) + } + s.Shutdown() +} diff --git a/server/configs/test.conf b/server/configs/test.conf index 689a815c8c8..91d22ccba43 100644 --- a/server/configs/test.conf +++ b/server/configs/test.conf @@ -40,3 +40,7 @@ max_pending_size: 10000000 # ping interval and no pong threshold ping_interval: 60 ping_max: 3 + +# rate limit +rate_msgs: 1000000 +rate_bytes: 100MB diff --git a/server/opts.go b/server/opts.go index 455d496eb3f..f6e8285ad38 100644 --- a/server/opts.go +++ b/server/opts.go @@ -81,6 +81,8 @@ type Options struct { TLSKey string `json:"-"` TLSCaCert string `json:"-"` TLSConfig *tls.Config `json:"-"` + RateMaxMsgs int `json:"-"` + RateMaxBytes int64 `json:"-"` } // Configuration file authorization section. @@ -227,6 +229,10 @@ func ProcessConfigFile(configFile string) (*Options, error) { return nil, err } opts.TLSTimeout = tc.Timeout + case "rate_msgs": + opts.RateMaxMsgs = int(v.(int64)) + case "rate_bytes": + opts.RateMaxBytes = v.(int64) } } return opts, nil @@ -650,6 +656,12 @@ func MergeOptions(fileOpts, flagOpts *Options) *Options { if flagOpts.RoutesStr != "" { mergeRoutes(&opts, flagOpts) } + if flagOpts.RateMaxMsgs != 0 { + opts.RateMaxMsgs = flagOpts.RateMaxMsgs + } + if flagOpts.RateMaxBytes != 0 { + opts.RateMaxBytes = flagOpts.RateMaxBytes + } return &opts } diff --git a/server/opts_test.go b/server/opts_test.go index 54ac3c2a018..b0f7c959f01 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -69,6 +69,8 @@ func TestConfigFile(t *testing.T) { MaxConn: 100, PingInterval: 60 * time.Second, MaxPingsOut: 3, + RateMaxMsgs: 1000000, + RateMaxBytes: 100 * 1024 * 1024, } opts, err := ProcessConfigFile("./configs/test.conf") @@ -193,6 +195,8 @@ func TestMergeOverrides(t *testing.T) { Cluster: ClusterOpts{ NoAdvertise: true, }, + RateMaxMsgs: 10000, + RateMaxBytes: 10000 * 1024, } fopts, err := ProcessConfigFile("./configs/test.conf") if err != nil { @@ -209,6 +213,8 @@ func TestMergeOverrides(t *testing.T) { Cluster: ClusterOpts{ NoAdvertise: true, }, + RateMaxMsgs: 10000, + RateMaxBytes: 10000 * 1024, } merged := MergeOptions(fopts, opts)