Skip to content

Commit

Permalink
changed logic to wait for interrupt channel to close before checking …
Browse files Browse the repository at this point in the history
…receiver state rather to rely on lock mechanism

replaced time.Sleep with active loop to get session and receiver (relates to #66)
some test updates to use testify to check for nil
(relates to #66)
  • Loading branch information
igm committed Nov 16, 2019
1 parent bd2ed36 commit 84f9918
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
28 changes: 19 additions & 9 deletions sockjs/eventsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ func TestHandler_EventSource(t *testing.T) {
h := newTestHandler()
h.options.ResponseLimit = 1024
go func() {
time.Sleep(1 * time.Millisecond)
h.sessionsMux.Lock()
defer h.sessionsMux.Unlock()
sess := h.sessions["session"]
sess.Lock()
defer sess.Unlock()
recv := sess.recv
recv.close()
var sess *session
for exists := false; !exists; {
h.sessionsMux.Lock()
sess, exists = h.sessions["session"]
h.sessionsMux.Unlock()
}
for exists := false; !exists; {
sess.RLock()
exists = sess.recv != nil
sess.RUnlock()
}
sess.RLock()
sess.recv.close()
sess.RUnlock()
}()
h.eventSource(rw, req)
contentType := rw.Header().Get("content-type")
Expand Down Expand Up @@ -65,7 +71,11 @@ func TestHandler_EventSourceConnectionInterrupted(t *testing.T) {
rw := newClosableRecorder()
close(rw.closeNotifCh)
h.eventSource(rw, req)
time.Sleep(1 * time.Millisecond)
select {
case <-sess.closeCh:
case <-time.After(1 * time.Second):
t.Errorf("session close channel should be closed")
}
sess.Lock()
if sess.state != SessionClosed {
t.Errorf("Session should be closed")
Expand Down
13 changes: 7 additions & 6 deletions sockjs/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/url"
"testing"
"time"

"github.com/stretchr/testify/require"
)

var testOptions = DefaultOptions
Expand All @@ -28,9 +30,8 @@ func TestHandler_Create(t *testing.T) {
defer server.Close()

resp, err := http.Get(server.URL + "/echo")
if err != nil {
t.Errorf("There should not be any error, got '%s'", err)
}
require.NoError(t, err)
require.NotNil(t, resp)
if resp.StatusCode != http.StatusOK {
t.Errorf("Unexpected status code receiver, got '%d' expected '%d'", resp.StatusCode, http.StatusOK)
}
Expand All @@ -45,9 +46,9 @@ func TestHandler_RootPrefixInfoHandler(t *testing.T) {
defer server.Close()

resp, err := http.Get(server.URL + "/info")
if err != nil {
t.Errorf("There should not be any error, got '%s'", err)
}
require.NoError(t, err)
require.NotNil(t, resp)

if resp.StatusCode != http.StatusOK {
t.Errorf("Unexpected status code receiver, got '%d' expected '%d'", resp.StatusCode, http.StatusOK)
}
Expand Down
7 changes: 6 additions & 1 deletion sockjs/httpreceiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"net/http/httptest"
"testing"
"time"
)

type testFrameWriter struct {
Expand Down Expand Up @@ -94,7 +95,11 @@ func TestHttpReceiver_ConnectionInterrupt(t *testing.T) {
rw := newClosableRecorder()
recv := newHTTPReceiver(rw, 1024, nil)
rw.closeNotifCh <- true
recv.Lock()
select {
case <-recv.interruptCh:
case <-time.After(1 * time.Second):
t.Errorf("should interrupt")
}
if recv.state != stateHTTPReceiverClosed {
t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed)
}
Expand Down
7 changes: 6 additions & 1 deletion sockjs/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ func TestSession_Timeout(t *testing.T) {
select {
case <-sess.closeCh:
case <-time.After(20 * time.Millisecond):
t.Errorf("sess close notification channel should close")
select {
case <-sess.closeCh:
// still ok
default:
t.Errorf("sess close notification channel should close")
}
}
if sess.GetSessionState() != SessionClosed {
t.Errorf("Session did not timeout")
Expand Down
23 changes: 12 additions & 11 deletions sockjs/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)

func TestHandler_WebSocketHandshakeError(t *testing.T) {
Expand All @@ -16,7 +17,9 @@ func TestHandler_WebSocketHandshakeError(t *testing.T) {
defer server.Close()
req, _ := http.NewRequest("GET", server.URL, nil)
req.Header.Set("origin", "https"+server.URL[4:])
resp, _ := http.DefaultClient.Do(req)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.NotNil(t, resp)
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Unexpected response code, got '%d', expected '%d'", resp.StatusCode, http.StatusBadRequest)
}
Expand All @@ -30,12 +33,9 @@ func TestHandler_WebSocket(t *testing.T) {
var connCh = make(chan Session)
h.handlerFunc = func(conn Session) { connCh <- conn }
conn, resp, err := websocket.DefaultDialer.Dial(url, nil)
if conn == nil {
t.Errorf("Connection should not be nil")
}
if err != nil {
t.Errorf("Unexpected error '%v'", err)
}
require.NoError(t, err)
require.NotNil(t, conn)
require.NotNil(t, resp)
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Errorf("Wrong response code returned, got '%d', expected '%d'", resp.StatusCode, http.StatusSwitchingProtocols)
}
Expand All @@ -56,9 +56,7 @@ func TestHandler_WebSocketTerminationByServer(t *testing.T) {
conn.Close(0, "this should be ignored")
}
conn, _, err := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
if err != nil {
t.Fatalf("websocket dial failed: %v", err)
}
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
if string(msg) != "o" || err != nil {
t.Errorf("Open frame expected, got '%s' and error '%v', expected '%s' without error", msg, err, "o")
Expand Down Expand Up @@ -88,7 +86,8 @@ func TestHandler_WebSocketTerminationByClient(t *testing.T) {
}
close(done)
}
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": {server.URL}})
require.NotNil(t, conn)
conn.Close()
<-done
}
Expand All @@ -111,6 +110,7 @@ func TestHandler_WebSocketCommunication(t *testing.T) {
close(done)
}
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
require.NotNil(t, conn)
conn.WriteJSON([]string{"message 3"})
var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`}
for _, exp := range expected {
Expand Down Expand Up @@ -145,6 +145,7 @@ func TestHandler_CustomWebSocketCommunication(t *testing.T) {
close(done)
}
conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}})
require.NotNil(t, conn)
conn.WriteJSON([]string{"message 3"})
var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`}
for _, exp := range expected {
Expand Down

0 comments on commit 84f9918

Please sign in to comment.