Skip to content

Commit

Permalink
remove need for manually mocking websocket server
Browse files Browse the repository at this point in the history
  • Loading branch information
iansuvak committed Apr 6, 2023
1 parent 85e31ea commit 660d682
Showing 1 changed file with 13 additions and 107 deletions.
120 changes: 13 additions & 107 deletions network/wsNetwork_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"math/rand"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"runtime"
Expand All @@ -42,7 +41,6 @@ import (
"github.com/stretchr/testify/require"

"github.com/algorand/go-deadlock"
"github.com/algorand/websocket"

"github.com/algorand/go-algorand/config"
"github.com/algorand/go-algorand/crypto"
Expand Down Expand Up @@ -3762,138 +3760,46 @@ func TestWebsocketNetworkTelemetryTCP(t *testing.T) {
t.Log("closed detailsB", string(pcdB))
}

type mockServer struct {
*httptest.Server
URL string
t *testing.T

waitForClientClose bool
sendClose bool
sendCloseWC bool

gotClientClose chan struct{}
}

type mockHandler struct {
*testing.T
s *mockServer
}

var mockUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
EnableCompression: true,
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
http.Error(w, reason.Error(), status)
},
}

func (t mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Set the required headers to successfully establish a connection
responseHeader := http.Header{}
responseHeader.Add(ProtocolVersionHeader, ProtocolVersion)
responseHeader.Add(GenesisHeader, genesisID)
responseHeader.Add(NodeRandomHeader, "randomHeader")
ws, err := mockUpgrader.Upgrade(w, r, responseHeader)
if err != nil {
t.Logf("Upgrade: %v", err)
return
}
defer ws.Close()

for true {
// echo a message back to the client
op, rd, err := ws.NextReader()
if err != nil {
if _, ok := err.(*websocket.CloseError); ok && t.s.waitForClientClose {
t.Log("got client close")
close(t.s.gotClientClose)
return
}
t.Logf("NextReader: %v", err)
return
}
wr, err := ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
t.Log("sent message")
if !t.s.waitForClientClose {
break
}
}
if t.s.sendClose {
err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
t.Logf("WriteMessage(CloseMessage): %v", err)
return
}
t.Log("sent close")
} else if t.s.sendCloseWC {
err = ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(5*time.Second))
if err != nil {
t.Logf("WriteControl(CloseMessage): %v", err)
return
}
t.Log("sent close")
}
}

func makeWsProto(s string) string {
return "ws" + strings.TrimPrefix(s, "http")
}

func newServer(t *testing.T) *mockServer {
var s mockServer
s.Server = httptest.NewServer(mockHandler{t, &s})
s.Server.URL += ""
s.URL = makeWsProto(s.Server.URL)
return &s
}

func TestMaxHeaderSize(t *testing.T) {
partitiontest.PartitionTest(t)

netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"})
netA.config.GossipFanout = 1

s := newServer(t)
s.waitForClientClose = true
defer s.Close()
netB := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netB"})
netB.config.GossipFanout = 1

netA.Start()
defer netA.Stop()
netB.Start()
defer netB.Stop()

addrB, ok := netB.Address()
require.True(t, ok)
gossipB, err := netB.addrToGossipAddr(addrB)
require.NoError(t, err)

// First make sure that the regular connection with default max header size works
netA.wsMaxHeaderBytes = wsMaxHeaderBytes
netA.wg.Add(1)
netA.tryConnect(s.URL, s.URL)
netA.tryConnect(addrB, gossipB)
time.Sleep(250 * time.Millisecond)
assert.Equal(t, 1, len(netA.peers))

netA.removePeer(netA.peers[0], disconnectReasonNone)
assert.Zero(t, len(netA.peers))

// Now try to connect with a max header size that is too small
netA.wsMaxHeaderBytes = 64
netA.wsMaxHeaderBytes = 128
netA.wg.Add(1)
netA.tryConnect(s.URL, s.URL)
netA.tryConnect(addrB, gossipB)
time.Sleep(250 * time.Millisecond)
assert.Zero(t, len(netA.peers))

// Test that setting 0 disables the max header size check
netA.wsMaxHeaderBytes = 0
netA.wg.Add(1)
netA.tryConnect(s.URL, s.URL)
netA.tryConnect(addrB, gossipB)
time.Sleep(250 * time.Millisecond)
assert.Equal(t, 1, len(netA.peers))
}

0 comments on commit 660d682

Please sign in to comment.