From 5e1b54c24eac3ddbace6a299035fd0c0d9370f7f Mon Sep 17 00:00:00 2001 From: Arno Geurts Date: Mon, 14 Sep 2020 12:02:30 +0200 Subject: [PATCH] Add WithContext methods to the Fluent struct including handling of context deadlines Signed-off-by: Arno Geurts --- fluent/fluent.go | 143 +++++++++++++++++++++++++++++++----------- fluent/fluent_test.go | 54 ++++++++++++++++ 2 files changed, 161 insertions(+), 36 deletions(-) diff --git a/fluent/fluent.go b/fluent/fluent.go index 2428672..5889abb 100644 --- a/fluent/fluent.go +++ b/fluent/fluent.go @@ -1,6 +1,7 @@ package fluent import ( + "context" "encoding/json" "errors" "fmt" @@ -78,17 +79,24 @@ func NewErrUnknownNetwork(network string) error { } type msgToSend struct { + ctx context.Context data []byte ack string } +type bufferInput struct { + msg *msgToSend + result chan<- error +} + type Fluent struct { Config dialer dialer stopRunning chan bool - pending chan *msgToSend + pending chan bufferInput wg sync.WaitGroup + resultPool sync.Pool muconn sync.Mutex conn net.Conn @@ -108,6 +116,10 @@ type dialer interface { Dial(string, string) (net.Conn, error) } +type dialerWithContext interface { + DialContext(context.Context, string, string) (net.Conn, error) +} + func newWithDialer(config Config, d dialer) (f *Fluent, err error) { if config.FluentNetwork == "" { config.FluentNetwork = defaultNetwork @@ -140,22 +152,24 @@ func newWithDialer(config Config, d dialer) (f *Fluent, err error) { fmt.Fprintf(os.Stderr, "fluent#New: AsyncConnect is now deprecated, please use Async instead") config.Async = config.Async || config.AsyncConnect } - - if config.Async { - f = &Fluent{ - Config: config, - dialer: d, - pending: make(chan *msgToSend, config.BufferLimit), - } - f.wg.Add(1) - go f.run() - } else { - f = &Fluent{ - Config: config, - dialer: d, + f = &Fluent{ + Config: config, + dialer: d, + pending: make(chan bufferInput, config.BufferLimit), + resultPool: sync.Pool{ + New: func() interface{} { + return make(chan error, 1) + }, + }, + } + if !config.Async { + if err = f.connect(context.Background()); err != nil { + return } - err = f.connect() } + + f.wg.Add(1) + go f.run() return } @@ -185,17 +199,25 @@ func newWithDialer(config Config, d dialer) (f *Fluent, err error) { // f.Post("tag_name", structData) // func (f *Fluent) Post(tag string, message interface{}) error { + return f.PostWithContext(context.Background(), tag, message) +} + +func (f *Fluent) PostWithContext(ctx context.Context, tag string, message interface{}) error { timeNow := time.Now() - return f.PostWithTime(tag, timeNow, message) + return f.PostWithTimeAndContext(ctx, tag, timeNow, message) } func (f *Fluent) PostWithTime(tag string, tm time.Time, message interface{}) error { + return f.PostWithTimeAndContext(context.Background(), tag, tm, message) +} + +func (f *Fluent) PostWithTimeAndContext(ctx context.Context, tag string, tm time.Time, message interface{}) error { if len(f.TagPrefix) > 0 { tag = f.TagPrefix + "." + tag } if m, ok := message.(msgp.Marshaler); ok { - return f.EncodeAndPostData(tag, tm, m) + return f.EncodeAndPostDataWithContext(ctx, tag, tm, m) } msg := reflect.ValueOf(message) @@ -215,7 +237,7 @@ func (f *Fluent) PostWithTime(tag string, tm time.Time, message interface{}) err } kv[name] = msg.FieldByIndex(field.Index).Interface() } - return f.EncodeAndPostData(tag, tm, kv) + return f.EncodeAndPostDataWithContext(ctx, tag, tm, kv) } if msgtype.Kind() != reflect.Map { @@ -229,13 +251,17 @@ func (f *Fluent) PostWithTime(tag string, tm time.Time, message interface{}) err kv[k.String()] = msg.MapIndex(k).Interface() } - return f.EncodeAndPostData(tag, tm, kv) + return f.EncodeAndPostDataWithContext(ctx, tag, tm, kv) } func (f *Fluent) EncodeAndPostData(tag string, tm time.Time, message interface{}) error { + return f.EncodeAndPostDataWithContext(context.Background(), tag, tm, message) +} + +func (f *Fluent) EncodeAndPostDataWithContext(ctx context.Context, tag string, tm time.Time, message interface{}) error { var msg *msgToSend var err error - if msg, err = f.EncodeData(tag, tm, message); err != nil { + if msg, err = f.EncodeDataWithContext(ctx, tag, tm, message); err != nil { return fmt.Errorf("fluent#EncodeAndPostData: can't convert '%#v' to msgpack:%v", message, err) } return f.postRawData(msg) @@ -251,7 +277,7 @@ func (f *Fluent) postRawData(msg *msgToSend) error { return f.appendBuffer(msg) } // Synchronous write - return f.write(msg) + return f.appendBufferBlocking(msg) } // For sending forward protocol adopted JSON @@ -296,8 +322,12 @@ func getUniqueID(timeUnix int64) (string, error) { } func (f *Fluent) EncodeData(tag string, tm time.Time, message interface{}) (msg *msgToSend, err error) { + return f.EncodeDataWithContext(context.Background(), tag, tm, message) +} + +func (f *Fluent) EncodeDataWithContext(ctx context.Context, tag string, tm time.Time, message interface{}) (msg *msgToSend, err error) { option := make(map[string]string) - msg = &msgToSend{} + msg = &msgToSend{ctx: ctx} timeUnix := tm.Unix() if f.Config.RequestAck { var err error @@ -338,13 +368,37 @@ func (f *Fluent) Close() (err error) { // appendBuffer appends data to buffer with lock. func (f *Fluent) appendBuffer(msg *msgToSend) error { select { - case f.pending <- msg: + case f.pending <- bufferInput{msg: msg}: default: return fmt.Errorf("fluent#appendBuffer: Buffer full, limit %v", f.Config.BufferLimit) } return nil } +// appendBufferWithFeedback appends data to buffer and waits for the result +func (f *Fluent) appendBufferBlocking(msg *msgToSend) error { + result := f.resultPool.Get().(chan error) + // write the data to the buffer and block if the buffer is full + select { + case f.pending <- bufferInput{msg: msg, result: result}: + // don't do anything + case <-msg.ctx.Done(): + // because the result channel is not used, it can safely be returned to the sync pool. + f.resultPool.Put(result) + return msg.ctx.Err() + } + + select { + case err := <-result: + f.resultPool.Put(result) + return err + case <-msg.ctx.Done(): + // the context deadline has exceeded, but there is no result yet. So the result channel cannot be returned to + // the pool, as it might be written later. + return msg.ctx.Err() + } +} + // close closes the connection. func (f *Fluent) close(c net.Conn) { f.muconn.Lock() @@ -356,19 +410,23 @@ func (f *Fluent) close(c net.Conn) { } // connect establishes a new connection using the specified transport. -func (f *Fluent) connect() (err error) { +func (f *Fluent) connect(ctx context.Context) (err error) { + var address string switch f.Config.FluentNetwork { case "tcp": - f.conn, err = f.dialer.Dial( - f.Config.FluentNetwork, - f.Config.FluentHost+":"+strconv.Itoa(f.Config.FluentPort)) + address = f.Config.FluentHost + ":" + strconv.Itoa(f.Config.FluentPort) case "unix": - f.conn, err = f.dialer.Dial( - f.Config.FluentNetwork, - f.Config.FluentSocketPath) + address = f.Config.FluentSocketPath default: err = NewErrUnknownNetwork(f.Config.FluentNetwork) + return } + if d, ok := f.dialer.(dialerWithContext); ok { + f.conn, err = d.DialContext(ctx, f.Config.FluentNetwork, address) + } else { + f.conn, err = f.dialer.Dial(f.Config.FluentNetwork, address) + } + return err } @@ -386,7 +444,11 @@ func (f *Fluent) run() { emitEventDrainMsg.Do(func() { fmt.Fprintf(os.Stderr, "[%s] Discarding queued events...\n", time.Now().Format(time.RFC3339)) }) continue } - err := f.write(entry) + err := f.write(entry.msg) + if entry.result != nil { + entry.result <- err + continue + } if err != nil { fmt.Fprintf(os.Stderr, "[%s] Unable to send logs to fluentd, reconnecting...\n", time.Now().Format(time.RFC3339)) } @@ -413,7 +475,7 @@ func (f *Fluent) write(msg *msgToSend) error { if c == nil { f.muconn.Lock() if f.conn == nil { - err := f.connect() + err := f.connect(msg.ctx) if err != nil { f.muconn.Unlock() @@ -425,7 +487,13 @@ func (f *Fluent) write(msg *msgToSend) error { if waitTime > f.Config.MaxRetryWait { waitTime = f.Config.MaxRetryWait } - time.Sleep(time.Duration(waitTime) * time.Millisecond) + waitDuration := time.Duration(waitTime) * time.Millisecond + if deadline, hasDeadLine := msg.ctx.Deadline(); hasDeadLine && deadline.Before(time.Now().Add(waitDuration)) { + // the context deadline is within the wait time, so after the sleep the deadline will have been + // exceeded. It is a waste of time to wait on that. + return context.DeadlineExceeded + } + time.Sleep(waitDuration) continue } } @@ -435,11 +503,14 @@ func (f *Fluent) write(msg *msgToSend) error { // We're connected, write msg t := f.Config.WriteTimeout + var deadline time.Time if time.Duration(0) < t { - c.SetWriteDeadline(time.Now().Add(t)) - } else { - c.SetWriteDeadline(time.Time{}) + deadline = time.Now().Add(t) + } + if ctxDeadline, hasDeadline := msg.ctx.Deadline(); hasDeadline && (deadline.IsZero() || ctxDeadline.Before(deadline)) { + deadline = ctxDeadline } + c.SetWriteDeadline(deadline) _, err := c.Write(msg.data) if err != nil { f.close(c) diff --git a/fluent/fluent_test.go b/fluent/fluent_test.go index 6db33e5..2d87fe0 100644 --- a/fluent/fluent_test.go +++ b/fluent/fluent_test.go @@ -2,6 +2,7 @@ package fluent import ( "bytes" + "context" "encoding/json" "errors" "io/ioutil" @@ -449,6 +450,59 @@ func TestPostWithTime(t *testing.T) { } } +func TestPostWithTimeAndContext(t *testing.T) { + testcases := map[string]Config{ + "with Async": { + Async: true, + MarshalAsJSON: true, + TagPrefix: "acme", + }, + "without Async": { + Async: false, + MarshalAsJSON: true, + TagPrefix: "acme", + }, + } + + for tcname := range testcases { + t.Run(tcname, func(t *testing.T) { + tc := testcases[tcname] + t.Parallel() + + d := newTestDialer() + var f *Fluent + defer func() { + if f != nil { + f.Close() + } + }() + deadline := time.Now().Add(1 * time.Second) + + go func() { + var err error + if f, err = newWithDialer(tc, d); err != nil { + t.Errorf("Unexpected error: %v", err) + } + ctx, cancelFunc := context.WithDeadline(context.Background(), deadline) + defer cancelFunc() + + _ = f.PostWithTimeAndContext(ctx, "tag_name", time.Unix(1482493046, 0), map[string]string{"foo": "bar"}) + _ = f.PostWithTimeAndContext(ctx, "tag_name", time.Unix(1482493050, 0), map[string]string{"fluentd": "is awesome"}) + }() + + conn := d.waitForNextDialing(true) + assertReceived(t, + conn.waitForNextWrite(true, ""), + "[\"acme.tag_name\",1482493046,{\"foo\":\"bar\"},{}]") + + assertReceived(t, + conn.waitForNextWrite(true, ""), + "[\"acme.tag_name\",1482493050,{\"fluentd\":\"is awesome\"},{}]") + assert.Equal(t, conn.writeDeadline, deadline) + }) + } +} + func TestReconnectAndResendAfterTransientFailure(t *testing.T) { testcases := map[string]Config{ "with Async": {