Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apiserver: rate-limit logsink receives #7474

Merged
merged 1 commit into from Jun 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions agent/agent.go
Expand Up @@ -183,6 +183,8 @@ const (

LogSinkDBLoggerBufferSize = "LOGSINK_DBLOGGER_BUFFER_SIZE"
LogSinkDBLoggerFlushInterval = "LOGSINK_DBLOGGER_FLUSH_INTERVAL"
LogSinkRateLimitBurst = "LOGSINK_RATELIMIT_BURST"
LogSinkRateLimitRefill = "LOGSINK_RATELIMIT_REFILL"
)

// The Config interface is the sole way that the agent gets access to the
Expand Down
103 changes: 65 additions & 38 deletions apiserver/apiserver.go
Expand Up @@ -51,46 +51,49 @@ var defaultHTTPMethods = []string{"GET", "POST", "HEAD", "PUT", "DELETE", "OPTIO

// These vars define how we rate limit incoming connections.
const (
defaultLoginRateLimit = 10 // concurrent login operations
defaultLoginMinPause = 100 * time.Millisecond
defaultLoginMaxPause = 1 * time.Second
defaultLoginRetryPause = 5 * time.Second
defaultConnMinPause = 0 * time.Millisecond
defaultConnMaxPause = 5 * time.Second
defaultConnLookbackWindow = 1 * time.Second
defaultConnLowerThreshold = 1000 // connections per second
defaultConnUpperThreshold = 100000 // connections per second
defaultLoginRateLimit = 10 // concurrent login operations
defaultLoginMinPause = 100 * time.Millisecond
defaultLoginMaxPause = 1 * time.Second
defaultLoginRetryPause = 5 * time.Second
defaultConnMinPause = 0 * time.Millisecond
defaultConnMaxPause = 5 * time.Second
defaultConnLookbackWindow = 1 * time.Second
defaultConnLowerThreshold = 1000 // connections per second
defaultConnUpperThreshold = 100000 // connections per second
defaultLogSinkRateLimitBurst = 1000
defaultLogSinkRateLimitRefill = time.Millisecond
)

// Server holds the server side of the API.
type Server struct {
tomb tomb.Tomb
clock clock.Clock
pingClock clock.Clock
wg sync.WaitGroup
state *state.State
statePool *state.StatePool
lis net.Listener
tag names.Tag
dataDir string
logDir string
limiter utils.Limiter
loginRetryPause time.Duration
validator LoginValidator
facades *facade.Registry
modelUUID string
authCtxt *authContext
lastConnectionID uint64
centralHub *pubsub.StructuredHub
newObserver observer.ObserverFactory
connCount int64
totalConn int64
loginAttempts int64
certChanged <-chan params.StateServingInfo
tlsConfig *tls.Config
allowModelAccess bool
logSinkWriter io.WriteCloser
dbloggers dbloggers
tomb tomb.Tomb
clock clock.Clock
pingClock clock.Clock
wg sync.WaitGroup
state *state.State
statePool *state.StatePool
lis net.Listener
tag names.Tag
dataDir string
logDir string
limiter utils.Limiter
loginRetryPause time.Duration
validator LoginValidator
facades *facade.Registry
modelUUID string
authCtxt *authContext
lastConnectionID uint64
centralHub *pubsub.StructuredHub
newObserver observer.ObserverFactory
connCount int64
totalConn int64
loginAttempts int64
certChanged <-chan params.StateServingInfo
tlsConfig *tls.Config
allowModelAccess bool
logSinkWriter io.WriteCloser
logsinkRateLimitConfig logsink.RateLimitConfig
dbloggers dbloggers

// mu guards the fields below it.
mu sync.Mutex
Expand Down Expand Up @@ -175,7 +178,8 @@ type ServerConfig struct {
PrometheusRegisterer prometheus.Registerer
}

func (c *ServerConfig) Validate() error {
// Validate validates the API server configuration.
func (c ServerConfig) Validate() error {
if c.Hub == nil {
return errors.NotValidf("missing Hub")
}
Expand All @@ -199,7 +203,7 @@ func (c *ServerConfig) Validate() error {
return nil
}

func (c *ServerConfig) pingClock() clock.Clock {
func (c ServerConfig) pingClock() clock.Clock {
if c.PingClock == nil {
return c.Clock
}
Expand Down Expand Up @@ -273,6 +277,14 @@ type LogSinkConfig struct {
// DBLoggerFlushInterval is the amount of time to allow a log record
// to sit in the buffer before being flushed to the database.
DBLoggerFlushInterval time.Duration

// RateLimitBurst defines the number of log messages that will be let
// through before we start rate limiting.
RateLimitBurst int64

// RateLimitRefill defines the rate at which log messages will be let
// through once the initial burst amount has been depleted.
RateLimitRefill time.Duration
}

// Validate validates the logsink endpoint configuration.
Expand All @@ -283,6 +295,12 @@ func (cfg LogSinkConfig) Validate() error {
if cfg.DBLoggerFlushInterval <= 0 || cfg.DBLoggerFlushInterval > 10*time.Second {
return errors.NotValidf("DBLoggerFlushInterval %s <= 0 or > 10 seconds", cfg.DBLoggerFlushInterval)
}
if cfg.RateLimitBurst <= 0 {
return errors.NotValidf("RateLimitBurst %d <= 0", cfg.RateLimitBurst)
}
if cfg.RateLimitRefill <= 0 {
return errors.NotValidf("RateLimitRefill %s <= 0", cfg.RateLimitRefill)
}
return nil
}

Expand All @@ -291,6 +309,8 @@ func DefaultLogSinkConfig() LogSinkConfig {
return LogSinkConfig{
DBLoggerBufferSize: defaultDBLoggerBufferSize,
DBLoggerFlushInterval: defaultDBLoggerFlushInterval,
RateLimitBurst: defaultLogSinkRateLimitBurst,
RateLimitRefill: defaultLogSinkRateLimitRefill,
}
}

Expand Down Expand Up @@ -350,6 +370,11 @@ func newServer(s *state.State, lis net.Listener, cfg ServerConfig) (_ *Server, e
allowModelAccess: cfg.AllowModelAccess,
publicDNSName_: cfg.AutocertDNSName,
registerIntrospectionHandlers: cfg.RegisterIntrospectionHandlers,
logsinkRateLimitConfig: logsink.RateLimitConfig{
Refill: cfg.LogSinkConfig.RateLimitRefill,
Burst: cfg.LogSinkConfig.RateLimitBurst,
Clock: cfg.Clock,
},
dbloggers: dbloggers{
clock: cfg.Clock,
dbLoggerBufferSize: cfg.LogSinkConfig.DBLoggerBufferSize,
Expand Down Expand Up @@ -611,13 +636,15 @@ func (srv *Server) endpoints() []apihttp.Endpoint {
logSinkHandler := logsink.NewHTTPHandler(
newAgentLogWriteCloserFunc(httpCtxt, srv.logSinkWriter, &srv.dbloggers),
httpCtxt.stop(),
&srv.logsinkRateLimitConfig,
)
add("/model/:modeluuid/logsink", srv.trackRequests(logSinkHandler))

// We don't need to save the migrated logs to a logfile as well as to the DB.
logTransferHandler := logsink.NewHTTPHandler(
newMigrationLogWriteCloserFunc(httpCtxt, &srv.dbloggers),
httpCtxt.stop(),
nil, // no rate-limiting
)
add("/migrate/logtransfer", srv.trackRequests(logTransferHandler))

Expand Down
58 changes: 57 additions & 1 deletion apiserver/logsink/logsink.go
Expand Up @@ -11,6 +11,8 @@ import (
gorillaws "github.com/gorilla/websocket"
"github.com/juju/errors"
"github.com/juju/loggo"
"github.com/juju/ratelimit"
"github.com/juju/utils/clock"
"github.com/juju/utils/featureflag"
"github.com/juju/version"

Expand All @@ -34,21 +36,43 @@ type LogWriteCloser interface {
// NewLogWriteCloserFunc returns a new LogWriteCloser for the given http.Request.
type NewLogWriteCloserFunc func(*http.Request) (LogWriteCloser, error)

// RateLimitConfig contains the rate-limit configuration for the logsink
// handler.
type RateLimitConfig struct {
// Burst is the number of log messages that will be let through before
// we start rate limiting.
Burst int64

// Refill is the rate at which log messages will be let through once
// the initial burst amount has been depleted.
Refill time.Duration

// Clock is the clock used to wait when rate-limiting log receives.
Clock clock.Clock
}

// NewHTTPHandler returns a new http.Handler for receiving log messages over a
// websocket.
// websocket, using the given NewLogWriteCloserFunc to obtain a writer to which
// the log messages will be written.
//
// ratelimit defines an optional rate-limit configuration. If nil, no rate-
// limiting will be applied.
func NewHTTPHandler(
newLogWriteCloser NewLogWriteCloserFunc,
abort <-chan struct{},
ratelimit *RateLimitConfig,
) http.Handler {
return &logSinkHandler{
newLogWriteCloser: newLogWriteCloser,
abort: abort,
ratelimit: ratelimit,
}
}

type logSinkHandler struct {
newLogWriteCloser NewLogWriteCloserFunc
abort <-chan struct{}
ratelimit *RateLimitConfig
}

// Since the logsink only receives messages, it is possible for the other end
Expand Down Expand Up @@ -157,6 +181,15 @@ func (h *logSinkHandler) getVersion(req *http.Request) (int, error) {
func (h *logSinkHandler) receiveLogs(socket *websocket.Conn, endpointVersion int) <-chan params.LogRecord {
logCh := make(chan params.LogRecord)

var tokenBucket *ratelimit.Bucket
if h.ratelimit != nil {
tokenBucket = ratelimit.NewBucketWithClock(
h.ratelimit.Refill,
h.ratelimit.Burst,
ratelimitClock{h.ratelimit.Clock},
)
}

go func() {
// Close the channel to signal ServeHTTP to finish. Otherwise
// we leak goroutines on client disconnect, because the server
Expand All @@ -180,6 +213,19 @@ func (h *logSinkHandler) receiveLogs(socket *websocket.Conn, endpointVersion int
return
}

// Rate-limit receipt of log messages. We rate-limit
// each connection individually to prevent one noisy
// individual from drowning out the others.
if tokenBucket != nil {
if d := tokenBucket.Take(1); d > 0 {
select {
case <-h.ratelimit.Clock.After(d):
case <-h.abort:
return
}
}
}

// Send the log message.
select {
case <-h.abort:
Expand Down Expand Up @@ -223,3 +269,13 @@ func JujuClientVersionFromRequest(req *http.Request) (version.Number, error) {
}
return ver, nil
}

// ratelimitClock adapts clock.Clock to ratelimit.Clock.
type ratelimitClock struct {
clock.Clock
}

// Sleep is defined by the ratelimit.Clock interface.
func (c ratelimitClock) Sleep(d time.Duration) {
<-c.Clock.After(d)
}
65 changes: 65 additions & 0 deletions apiserver/logsink/logsink_test.go
Expand Up @@ -65,6 +65,7 @@ func (s *logsinkSuite) SetUpTest(c *gc.C) {
}, s.stub.NextErr()
},
s.abort,
nil, // no rate-limiting
))
s.AddCleanup(func(*gc.C) { s.srv.Close() })
}
Expand Down Expand Up @@ -168,6 +169,70 @@ func (s *logsinkSuite) TestReceiveErrorBreaksConn(c *gc.C) {
websockettest.AssertWebsocketClosed(c, conn)
}

func (s *logsinkSuite) TestRateLimit(c *gc.C) {
testClock := testing.NewClock(time.Time{})
s.srv.Close()
s.srv = httptest.NewServer(logsink.NewHTTPHandler(
func(req *http.Request) (logsink.LogWriteCloser, error) {
s.stub.AddCall("Open")
return &mockLogWriteCloser{
&s.stub,
s.written,
}, s.stub.NextErr()
},
s.abort,
&logsink.RateLimitConfig{
Burst: 2,
Refill: time.Second,
Clock: testClock,
},
))

conn := s.dialWebsocket(c)
websockettest.AssertJSONInitialErrorNil(c, conn)

record := params.LogRecord{
Time: time.Date(2015, time.June, 1, 23, 2, 1, 0, time.UTC),
Module: "some.where",
Location: "foo.go:42",
Level: loggo.INFO.String(),
Message: "all is well",
}
for i := 0; i < 4; i++ {
err := conn.WriteJSON(&record)
c.Assert(err, jc.ErrorIsNil)
}

expectRecord := func() {
select {
case written, ok := <-s.written:
c.Assert(ok, jc.IsTrue)
c.Assert(written, jc.DeepEquals, record)
case <-time.After(coretesting.LongWait):
c.Fatal("timed out waiting for log record to be written")
}
}
expectNoRecord := func() {
select {
case <-s.written:
c.Fatal("unexpected log record")
case <-time.After(coretesting.ShortWait):
}
}

// There should be 2 records received immediately,
// and then rate-limiting should kick in.
expectRecord()
expectRecord()
expectNoRecord()
testClock.WaitAdvance(time.Second, coretesting.LongWait, 1)
expectRecord()
expectNoRecord()
testClock.WaitAdvance(time.Second, coretesting.LongWait, 1)
expectRecord()
expectNoRecord()
}

type mockLogWriteCloser struct {
*testing.Stub
written chan<- params.LogRecord
Expand Down
8 changes: 8 additions & 0 deletions apiserver/logsink_test.go
Expand Up @@ -230,6 +230,14 @@ func (s *logsinkSuite) TestNewServerValidatesLogSinkConfig(c *gc.C) {
cfg.LogSinkConfig.DBLoggerFlushInterval = 30 * time.Second
_, err = apiserver.NewServer(s.State, dummyListener{}, cfg)
c.Assert(err, gc.ErrorMatches, "validating logsink configuration: DBLoggerFlushInterval 30s <= 0 or > 10 seconds not valid")

cfg.LogSinkConfig.DBLoggerFlushInterval = 10 * time.Second
_, err = apiserver.NewServer(s.State, dummyListener{}, cfg)
c.Assert(err, gc.ErrorMatches, "validating logsink configuration: RateLimitBurst 0 <= 0 not valid")

cfg.LogSinkConfig.RateLimitBurst = 1000
_, err = apiserver.NewServer(s.State, dummyListener{}, cfg)
c.Assert(err, gc.ErrorMatches, "validating logsink configuration: RateLimitRefill 0s <= 0 not valid")
}

func (s *logsinkSuite) dialWebsocket(c *gc.C) *websocket.Conn {
Expand Down