diff --git a/client.go b/client.go index b03145cd..db11c80d 100644 --- a/client.go +++ b/client.go @@ -160,6 +160,9 @@ func (c *Client) Dial(ctx context.Context) error { ctx = context.Background() } + ctx, cancel := context.WithTimeout(ctx, c.cfg.DialTimeout) + defer cancel() + c.once.Do(func() { c.session.Store((*Session)(nil)) }) if c.sechan != nil { return errors.Errorf("secure channel already connected") diff --git a/config.go b/config.go index 2deeaf6f..24dc6374 100644 --- a/config.go +++ b/config.go @@ -28,6 +28,7 @@ func DefaultClientConfig() *uasc.Config { SecurityMode: ua.MessageSecurityModeNone, Lifetime: uint32(time.Hour / time.Millisecond), RequestTimeout: 10 * time.Second, + DialTimeout: 10 * time.Second, } } @@ -416,3 +417,10 @@ func RequestTimeout(t time.Duration) Option { c.RequestTimeout = t } } + +// DialTimeout sets the timeout for name resolution and establishment of a network connection +func DialTimeout(t time.Duration) Option { + return func(c *uasc.Config, sc *uasc.SessionConfig) { + c.DialTimeout = t + } +} diff --git a/examples/server/server.go b/examples/server/server.go index 83bd0ac5..f833c8c4 100644 --- a/examples/server/server.go +++ b/examples/server/server.go @@ -21,7 +21,7 @@ func main() { ctx := context.Background() log.Printf("Listening on %s", *endpoint) - l, err := uacp.Listen(*endpoint, nil) + l, err := uacp.Listen(ctx, *endpoint, nil) if err != nil { log.Fatal(err) } diff --git a/uacp/conn.go b/uacp/conn.go index 03b2798f..13db65b3 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -9,6 +9,7 @@ import ( "io" "net" "sync/atomic" + "time" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" @@ -23,6 +24,7 @@ const ( DefaultSendBufSize = 0xffff DefaultMaxChunkCount = 512 DefaultMaxMessageSize = 2 * MB + DefaultDialTimeout = time.Second * 10 ) // connid stores the current connection id. updated with atomic.AddUint32 @@ -35,12 +37,15 @@ func nextid() uint32 { func Dial(ctx context.Context, endpoint string) (*Conn, error) { debug.Printf("Connect to %s", endpoint) - _, raddr, err := ResolveEndpoint(endpoint) + + _, raddr, err := ResolveEndpoint(ctx, endpoint) if err != nil { return nil, err } + var dialer net.Dialer - c, err := dialer.DialContext(ctx, "tcp", raddr.String()) + + c, err := dialer.DialContext(ctx, "tcp", raddr.Host) if err != nil { return nil, err } @@ -79,7 +84,7 @@ type Listener struct { // If the IP field of laddr is nil or an unspecified IP address, Listen listens // on all available unicast and anycast IP addresses of the local system. // If the Port field of laddr is 0, a port number is automatically chosen. -func Listen(endpoint string, ack *Acknowledge) (*Listener, error) { +func Listen(ctx context.Context, endpoint string, ack *Acknowledge) (*Listener, error) { if ack == nil { ack = &Acknowledge{ ReceiveBufSize: DefaultReceiveBufSize, @@ -89,16 +94,18 @@ func Listen(endpoint string, ack *Acknowledge) (*Listener, error) { } } - network, laddr, err := ResolveEndpoint(endpoint) + _, laddr, err := ResolveEndpoint(ctx, endpoint) if err != nil { return nil, err } - l, err := net.ListenTCP(network, laddr) + + var lc net.ListenConfig + l, err := lc.Listen(ctx, "tcp", laddr.Host) if err != nil { return nil, err } return &Listener{ - l: l, + l: l.(*net.TCPListener), ack: ack, endpoint: endpoint, }, nil diff --git a/uacp/conn_test.go b/uacp/conn_test.go index decd211a..392002f5 100644 --- a/uacp/conn_test.go +++ b/uacp/conn_test.go @@ -16,7 +16,7 @@ import ( func TestConn(t *testing.T) { t.Run("server exists ", func(t *testing.T) { ep := "opc.tcp://127.0.0.1:4840/foo/bar" - ln, err := Listen(ep, nil) + ln, err := Listen(context.Background(), ep, nil) if err != nil { t.Fatal(err) } @@ -64,7 +64,7 @@ func TestConn(t *testing.T) { func TestClientWrite(t *testing.T) { ep := "opc.tcp://127.0.0.1:4840/foo/bar" - ln, err := Listen(ep, nil) + ln, err := Listen(context.Background(), ep, nil) if err != nil { t.Fatal(err) } @@ -127,7 +127,7 @@ NEXT: func TestServerWrite(t *testing.T) { ep := "opc.tcp://127.0.0.1:4840/foo/bar" - ln, err := Listen(ep, nil) + ln, err := Listen(context.Background(), ep, nil) if err != nil { t.Fatal(err) } diff --git a/uacp/endpoint.go b/uacp/endpoint.go index b1b58e98..95a7f004 100644 --- a/uacp/endpoint.go +++ b/uacp/endpoint.go @@ -5,31 +5,49 @@ package uacp import ( + "context" "net" - "strings" + "net/url" "github.com/gopcua/opcua/errors" ) -// ResolveEndpoint returns network type, address, and error splitted from EndpointURL. +const defaultPort = "4840" + +// ResolveEndpoint returns network type, address, and error split from EndpointURL. // // Expected format of input is "opc.tcp://