Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a race condition in the TCP input #13038

Merged
merged 7 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...master[Check the HEAD d

- Add read_buffer configuration option. {pull}11739[11739]
- `convert_timezone` option is removed and locale is always added to the event so timezone is used when parsing the timestamp, this behaviour can be overriden with processors. {pull}12410[12410]
- Fix a race condition in the TCP input when close the client socket. {pull}13038[13038]

*Heartbeat*

Expand Down
28 changes: 12 additions & 16 deletions filebeat/inputsource/tcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (

// splitHandler is a TCP client that has splitting capabilities.
type splitHandler struct {
conn net.Conn
callback inputsource.NetworkFunc
done chan struct{}
metadata inputsource.NetworkMetadata
Expand All @@ -42,19 +41,23 @@ type splitHandler struct {
timeout time.Duration
}

// ClientFactory returns a ConnectionHandler func
type ClientFactory func(config Config) ConnectionHandler
// HandlerFactory returns a ConnectionHandler func
type HandlerFactory func(config Config) ConnectionHandler

// ConnectionHandler interface provides mechanisms for handling of incoming TCP connections
type ConnectionHandler interface {
Handle(conn net.Conn) error
Close()
Handle(closeRef, net.Conn) error
}

// SplitHandlerFactory allows creation of a ConnectionHandler that can do splitting of messages received on a TCP connection.
func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) ClientFactory {
func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) HandlerFactory {
return func(config Config) ConnectionHandler {
return newSplitHandler(callback, splitFunc, uint64(config.MaxMessageSize), config.Timeout)
return newSplitHandler(
callback,
splitFunc,
uint64(config.MaxMessageSize),
config.Timeout,
)
}
}

Expand All @@ -76,8 +79,7 @@ func newSplitHandler(
}

// Handle takes a connection as input and processes data received on it.
func (c *splitHandler) Handle(conn net.Conn) error {
c.conn = conn
func (c *splitHandler) Handle(closer closeRef, conn net.Conn) error {
c.metadata = inputsource.NetworkMetadata{
RemoteAddr: conn.RemoteAddr(),
TLS: extractSSLInformation(conn),
Expand All @@ -97,7 +99,7 @@ func (c *splitHandler) Handle(conn net.Conn) error {
if err != nil {
// we are forcing a Close on the socket, lets ignore any error that could happen.
select {
case <-c.done:
case <-closer.Done():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the splitter returns an error due to us closing the connection, shall we supress the error then and just return nil?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently what I do I suppress errors in that scenario, in the beats input, I've decided otherwise and it created a lot of confusion for the users.

break
default:
}
Expand All @@ -121,12 +123,6 @@ func (c *splitHandler) Handle(conn net.Conn) error {
return nil
}

// Close is used to perform clean up before the client is released.
func (c *splitHandler) Close() {
close(c.done)
c.conn.Close()
}

func extractSSLInformation(c net.Conn) *inputsource.TLSMetadata {
if tls, ok := c.(*tls.Conn); ok {
state := tls.ConnectionState()
Expand Down
164 changes: 113 additions & 51 deletions filebeat/inputsource/tcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,50 @@ import (

"golang.org/x/net/netutil"

"github.com/pkg/errors"

"github.com/elastic/beats/libbeat/common/atomic"
"github.com/elastic/beats/libbeat/common/transport/tlscommon"
"github.com/elastic/beats/libbeat/logp"
"github.com/elastic/beats/libbeat/outputs/transport"
)

var errClosed = errors.New("connection closed")

// Server represent a TCP server
type Server struct {
sync.RWMutex
config *Config
Listener net.Listener
clients map[ConnectionHandler]struct{}
wg sync.WaitGroup
done chan struct{}
factory ClientFactory
log *logp.Logger
tlsConfig *transport.TLSConfig
config *Config
Listener net.Listener
wg sync.WaitGroup
done chan struct{}
factory HandlerFactory
log *logp.Logger
tlsConfig *transport.TLSConfig
closer *closer
clientsCount atomic.Int
}

// New creates a new tcp server
func New(
config *Config,
factory ClientFactory,
factory HandlerFactory,
) (*Server, error) {
tlsConfig, err := tlscommon.LoadTLSServerConfig(config.TLS)
if err != nil {
return nil, err
}

if factory == nil {
return nil, fmt.Errorf("ClientFactory can't be empty")
return nil, fmt.Errorf("HandlerFactory can't be empty")
}

return &Server{
config: config,
clients: make(map[ConnectionHandler]struct{}, 0),
done: make(chan struct{}),
factory: factory,
log: logp.NewLogger("tcp").With("address", config.Host),
tlsConfig: tlsConfig,
closer: &closer{done: make(chan struct{})},
}, nil
}

Expand All @@ -77,6 +82,7 @@ func (s *Server) Start() error {
return err
}

s.closer.callback = func() { s.Listener.Close() }
s.log.Info("Started listening for TCP connection")

s.wg.Add(1)
Expand All @@ -97,27 +103,28 @@ func (s *Server) run() {
conn, err := s.Listener.Accept()
if err != nil {
select {
case <-s.done:
case <-s.closer.Done():
return
default:
s.log.Debugw("Can not accept the connection", "error", err)
continue
}
}

client := s.factory(*s.config)
handler := s.factory(*s.config)
closer := withCloser(s.closer, conn)

s.wg.Add(1)
go func() {
defer logp.Recover("recovering from a tcp client crash")
defer s.wg.Done()
defer conn.Close()
defer closer.Close()

s.registerClient(client)
defer s.unregisterClient(client)
s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount())
s.registerHandler()
defer s.unregisterHandler()
s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount.Load())

err := client.Handle(conn)
err := handler.Handle(closer, conn)
if err != nil {
s.log.Debugw("client error", "error", err)
}
Expand All @@ -127,7 +134,7 @@ func (s *Server) run() {
"remote_address",
conn.RemoteAddr(),
"total",
s.clientsCount(),
s.clientsCount.Load(),
)
}()
}
Expand All @@ -136,37 +143,17 @@ func (s *Server) run() {
// Stop stops accepting new incoming TCP connection and Close any active clients
func (s *Server) Stop() {
s.log.Info("Stopping TCP server")
close(s.done)
s.Listener.Close()
for _, client := range s.allClients() {
client.Close()
}
s.closer.Close()
s.wg.Wait()
s.log.Info("TCP server stopped")
}

func (s *Server) registerClient(client ConnectionHandler) {
s.Lock()
defer s.Unlock()
s.clients[client] = struct{}{}
func (s *Server) registerHandler() {
s.clientsCount.Inc()
}

func (s *Server) unregisterClient(client ConnectionHandler) {
s.Lock()
defer s.Unlock()
delete(s.clients, client)
}

func (s *Server) allClients() []ConnectionHandler {
s.RLock()
defer s.RUnlock()
currentClients := make([]ConnectionHandler, len(s.clients))
idx := 0
for client := range s.clients {
currentClients[idx] = client
idx++
}
return currentClients
func (s *Server) unregisterHandler() {
s.clientsCount.Dec()
}

func (s *Server) createServer() (net.Listener, error) {
Expand All @@ -192,12 +179,6 @@ func (s *Server) createServer() (net.Listener, error) {
return l, nil
}

func (s *Server) clientsCount() int {
s.RLock()
defer s.RUnlock()
return len(s.clients)
}

// SplitFunc allows to create a `bufio.SplitFunc` based on a delimiter provided.
func SplitFunc(lineDelimiter []byte) bufio.SplitFunc {
ld := []byte(lineDelimiter)
Expand All @@ -209,3 +190,84 @@ func SplitFunc(lineDelimiter []byte) bufio.SplitFunc {
}
return factoryDelimiter(ld)
}

func withCloser(parent *closer, conn net.Conn) *closer {
child := &closer{
done: make(chan struct{}),
parent: parent,
callback: func() {
conn.Close()
},
}
parent.addChild(child)
return child
}

type closeRef interface {
ph marked this conversation as resolved.
Show resolved Hide resolved
Done() <-chan struct{}
Err() error
}

type closer struct {
mu sync.Mutex
done chan struct{}
err error
parent *closer
children map[*closer]struct{}
callback func()
}

func (c *closer) Close() {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return
}

if c.callback != nil {
c.callback()
}

close(c.done)

// propagate close to children.
if c.children != nil {
for child := range c.children {
child.Close()
}
c.children = nil
}

c.err = errClosed
c.mu.Unlock()

if c.parent != nil {
c.removeChild(c)
}
}

func (c *closer) Done() <-chan struct{} {
return c.done
}

func (c *closer) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}

func (c *closer) removeChild(child *closer) {
c.mu.Lock()
delete(c.children, child)
c.mu.Unlock()
}

func (c *closer) addChild(child *closer) {
c.mu.Lock()
if c.children == nil {
c.children = make(map[*closer]struct{})
}
c.children[child] = struct{}{}
c.mu.Unlock()
}
3 changes: 2 additions & 1 deletion filebeat/inputsource/tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/dustin/go-humanize"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/beats/filebeat/inputsource"
"github.com/elastic/beats/libbeat/common"
Expand Down Expand Up @@ -180,7 +181,7 @@ func TestReceiveEventsAndMetadata(t *testing.T) {
defer server.Stop()

conn, err := net.Dial("tcp", server.Listener.Addr().String())
assert.NoError(t, err)
require.NoError(t, err)
fmt.Fprint(conn, test.messageSent)
conn.Close()

Expand Down