Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#85 Added a DialContext method to all sockets. #98

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var ErrClosedConn = errors.New("zmq4: read/write on closed connection")
// Conn implements the ZeroMQ Message Transport Protocol as defined
// in https://rfc.zeromq.org/spec:23/ZMTP/.
type Conn struct {
ep string
typ SocketType
id SocketIdentity
rw net.Conn
Expand Down Expand Up @@ -64,7 +65,15 @@ func (c *Conn) Write(p []byte) (int, error) {
// Open opens a ZMTP connection over rw with the given security, socket type and identity.
// An optional onCloseErrorCB can be provided to inform the caller when this Conn is closed.
// Open performs a complete ZMTP handshake.
func Open(rw net.Conn, sec Security, sockType SocketType, sockID SocketIdentity, server bool, onCloseErrorCB func(c *Conn)) (*Conn, error) {
func Open(
ep string,
rw net.Conn,
sec Security,
sockType SocketType,
sockID SocketIdentity,
server bool,
onCloseErrorCB func(c *Conn),
) (*Conn, error) {
if rw == nil {
return nil, fmt.Errorf("zmq4: invalid nil read-writer")
}
Expand All @@ -74,6 +83,7 @@ func Open(rw net.Conn, sec Security, sockType SocketType, sockID SocketIdentity,
}

conn := &Conn{
ep: ep,
typ: sockType,
id: sockID,
rw: rw,
Expand Down
5 changes: 5 additions & 0 deletions dealer.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func (dealer *dealerSocket) Dial(ep string) error {
return dealer.sck.Dial(ep)
}

// Dial connects a remote endpoint to the Socket.
func (dealer *dealerSocket) DialContext(ctx context.Context, ep string) error {
return dealer.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (dealer *dealerSocket) Type() SocketType {
return dealer.sck.Type()
Expand Down
4 changes: 2 additions & 2 deletions msgio.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func newQReader(ctx context.Context) *qreader {
}

func (q *qreader) Close() error {
q.mu.RLock()
q.mu.Lock()
var err error
var grp errgroup.Group
for i := range q.rs {
grp.Go(q.rs[i].Close)
}
err = grp.Wait()
q.rs = nil
q.mu.RUnlock()
q.mu.Unlock()
return err
}

Expand Down
5 changes: 5 additions & 0 deletions pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func (pair *pairSocket) Dial(ep string) error {
return pair.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (pair *pairSocket) DialContext(ctx context.Context, ep string) error {
return pair.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (pair *pairSocket) Type() SocketType {
return pair.sck.Type()
Expand Down
5 changes: 5 additions & 0 deletions pub.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ func (pub *pubSocket) Dial(ep string) error {
return pub.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (pub *pubSocket) DialContext(ctx context.Context, ep string) error {
return pub.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (pub *pubSocket) Type() SocketType {
return pub.sck.Type()
Expand Down
5 changes: 5 additions & 0 deletions pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ func (pull *pullSocket) Dial(ep string) error {
return pull.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (pull *pullSocket) DialContext(ctx context.Context, ep string) error {
return pull.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (pull *pullSocket) Type() SocketType {
return pull.sck.Type()
Expand Down
5 changes: 5 additions & 0 deletions push.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ func (push *pushSocket) Dial(ep string) error {
return push.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (push *pushSocket) DialContext(ctx context.Context, ep string) error {
return push.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (push *pushSocket) Type() SocketType {
return push.sck.Type()
Expand Down
5 changes: 5 additions & 0 deletions rep.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func (rep *repSocket) Dial(ep string) error {
return rep.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (rep *repSocket) DialContext(ctx context.Context, ep string) error {
return rep.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (rep *repSocket) Type() SocketType {
return rep.sck.Type()
Expand Down
5 changes: 5 additions & 0 deletions req.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ func (req *reqSocket) Dial(ep string) error {
return req.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (req *reqSocket) DialContext(ctx context.Context, ep string) error {
return req.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (req *reqSocket) Type() SocketType {
return req.sck.Type()
Expand Down
5 changes: 5 additions & 0 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ func (router *routerSocket) Dial(ep string) error {
return router.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (router *routerSocket) DialContext(ctx context.Context, ep string) error {
return router.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (router *routerSocket) Type() SocketType {
return router.sck.Type()
Expand Down
35 changes: 34 additions & 1 deletion security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func TestNullHandshakeReqRep(t *testing.T) {
})

grp.Go(func() error {
err := req.Dial(ep)
// err := req.Dial(ep)
err := req.DialContext(ctx, ep)
if err != nil {
return fmt.Errorf("could not dial: %w", err)
}
Expand All @@ -104,6 +105,38 @@ func TestNullHandshakeReqRep(t *testing.T) {
}
}

func TestNullHandshakeRRFail(t *testing.T) {

sec := nullSecurity{}
ctx, timeout := context.WithTimeout(context.Background(), 1*time.Second)
defer timeout()

ep := "ipc://ipc-req-rep-null-sec"
cleanUp(ep)

req := NewReq(ctx, WithSecurity(sec), WithLogger(Devnull))
defer req.Close()

rep := NewRep(ctx, WithSecurity(sec), WithLogger(Devnull))
defer rep.Close()

grp, _ := errgroup.WithContext(ctx)
grp.Go(func() error {
err := req.DialContext(ctx, ep)
if err != nil {
return fmt.Errorf("could not dial: %w", err)
}
return nil
})

// make Dial above fail
time.Sleep(1050 * time.Millisecond)

if err := grp.Wait(); err == nil {
t.Error("error: timeout not detected")
}
}

func cleanUp(ep string) {
if strings.HasPrefix(ep, "ipc://") {
os.Remove(ep[len("ipc://"):])
Expand Down
70 changes: 40 additions & 30 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ import (
)

const (
defaultRetry = 250 * time.Millisecond
defaultTimeout = 5 * time.Minute
defaultRetry = 250 * time.Millisecond
// defaultTimeout = 5 * time.Minute
defaultTimeout = 15 * time.Second
)

var (
errInvalidAddress = errors.New("zmq4: invalid address")

ErrBadProperty = errors.New("zmq4: bad property")
ErrBadProperty = errors.New("zmq4: bad property")
ErrUnknownTransport = errors.New("zmq4: unknown transport")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a value, I'd rather have it as a type, TransportError or something, so we could provide the name of the unknown transport.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it a feature request.
Not related to this PR at all.

)

// socket implements the ZeroMQ socket interface
Expand Down Expand Up @@ -88,6 +90,7 @@ func newSocket(ctx context.Context, sockType SocketType, opts ...Option) *socket
sck.log = log.New(os.Stderr, "zmq4: ", 0)
}

go sck.connReaper()
return sck
}

Expand Down Expand Up @@ -122,8 +125,9 @@ func (sck *socket) Close() error {
defer sck.listener.Close()
}

sck.mu.RLock()
defer sck.mu.RUnlock()
// state change, write lock!
sck.mu.Lock()
defer sck.mu.Unlock()

var err error
for _, conn := range sck.conns {
Expand Down Expand Up @@ -179,25 +183,22 @@ func (sck *socket) Listen(endpoint string) error {
var l net.Listener

trans, ok := drivers.get(network)
switch {
case ok:
l, err = trans.Listen(sck.ctx, addr)
default:
panic("zmq4: unknown protocol " + network)
if !ok {
return ErrUnknownTransport
}

l, err = trans.Listen(sck.ctx, addr)
if err != nil {
return fmt.Errorf("zmq4: could not listen to %q: %w", endpoint, err)
}
sck.listener = l

go sck.accept()
go sck.connReaper()
go sck.accept(endpoint)

return nil
}

func (sck *socket) accept() {
func (sck *socket) accept(ep string) {
ctx, cancel := context.WithCancel(sck.ctx)
defer cancel()
for {
Expand All @@ -212,10 +213,10 @@ func (sck *socket) accept() {
continue
}

zconn, err := Open(conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn)
zconn, err := Open(ep, conn, sck.sec, sck.typ, sck.id, true, sck.scheduleRmConn)
if err != nil {
// FIXME(sbinet): maybe bubble up this error to application code?
sck.log.Printf("could not open a ZMTP connection with %q: %+v", sck.ep, err)
sck.log.Printf("could not open a ZMTP connection with %q: %+v", ep, err)
continue
}

Expand All @@ -224,8 +225,16 @@ func (sck *socket) accept() {
}
}

// Dial connects a remote endpoint to the Socket.
// Dial connects a remote endpoint to the socket using default timeout.
func (sck *socket) Dial(endpoint string) error {
ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout())
defer cancel()
return sck.DialContext(ctx, endpoint)
}

// DialContext connects a remote endpoint to the Socket.
// Uses the contexts timeout.
func (sck *socket) DialContext(ctx context.Context, endpoint string) error {
sck.ep = endpoint

network, addr, err := splitAddr(endpoint)
Expand All @@ -236,23 +245,24 @@ func (sck *socket) Dial(endpoint string) error {
var (
conn net.Conn
trans, ok = drivers.get(network)
retries = 0
)
if !ok {
return ErrUnknownTransport
}

connect:
switch {
case ok:
conn, err = trans.Dial(sck.ctx, &sck.dialer, addr)
select {
case <-ctx.Done():
return ctx.Err()
default:
panic("zmq4: unknown protocol " + network)
// fall through
}

conn, err = trans.Dial(ctx, &sck.dialer, addr)
if err != nil {
if retries < 10 {
retries++
time.Sleep(sck.retry)
goto connect
}
return fmt.Errorf("zmq4: could not dial to %q (retry=%v): %w", endpoint, sck.retry, err)
td := sck.retry
// do the wait
time.Sleep(td)
goto connect
}

if conn == nil {
Expand All @@ -267,13 +277,14 @@ connect:
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
}

go sck.connReaper()
sck.addConn(zconn)
return nil
}

func (sck *socket) addConn(c *Conn) {
sck.mu.Lock()
defer sck.mu.Unlock()

sck.conns = append(sck.conns, c)
uuid, ok := c.Peer.Meta[sysSockID]
if !ok {
Expand All @@ -287,7 +298,6 @@ func (sck *socket) addConn(c *Conn) {
if sck.w != nil {
sck.w.addConn(c)
}
sck.mu.Unlock()
}

func (sck *socket) rmConn(c *Conn) {
Expand Down
11 changes: 6 additions & 5 deletions sub.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ func (sub *subSocket) Listen(ep string) error {

// Dial connects a remote endpoint to the Socket.
func (sub *subSocket) Dial(ep string) error {
err := sub.sck.Dial(ep)
if err != nil {
return err
}
return nil
return sub.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (sub *subSocket) DialContext(ctx context.Context, ep string) error {
return sub.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
Expand Down
5 changes: 5 additions & 0 deletions xpub.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ func (xpub *xpubSocket) Dial(ep string) error {
return xpub.sck.Dial(ep)
}

// DialContext connects a remote endpoint to the Socket.
func (xpub *xpubSocket) DialContext(ctx context.Context, ep string) error {
return xpub.sck.DialContext(ctx, ep)
}

// Type returns the type of this Socket (PUB, SUB, ...)
func (xpub *xpubSocket) Type() SocketType {
return xpub.sck.Type()
Expand Down