Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent hanging on TCP connections that won't return [any/full] OPC UA data #628

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewDialer(cfg *Config) *uacp.Dialer {
// ApplyConfig applies the config options to the default configuration.
// todo(fs): Can we find a better name?
//
// Note: Starting with v0.5 this function will will return an error.
// Note: Starting with v0.5 this function will return an error.
func ApplyConfig(opts ...Option) *Config {
cfg := &Config{
sechan: DefaultClientConfig(),
Expand Down Expand Up @@ -501,6 +501,22 @@ func DialTimeout(d time.Duration) Option {
}
}

// ReadTimeout sets the timeout for every read operation.
func ReadTimeout(d time.Duration) Option {
return func(cfg *Config) {
initDialer(cfg)
cfg.dialer.ReadTimeout = d
}
}

// WriteTimeout sets the timeout for every write operation.
func WriteTimeout(d time.Duration) Option {
return func(cfg *Config) {
initDialer(cfg)
cfg.dialer.WriteTimeout = d
}
}

// MaxMessageSize sets the maximum message size for the UACP handshake.
func MaxMessageSize(n uint32) Option {
return func(cfg *Config) {
Expand Down
53 changes: 46 additions & 7 deletions uacp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ type Dialer struct {
// ClientACK defines the connection parameters requested by the client.
// Defaults to DefaultClientACK.
ClientACK *Acknowledge

// ReadTimeout sets a read timeout for reading a full response from the
// underlying network connection. ReadTimeout is ignored if it is <= 0.
ReadTimeout time.Duration

// WriteTimeout sets a write timeout for sending a request on the
// underlying network connection. WriteTimeout is ignored if it is <= 0.
WriteTimeout time.Duration
}

func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
Expand All @@ -88,6 +96,8 @@ func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
c.Close()
return nil, err
}
conn.readTimeout = d.ReadTimeout
conn.writeTimeout = d.WriteTimeout

debug.Printf("uacp %d: start HEL/ACK handshake", conn.id)
if err := conn.Handshake(ctx, endpoint); err != nil {
Expand Down Expand Up @@ -174,7 +184,9 @@ type Conn struct {
id uint32
ack *Acknowledge

closeOnce sync.Once
closeOnce sync.Once
readTimeout time.Duration
writeTimeout time.Duration
}

func NewConn(c *net.TCPConn, ack *Acknowledge) (*Conn, error) {
Expand Down Expand Up @@ -351,15 +363,25 @@ const hdrlen = 8
// The size of b must be at least ReceiveBufSize. Otherwise,
// the function returns an error.
func (c *Conn) Receive() ([]byte, error) {
// TODO(kung-foo): allow user-specified buffer
// TODO(kung-foo): sync.Pool
// todo(kung-foo): allow user-specified buffer
// todo(kung-foo): sync.Pool
b := make([]byte, c.ack.ReceiveBufSize)

if _, err := io.ReadFull(c, b[:hdrlen]); err != nil {
if c.readTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil {
return nil, errors.Errorf("uacp: failed to set read timeout: %w", err)
}
}

n, err := c.Read(b[:hdrlen])
if err != nil {
// todo(fs): do not wrap this error since it hides io.EOF
// todo(fs): use golang.org/x/xerrors
return nil, err
}
if n != hdrlen {
return nil, errors.Errorf("uacp: short read on header. got %d bytes, want %d ", n, hdrlen)
}

var h Header
if _, err := h.Decode(b[:hdrlen]); err != nil {
Expand All @@ -370,18 +392,26 @@ func (c *Conn) Receive() ([]byte, error) {
return nil, errors.Errorf("uacp: message too large: %d > %d bytes", h.MessageSize, c.ack.ReceiveBufSize)
}

if _, err := io.ReadFull(c, b[hdrlen:h.MessageSize]); err != nil {
n, err = c.Read(b[hdrlen:h.MessageSize])
if err != nil {
// todo(fs): do not wrap this error since it hides io.EOF
// todo(fs): use golang.org/x/xerrors
return nil, err
}

// clear the deadline
c.SetReadDeadline(time.Time{})

if uint32(n) != h.MessageSize-hdrlen {
return nil, fmt.Errorf("uacp %d: short read on message. got %d bytes, want %d", c.id, n, h.MessageSize-hdrlen)
}

debug.Printf("uacp %d: recv %s%c with %d bytes", c.id, h.MessageType, h.ChunkType, h.MessageSize)

if h.MessageType == "ERR" {
errf := new(Error)
if _, err := errf.Decode(b[hdrlen:h.MessageSize]); err != nil {
return nil, errors.Errorf("uacp: failed to decode ERRF message: %s", err)
return nil, errors.Errorf("uacp: failed to decode ERRF message: %w", err)
}
return nil, errf
}
Expand All @@ -395,7 +425,7 @@ func (c *Conn) Send(typ string, msg interface{}) error {

body, err := ua.Encode(msg)
if err != nil {
return errors.Errorf("encode msg failed: %s", err)
return errors.Errorf("encode msg failed: %w", err)
}

h := Header{
Expand All @@ -413,12 +443,21 @@ func (c *Conn) Send(typ string, msg interface{}) error {
return errors.Errorf("encode hdr failed: %s", err)
}

if c.writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return errors.Errorf("failed to set write timeout: %w", err)
}
}

b := append(hdr, body...)
if _, err := c.Write(b); err != nil {
return errors.Errorf("write failed: %s", err)
}
debug.Printf("uacp %d: sent %s with %d bytes", c.id, typ, len(b))

// clear the deadline
c.SetWriteDeadline(time.Time{})

return nil
}

Expand Down