diff --git a/conn.go b/conn.go index 2da23bd..7efaa4f 100644 --- a/conn.go +++ b/conn.go @@ -21,6 +21,7 @@ var ErrClosedConn = errors.New("zmq4: read/write on closed connection") // Conn implements the ZeroMQ Message Transport Protocol as defined // in https://rfc.zeromq.org/spec:23/ZMTP/. type Conn struct { + ep string typ SocketType id SocketIdentity rw net.Conn @@ -64,7 +65,15 @@ func (c *Conn) Write(p []byte) (int, error) { // Open opens a ZMTP connection over rw with the given security, socket type and identity. // An optional onCloseErrorCB can be provided to inform the caller when this Conn is closed. // Open performs a complete ZMTP handshake. -func Open(rw net.Conn, sec Security, sockType SocketType, sockID SocketIdentity, server bool, onCloseErrorCB func(c *Conn)) (*Conn, error) { +func Open( + ep string, + rw net.Conn, + sec Security, + sockType SocketType, + sockID SocketIdentity, + server bool, + onCloseErrorCB func(c *Conn), +) (*Conn, error) { if rw == nil { return nil, fmt.Errorf("zmq4: invalid nil read-writer") } @@ -74,6 +83,7 @@ func Open(rw net.Conn, sec Security, sockType SocketType, sockID SocketIdentity, } conn := &Conn{ + ep: ep, typ: sockType, id: sockID, rw: rw, diff --git a/dealer.go b/dealer.go index 67ade32..d548902 100644 --- a/dealer.go +++ b/dealer.go @@ -54,6 +54,11 @@ func (dealer *dealerSocket) Dial(ep string) error { return dealer.sck.Dial(ep) } +// Dial connects a remote endpoint to the Socket. +func (dealer *dealerSocket) DialContext(ctx context.Context, ep string) error { + return dealer.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (dealer *dealerSocket) Type() SocketType { return dealer.sck.Type() diff --git a/msgio.go b/msgio.go index c9b868c..fbd2839 100644 --- a/msgio.go +++ b/msgio.go @@ -50,7 +50,7 @@ func newQReader(ctx context.Context) *qreader { } func (q *qreader) Close() error { - q.mu.RLock() + q.mu.Lock() var err error var grp errgroup.Group for i := range q.rs { @@ -58,7 +58,7 @@ func (q *qreader) Close() error { } err = grp.Wait() q.rs = nil - q.mu.RUnlock() + q.mu.Unlock() return err } diff --git a/pair.go b/pair.go index cfd691d..0e85baa 100644 --- a/pair.go +++ b/pair.go @@ -54,6 +54,11 @@ func (pair *pairSocket) Dial(ep string) error { return pair.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (pair *pairSocket) DialContext(ctx context.Context, ep string) error { + return pair.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (pair *pairSocket) Type() SocketType { return pair.sck.Type() diff --git a/pub.go b/pub.go index ebdd05e..412d8d1 100644 --- a/pub.go +++ b/pub.go @@ -70,6 +70,11 @@ func (pub *pubSocket) Dial(ep string) error { return pub.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (pub *pubSocket) DialContext(ctx context.Context, ep string) error { + return pub.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (pub *pubSocket) Type() SocketType { return pub.sck.Type() diff --git a/pull.go b/pull.go index d3e91f3..55ab10a 100644 --- a/pull.go +++ b/pull.go @@ -56,6 +56,11 @@ func (pull *pullSocket) Dial(ep string) error { return pull.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (pull *pullSocket) DialContext(ctx context.Context, ep string) error { + return pull.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (pull *pullSocket) Type() SocketType { return pull.sck.Type() diff --git a/push.go b/push.go index 08a346a..0e88d15 100644 --- a/push.go +++ b/push.go @@ -56,6 +56,11 @@ func (push *pushSocket) Dial(ep string) error { return push.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (push *pushSocket) DialContext(ctx context.Context, ep string) error { + return push.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (push *pushSocket) Type() SocketType { return push.sck.Type() diff --git a/rep.go b/rep.go index f251041..a89cad1 100644 --- a/rep.go +++ b/rep.go @@ -71,6 +71,11 @@ func (rep *repSocket) Dial(ep string) error { return rep.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (rep *repSocket) DialContext(ctx context.Context, ep string) error { + return rep.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (rep *repSocket) Type() SocketType { return rep.sck.Type() diff --git a/req.go b/req.go index d0de956..af8df70 100644 --- a/req.go +++ b/req.go @@ -69,6 +69,11 @@ func (req *reqSocket) Dial(ep string) error { return req.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (req *reqSocket) DialContext(ctx context.Context, ep string) error { + return req.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (req *reqSocket) Type() SocketType { return req.sck.Type() diff --git a/router.go b/router.go index b0175c9..458f39c 100644 --- a/router.go +++ b/router.go @@ -63,6 +63,11 @@ func (router *routerSocket) Dial(ep string) error { return router.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (router *routerSocket) DialContext(ctx context.Context, ep string) error { + return router.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (router *routerSocket) Type() SocketType { return router.sck.Type() diff --git a/security_test.go b/security_test.go index f293b1d..e171b43 100644 --- a/security_test.go +++ b/security_test.go @@ -87,7 +87,8 @@ func TestNullHandshakeReqRep(t *testing.T) { }) grp.Go(func() error { - err := req.Dial(ep) + // err := req.Dial(ep) + err := req.DialContext(ctx, ep) if err != nil { return fmt.Errorf("could not dial: %w", err) } @@ -104,6 +105,38 @@ func TestNullHandshakeReqRep(t *testing.T) { } } +func TestNullHandshakeRRFail(t *testing.T) { + + sec := nullSecurity{} + ctx, timeout := context.WithTimeout(context.Background(), 1*time.Second) + defer timeout() + + ep := "ipc://ipc-req-rep-null-sec" + cleanUp(ep) + + req := NewReq(ctx, WithSecurity(sec), WithLogger(Devnull)) + defer req.Close() + + rep := NewRep(ctx, WithSecurity(sec), WithLogger(Devnull)) + defer rep.Close() + + grp, _ := errgroup.WithContext(ctx) + grp.Go(func() error { + err := req.DialContext(ctx, ep) + if err != nil { + return fmt.Errorf("could not dial: %w", err) + } + return nil + }) + + // make Dial above fail + time.Sleep(1050 * time.Millisecond) + + if err := grp.Wait(); err == nil { + t.Error("error: timeout not detected") + } +} + func cleanUp(ep string) { if strings.HasPrefix(ep, "ipc://") { os.Remove(ep[len("ipc://"):]) diff --git a/socket.go b/socket.go index 4814efc..e3bf6e2 100644 --- a/socket.go +++ b/socket.go @@ -18,14 +18,16 @@ import ( ) const ( - defaultRetry = 250 * time.Millisecond - defaultTimeout = 5 * time.Minute + defaultRetry = 250 * time.Millisecond + // defaultTimeout = 5 * time.Minute + defaultTimeout = 15 * time.Second ) var ( errInvalidAddress = errors.New("zmq4: invalid address") - ErrBadProperty = errors.New("zmq4: bad property") + ErrBadProperty = errors.New("zmq4: bad property") + ErrUnknownTransport = errors.New("zmq4: unknown transport") ) // socket implements the ZeroMQ socket interface @@ -88,6 +90,7 @@ func newSocket(ctx context.Context, sockType SocketType, opts ...Option) *socket sck.log = log.New(os.Stderr, "zmq4: ", 0) } + go sck.connReaper() return sck } @@ -122,8 +125,9 @@ func (sck *socket) Close() error { defer sck.listener.Close() } - sck.mu.RLock() - defer sck.mu.RUnlock() + // state change, write lock! + sck.mu.Lock() + defer sck.mu.Unlock() var err error for _, conn := range sck.conns { @@ -179,25 +183,22 @@ func (sck *socket) Listen(endpoint string) error { var l net.Listener trans, ok := drivers.get(network) - switch { - case ok: - l, err = trans.Listen(sck.ctx, addr) - default: - panic("zmq4: unknown protocol " + network) + if !ok { + return ErrUnknownTransport } + l, err = trans.Listen(sck.ctx, addr) if err != nil { return fmt.Errorf("zmq4: could not listen to %q: %w", endpoint, err) } sck.listener = l - go sck.accept() - go sck.connReaper() + go sck.accept(endpoint) return nil } -func (sck *socket) accept() { +func (sck *socket) accept(ep string) { ctx, cancel := context.WithCancel(sck.ctx) defer cancel() for { @@ -212,10 +213,10 @@ func (sck *socket) accept() { continue } - zconn, err := Open(conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn) + zconn, err := Open(ep, conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn) if err != nil { // FIXME(sbinet): maybe bubble up this error to application code? - sck.log.Printf("could not open a ZMTP connection with %q: %+v", sck.ep, err) + sck.log.Printf("could not open a ZMTP connection with %q: %+v", ep, err) continue } @@ -224,8 +225,16 @@ func (sck *socket) accept() { } } -// Dial connects a remote endpoint to the Socket. +// Dial connects a remote endpoint to the socket using default timeout. func (sck *socket) Dial(endpoint string) error { + ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout()) + defer cancel() + return sck.DialContext(ctx, endpoint) +} + +// DialContext connects a remote endpoint to the Socket. +// Uses the contexts timeout. +func (sck *socket) DialContext(ctx context.Context, endpoint string) error { sck.ep = endpoint network, addr, err := splitAddr(endpoint) @@ -236,23 +245,24 @@ func (sck *socket) Dial(endpoint string) error { var ( conn net.Conn trans, ok = drivers.get(network) - retries = 0 ) + if !ok { + return ErrUnknownTransport + } + connect: - switch { - case ok: - conn, err = trans.Dial(sck.ctx, &sck.dialer, addr) + select { + case <-ctx.Done(): + return ctx.Err() default: - panic("zmq4: unknown protocol " + network) + // fall through } - + conn, err = trans.Dial(ctx, &sck.dialer, addr) if err != nil { - if retries < 10 { - retries++ - time.Sleep(sck.retry) - goto connect - } - return fmt.Errorf("zmq4: could not dial to %q (retry=%v): %w", endpoint, sck.retry, err) + td := sck.retry + // do the wait + time.Sleep(td) + goto connect } if conn == nil { @@ -267,13 +277,14 @@ connect: return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint) } - go sck.connReaper() sck.addConn(zconn) return nil } func (sck *socket) addConn(c *Conn) { sck.mu.Lock() + defer sck.mu.Unlock() + sck.conns = append(sck.conns, c) uuid, ok := c.Peer.Meta[sysSockID] if !ok { @@ -287,7 +298,6 @@ func (sck *socket) addConn(c *Conn) { if sck.w != nil { sck.w.addConn(c) } - sck.mu.Unlock() } func (sck *socket) rmConn(c *Conn) { diff --git a/sub.go b/sub.go index c530226..d4b4de6 100644 --- a/sub.go +++ b/sub.go @@ -58,11 +58,12 @@ func (sub *subSocket) Listen(ep string) error { // Dial connects a remote endpoint to the Socket. func (sub *subSocket) Dial(ep string) error { - err := sub.sck.Dial(ep) - if err != nil { - return err - } - return nil + return sub.sck.Dial(ep) +} + +// DialContext connects a remote endpoint to the Socket. +func (sub *subSocket) DialContext(ctx context.Context, ep string) error { + return sub.sck.DialContext(ctx, ep) } // Type returns the type of this Socket (PUB, SUB, ...) diff --git a/xpub.go b/xpub.go index 8bea95e..552165d 100644 --- a/xpub.go +++ b/xpub.go @@ -56,6 +56,11 @@ func (xpub *xpubSocket) Dial(ep string) error { return xpub.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (xpub *xpubSocket) DialContext(ctx context.Context, ep string) error { + return xpub.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (xpub *xpubSocket) Type() SocketType { return xpub.sck.Type() diff --git a/xsub.go b/xsub.go index 7eb06cd..9f2e6aa 100644 --- a/xsub.go +++ b/xsub.go @@ -54,6 +54,11 @@ func (xsub *xsubSocket) Dial(ep string) error { return xsub.sck.Dial(ep) } +// DialContext connects a remote endpoint to the Socket. +func (xsub *xsubSocket) DialContext(ctx context.Context, ep string) error { + return xsub.sck.DialContext(ctx, ep) +} + // Type returns the type of this Socket (PUB, SUB, ...) func (xsub *xsubSocket) Type() SocketType { return xsub.sck.Type() diff --git a/zmq4.go b/zmq4.go index eb23837..10f9a95 100644 --- a/zmq4.go +++ b/zmq4.go @@ -7,7 +7,10 @@ // For more informations, see http://zeromq.org. package zmq4 -import "net" +import ( + "context" + "net" +) // Socket represents a ZeroMQ socket. type Socket interface { @@ -32,6 +35,9 @@ type Socket interface { // Dial connects a remote endpoint to the Socket. Dial(ep string) error + // Dial connects a remote endpoint to the Socket. + DialContext(ctx context.Context, ep string) error + // Type returns the type of this Socket (PUB, SUB, ...) Type() SocketType