diff --git a/transport/memory.go b/transport/memory.go index 09f1f9d97e..4cfe47a364 100644 --- a/transport/memory.go +++ b/transport/memory.go @@ -2,8 +2,10 @@ package transport import ( "context" + "encoding/gob" "errors" "fmt" + "io" "math/rand" "net" "sync" @@ -14,8 +16,16 @@ import ( ) type memorySocket struct { - recv chan *Message - send chan *Message + // True server mode, False client mode + server bool + // Client receiver of io.Pipe with gob + crecv *gob.Decoder + // Client sender of the io.Pipe with gob + csend *gob.Encoder + // Server receiver of the io.Pip with gob + srecv *gob.Decoder + // Server sender of the io.Pip with gob + ssend *gob.Encoder // sock exit exit chan bool // listener exit @@ -27,7 +37,6 @@ type memorySocket struct { // for send/recv Timeout timeout time.Duration ctx context.Context - sync.RWMutex } type memoryClient struct { @@ -52,9 +61,6 @@ type memoryTransport struct { } func (ms *memorySocket) Recv(m *Message) error { - ms.RLock() - defer ms.RUnlock() - ctx := ms.ctx if ms.timeout > 0 { var cancel context.CancelFunc @@ -66,12 +72,23 @@ func (ms *memorySocket) Recv(m *Message) error { case <-ctx.Done(): return ctx.Err() case <-ms.exit: - return errors.New("connection closed") + // connection closed + return io.EOF case <-ms.lexit: - return errors.New("server connection closed") - case cm := <-ms.recv: - *m = *cm + // Server connection closed + return io.EOF + default: + if ms.server { + if err := ms.srecv.Decode(m); err != nil { + return err + } + } else { + if err := ms.crecv.Decode(m); err != nil { + return err + } + } } + return nil } @@ -84,9 +101,6 @@ func (ms *memorySocket) Remote() string { } func (ms *memorySocket) Send(m *Message) error { - ms.RLock() - defer ms.RUnlock() - ctx := ms.ctx if ms.timeout > 0 { var cancel context.CancelFunc @@ -98,17 +112,27 @@ func (ms *memorySocket) Send(m *Message) error { case <-ctx.Done(): return ctx.Err() case <-ms.exit: - return errors.New("connection closed") + // connection closed + return io.EOF case <-ms.lexit: - return errors.New("server connection closed") - case ms.send <- m: + // Server connection closed + return io.EOF + default: + if ms.server { + if err := ms.ssend.Encode(m); err != nil { + return err + } + } else { + if err := ms.csend.Encode(m); err != nil { + return err + } + } } + return nil } func (ms *memorySocket) Close() error { - ms.Lock() - defer ms.Unlock() select { case <-ms.exit: return nil @@ -141,10 +165,11 @@ func (m *memoryListener) Accept(fn func(Socket)) error { return nil case c := <-m.conn: go fn(&memorySocket{ + server: true, lexit: c.lexit, exit: c.exit, - send: c.recv, - recv: c.send, + ssend: c.ssend, + srecv: c.srecv, local: c.Remote(), remote: c.Local(), timeout: m.topts.Timeout, @@ -168,11 +193,16 @@ func (m *memoryTransport) Dial(addr string, opts ...DialOption) (Client, error) o(&options) } + creader, swriter := io.Pipe() + sreader, cwriter := io.Pipe() + client := &memoryClient{ &memorySocket{ - send: make(chan *Message), - recv: make(chan *Message), - exit: make(chan bool), + server: false, + csend: gob.NewEncoder(cwriter), + crecv: gob.NewDecoder(creader), + ssend: gob.NewEncoder(swriter), + srecv: gob.NewDecoder(sreader), exit: make(chan bool), lexit: listener.exit, local: addr, remote: addr,