Skip to content

Commit

Permalink
Merge ae3268d into b8f1996
Browse files Browse the repository at this point in the history
  • Loading branch information
Stebalien committed Jan 15, 2018
2 parents b8f1996 + ae3268d commit b09babc
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 74 deletions.
104 changes: 42 additions & 62 deletions lazy.go → lazyClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Multistream interface {
// NewMSSelect returns a new Multistream which is able to perform
// protocol selection with a MultistreamMuxer.
func NewMSSelect(c io.ReadWriteCloser, proto string) Multistream {
return &lazyConn{
return &lazyClientConn{
protos: []string{ProtocolID, proto},
con: c,
}
Expand All @@ -27,111 +27,91 @@ func NewMSSelect(c io.ReadWriteCloser, proto string) Multistream {
// perform any protocol selection. If you are using a MultistreamMuxer, use
// NewMSSelect.
func NewMultistream(c io.ReadWriteCloser, proto string) Multistream {
return &lazyConn{
return &lazyClientConn{
protos: []string{proto},
con: c,
}
}

type lazyConn struct {
rhandshake bool // only accessed by 'Read' should not call read async
type lazyClientConn struct {
rhandshakeOnce sync.Once
rerr error

rhlock sync.Mutex
rhsync bool //protected by mutex
rerr error

whandshake bool

whlock sync.Mutex
whsync bool
werr error
whandshakeOnce sync.Once
werr error

protos []string
con io.ReadWriteCloser
}

func (l *lazyConn) Read(b []byte) (int, error) {
if !l.rhandshake {
go l.writeHandshake()
err := l.readHandshake()
if err != nil {
return 0, err
}

l.rhandshake = true
func (l *lazyClientConn) Read(b []byte) (int, error) {
l.rhandshakeOnce.Do(func() {
go l.whandshakeOnce.Do(l.doWriteHandshake)
l.doReadHandshake()
})
if l.rerr != nil {
return 0, l.rerr
}

if len(b) == 0 {
return 0, nil
}

return l.con.Read(b)
}

func (l *lazyConn) readHandshake() error {
l.rhlock.Lock()
defer l.rhlock.Unlock()

// if we've already done this, exit
if l.rhsync {
return l.rerr
}
l.rhsync = true

func (l *lazyClientConn) doReadHandshake() {
for _, proto := range l.protos {
// read protocol
tok, err := ReadNextToken(l.con)
if err != nil {
l.rerr = err
return err
return
}

if tok != proto {
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto)
return l.rerr
return
}
}

return nil
}

func (l *lazyConn) writeHandshake() error {
l.whlock.Lock()
defer l.whlock.Unlock()

if l.whsync {
return l.werr
}

l.whsync = true
func (l *lazyClientConn) doWriteHandshake() {
l.doWriteHandshakeWithExtra(nil)
}

// Perform the write handshake but *also* write some extra data.
func (l *lazyClientConn) doWriteHandshakeWithExtra(extra []byte) int {
buf := bufio.NewWriter(l.con)
for _, proto := range l.protos {
err := delimWrite(buf, []byte(proto))
if err != nil {
l.werr = err
return err
l.werr = delimWrite(buf, []byte(proto))
if l.werr != nil {
return 0
}
}

n := 0
if len(extra) > 0 {
n, l.werr = buf.Write(extra)
if l.werr != nil {
return n
}
}
l.werr = buf.Flush()
return l.werr
return n
}

func (l *lazyConn) Write(b []byte) (int, error) {
if !l.whandshake {
go l.readHandshake()
err := l.writeHandshake()
if err != nil {
return 0, err
}

l.whandshake = true
func (l *lazyClientConn) Write(b []byte) (int, error) {
n := 0
l.whandshakeOnce.Do(func() {
go l.rhandshakeOnce.Do(l.doReadHandshake)
n = l.doWriteHandshakeWithExtra(b)
})
if l.werr != nil || n > 0 {
return n, l.werr
}

return l.con.Write(b)
}

func (l *lazyConn) Close() error {
func (l *lazyClientConn) Close() error {
return l.con.Close()
}
33 changes: 33 additions & 0 deletions lazyServer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package multistream

import (
"io"
"sync"
)

type lazyServerConn struct {
waitForHandshake sync.Once
werr error

con io.ReadWriteCloser
}

func (l *lazyServerConn) Write(b []byte) (int, error) {
l.waitForHandshake.Do(func() { panic("didn't initiate handshake") })
if l.werr != nil {
return 0, l.werr
}
return l.con.Write(b)
}

func (l *lazyServerConn) Read(b []byte) (int, error) {
// TODO: The tests require this for some reason. Not sure if it's correct...
if len(b) == 0 {
return 0, nil
}
return l.con.Read(b)
}

func (l *lazyServerConn) Close() error {
return l.con.Close()
}
22 changes: 10 additions & 12 deletions multistream.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,15 @@ func (msm *MultistreamMuxer) NegotiateLazy(rwc io.ReadWriteCloser) (Multistream,
writeErr := make(chan error, 1)
defer close(pval)

lzc := &lazyConn{
con: rwc,
rhandshake: true,
rhsync: true,
lzc := &lazyServerConn{
con: rwc,
}

// take lock here to prevent a race condition where the reads below from
// finishing and taking the write lock before this goroutine can
lzc.whlock.Lock()
started := make(chan struct{})
go lzc.waitForHandshake.Do(func() {
close(started)

go func() {
defer close(writeErr)
defer lzc.whlock.Unlock()
lzc.whsync = true

if err := delimWriteBuffered(rwc, []byte(ProtocolID)); err != nil {
lzc.werr = err
Expand All @@ -214,8 +209,8 @@ func (msm *MultistreamMuxer) NegotiateLazy(rwc io.ReadWriteCloser) (Multistream,
return
}
}
lzc.whandshake = true
}()
})
<-started

line, err := ReadNextToken(rwc)
if err != nil {
Expand All @@ -232,6 +227,7 @@ loop:
// Now read and respond to commands until they send a valid protocol id
tok, err := ReadNextToken(rwc)
if err != nil {
rwc.Close()
return nil, "", nil, err
}

Expand All @@ -240,6 +236,7 @@ loop:
select {
case pval <- "ls":
case err := <-writeErr:
rwc.Close()
return nil, "", nil, err
}
default:
Expand All @@ -248,6 +245,7 @@ loop:
select {
case pval <- "na":
case err := <-writeErr:
rwc.Close()
return nil, "", nil, err
}
continue loop
Expand Down
1 change: 1 addition & 0 deletions multistream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func TestNegLazyStressWrite(t *testing.T) {
t.Error(err)
return
}

}
}()

Expand Down

0 comments on commit b09babc

Please sign in to comment.