Skip to content

Commit

Permalink
Merge ad9bf2e into bd2ed36
Browse files Browse the repository at this point in the history
  • Loading branch information
igm committed Nov 16, 2019
2 parents bd2ed36 + ad9bf2e commit c1d44cc
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
28 changes: 19 additions & 9 deletions sockjs/eventsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ func TestHandler_EventSource(t *testing.T) {
h := newTestHandler()
h.options.ResponseLimit = 1024
go func() {
time.Sleep(1 * time.Millisecond)
h.sessionsMux.Lock()
defer h.sessionsMux.Unlock()
sess := h.sessions["session"]
sess.Lock()
defer sess.Unlock()
recv := sess.recv
recv.close()
var sess *session
for exists := false; !exists; {
h.sessionsMux.Lock()
sess, exists = h.sessions["session"]
h.sessionsMux.Unlock()
}
for exists := false; !exists; {
sess.RLock()
exists = sess.recv != nil
sess.RUnlock()
}
sess.RLock()
sess.recv.close()
sess.RUnlock()
}()
h.eventSource(rw, req)
contentType := rw.Header().Get("content-type")
Expand Down Expand Up @@ -65,7 +71,11 @@ func TestHandler_EventSourceConnectionInterrupted(t *testing.T) {
rw := newClosableRecorder()
close(rw.closeNotifCh)
h.eventSource(rw, req)
time.Sleep(1 * time.Millisecond)
select {
case <-sess.closeCh:
case <-time.After(1 * time.Second):
t.Errorf("session close channel should be closed")
}
sess.Lock()
if sess.state != SessionClosed {
t.Errorf("Session should be closed")
Expand Down
13 changes: 7 additions & 6 deletions sockjs/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/url"
"testing"
"time"

"github.com/stretchr/testify/require"
)

var testOptions = DefaultOptions
Expand All @@ -28,9 +30,8 @@ func TestHandler_Create(t *testing.T) {
defer server.Close()

resp, err := http.Get(server.URL + "/echo")
if err != nil {
t.Errorf("There should not be any error, got '%s'", err)
}
require.NoError(t, err)
require.NotNil(t, resp)
if resp.StatusCode != http.StatusOK {
t.Errorf("Unexpected status code receiver, got '%d' expected '%d'", resp.StatusCode, http.StatusOK)
}
Expand All @@ -45,9 +46,9 @@ func TestHandler_RootPrefixInfoHandler(t *testing.T) {
defer server.Close()

resp, err := http.Get(server.URL + "/info")
if err != nil {
t.Errorf("There should not be any error, got '%s'", err)
}
require.NoError(t, err)
require.NotNil(t, resp)

if resp.StatusCode != http.StatusOK {
t.Errorf("Unexpected status code receiver, got '%d' expected '%d'", resp.StatusCode, http.StatusOK)
}
Expand Down
7 changes: 6 additions & 1 deletion sockjs/httpreceiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"net/http/httptest"
"testing"
"time"
)

type testFrameWriter struct {
Expand Down Expand Up @@ -94,7 +95,11 @@ func TestHttpReceiver_ConnectionInterrupt(t *testing.T) {
rw := newClosableRecorder()
recv := newHTTPReceiver(rw, 1024, nil)
rw.closeNotifCh <- true
recv.Lock()
select {
case <-recv.interruptCh:
case <-time.After(1 * time.Second):
t.Errorf("should interrupt")
}
if recv.state != stateHTTPReceiverClosed {
t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed)
}
Expand Down
7 changes: 6 additions & 1 deletion sockjs/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ func TestSession_Timeout(t *testing.T) {
select {
case <-sess.closeCh:
case <-time.After(20 * time.Millisecond):
t.Errorf("sess close notification channel should close")
select {
case <-sess.closeCh:
// still ok
default:
t.Errorf("sess close notification channel should close")
}
}
if sess.GetSessionState() != SessionClosed {
t.Errorf("Session did not timeout")
Expand Down
23 changes: 12 additions & 11 deletions sockjs/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)

func TestHandler_WebSocketHandshakeError(t *testing.T) {
Expand All @@ -16,7 +17,9 @@ func TestHandler_WebSocketHandshakeError(t *testing.T) {
defer server.Close()
req, _ := http.NewRequest("GET", server.URL, nil)
req.Header.Set("origin", "https"+server.URL[4:])
resp, _ := http.DefaultClient.Do(req)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.NotNil(t, resp)
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Unexpected response code, got '%d', expected '%d'", resp.StatusCode, http.StatusBadRequest)
}
Expand All @@ -30,12 +33,9 @@ func TestHandler_WebSocket(t *testing.T) {
var connCh = make(chan Session)
h.handlerFunc = func(conn Session) { connCh <- conn }
conn, resp, err := websocket.DefaultDialer.Dial(url, nil)
if conn == nil {
t.Errorf("Connection should not be nil")
}
if err != nil {
t.Errorf("Unexpected error '%v'", err)
}
require.NoError(t, err)
require.NotNil(t, conn)
require.NotNil(t, resp)
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Errorf("Wrong response code returned, got '%d', expected '%d'", resp.StatusCode, http.StatusSwitchingProtocols)
}
Expand All @@ -56,9 +56,7 @@ func TestHandler_WebSocketTerminationByServer(t *testing.T) {
conn.Close(0, "this should be ignored")
}
conn, _, err := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
if err != nil {
t.Fatalf("websocket dial failed: %v", err)
}
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
if string(msg) != "o" || err != nil {
t.Errorf("Open frame expected, got '%s' and error '%v', expected '%s' without error", msg, err, "o")
Expand Down Expand Up @@ -88,7 +86,8 @@ func TestHandler_WebSocketTerminationByClient(t *testing.T) {
}
close(done)
}
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": {server.URL}})
require.NotNil(t, conn)
conn.Close()
<-done
}
Expand All @@ -111,6 +110,7 @@ func TestHandler_WebSocketCommunication(t *testing.T) {
close(done)
}
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
require.NotNil(t, conn)
conn.WriteJSON([]string{"message 3"})
var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`}
for _, exp := range expected {
Expand Down Expand Up @@ -145,6 +145,7 @@ func TestHandler_CustomWebSocketCommunication(t *testing.T) {
close(done)
}
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
require.NotNil(t, conn)
conn.WriteJSON([]string{"message 3"})
var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`}
for _, exp := range expected {
Expand Down

0 comments on commit c1d44cc

Please sign in to comment.