Skip to content

Commit

Permalink
made websocket package more testable
Browse files Browse the repository at this point in the history
  • Loading branch information
kelindar committed Nov 16, 2017
1 parent 6b57bb7 commit f62c673
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 13 deletions.
33 changes: 20 additions & 13 deletions network/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ import (
"github.com/gorilla/websocket"
)

type websocketConn interface {
NextReader() (messageType int, r io.Reader, err error)
NextWriter(messageType int) (io.WriteCloser, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}

// websocketConn represents a websocket connection.
type websocketTransport struct {
sync.Mutex
socket *websocket.Conn
socket websocketConn
reader io.Reader
closing chan bool
}
Expand Down Expand Up @@ -59,7 +69,7 @@ func TryUpgrade(w http.ResponseWriter, r *http.Request) (net.Conn, bool) {
}

// newConn creates a new transport from websocket.
func newConn(ws *websocket.Conn) net.Conn {
func newConn(ws websocketConn) net.Conn {
conn := &websocketTransport{
socket: ws,
closing: make(chan bool),
Expand Down Expand Up @@ -125,13 +135,11 @@ func (c *websocketTransport) Write(b []byte) (n int, err error) {
defer c.Unlock()

var w io.WriteCloser
if w, err = c.socket.NextWriter(websocket.BinaryMessage); err != nil {
return
}
if n, err = w.Write(b); err != nil {
return
if w, err = c.socket.NextWriter(websocket.BinaryMessage); err == nil {
if n, err = w.Write(b); err == nil {
err = w.Close()
}
}
err = w.Close()
return
}

Expand All @@ -153,12 +161,11 @@ func (c *websocketTransport) RemoteAddr() net.Addr {
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
func (c *websocketTransport) SetDeadline(t time.Time) error {
if err := c.socket.SetReadDeadline(t); err != nil {
return err
func (c *websocketTransport) SetDeadline(t time.Time) (err error) {
if err = c.socket.SetReadDeadline(t); err == nil {
err = c.socket.SetWriteDeadline(t)
}

return c.socket.SetWriteDeadline(t)
return
}

// SetReadDeadline sets the deadline for future Read calls
Expand Down
99 changes: 99 additions & 0 deletions network/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,49 @@ package websocket

import (
"bytes"
"io"
"net"
"net/http/httptest"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

type writer bytes.Buffer

func (w *writer) Close() error { return nil }
func (w *writer) Write(data []byte) (n int, err error) { return ((*bytes.Buffer)(w)).Write(data) }

type conn struct {
read []byte
write *writer
}

func (c *conn) NextReader() (messageType int, r io.Reader, err error) {
messageType = websocket.BinaryMessage
r = bytes.NewBuffer(c.read)
if c.read == nil {
err = io.EOF
}
return
}

func (c *conn) NextWriter(messageType int) (w io.WriteCloser, err error) {
w = c.write
if c.write == nil {
err = io.EOF
}

return
}
func (c *conn) Close() error { return nil }
func (c *conn) LocalAddr() net.Addr { return &net.IPAddr{} }
func (c *conn) RemoteAddr() net.Addr { return &net.IPAddr{} }
func (c *conn) SetReadDeadline(t time.Time) error { return nil }
func (c *conn) SetWriteDeadline(t time.Time) error { return nil }

func TestTryUpgradeNil(t *testing.T) {
_, ok := TryUpgrade(nil, nil)
assert.Equal(t, false, ok)
Expand All @@ -34,3 +71,65 @@ func TestTryUpgrade(t *testing.T) {
//assert.NotNil(t, ws)
//assert.True(t, ok)
}

func TestRead_EOF(t *testing.T) {
c := newConn(new(conn))

_, err := c.Read([]byte{})
assert.Error(t, io.EOF, err)
}

func TestRead(t *testing.T) {
message := []byte("hello world")
c := &websocketTransport{
socket: &conn{
read: message,
},
closing: make(chan bool),
}

buffer := make([]byte, 64)
n, err := c.Read(buffer)
assert.NoError(t, err)
assert.Equal(t, message, buffer[:n])
}

func TestWrite(t *testing.T) {
message := []byte("hello world")
buffer := new(bytes.Buffer)
c := &websocketTransport{
socket: &conn{
write: (*writer)(buffer),
},
closing: make(chan bool),
}

_, err := c.Write(message)
assert.NoError(t, err)
assert.Equal(t, message, buffer.Bytes())
}

func TestMisc(t *testing.T) {
c := &websocketTransport{
socket: &conn{},
closing: make(chan bool),
}

err := c.Close()
assert.NoError(t, err)

err = c.SetDeadline(time.Now())
assert.NoError(t, err)

err = c.SetReadDeadline(time.Now())
assert.NoError(t, err)

err = c.SetWriteDeadline(time.Now())
assert.NoError(t, err)

addr1 := c.LocalAddr()
assert.Equal(t, "", addr1.String())

addr2 := c.RemoteAddr()
assert.Equal(t, "", addr2.String())
}

0 comments on commit f62c673

Please sign in to comment.