Skip to content

Commit

Permalink
Added option to access receiver type from session, resolves #86
Browse files Browse the repository at this point in the history
  • Loading branch information
igm committed Apr 17, 2020
1 parent e472422 commit 9d53909
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 29 deletions.
2 changes: 1 addition & 1 deletion v3/sockjs/eventsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func (h *Handler) eventSource(rw http.ResponseWriter, req *http.Request) {
_, _ = fmt.Fprint(rw, "\r\n")
rw.(http.Flusher).Flush()

recv := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(eventSourceFrameWriter))
recv := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(eventSourceFrameWriter), ReceiverTypeEventSource)
sess, err := h.sessionByRequest(req)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
Expand Down
4 changes: 4 additions & 0 deletions v3/sockjs/eventsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ func TestHandler_EventSource(t *testing.T) {
exists = sess.recv != nil
sess.mux.RUnlock()
}
if rt := sess.ReceiverType(); rt != ReceiverTypeEventSource {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeEventSource)
}
sess.mux.RLock()
sess.recv.close()
sess.mux.RUnlock()
}()
h.eventSource(rw, req)

contentType := rw.Header().Get("content-type")
expected := "text/event-stream; charset=UTF-8"
if contentType != expected {
Expand Down
2 changes: 1 addition & 1 deletion v3/sockjs/htmlfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (h *Handler) htmlFile(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
recv := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(htmlfileFrameWriter))
recv := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(htmlfileFrameWriter), ReceiverTypeHtmlFile)
if err := sess.attachReceiver(recv); err != nil {
if err := recv.sendFrame(cFrame); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
Expand Down
5 changes: 4 additions & 1 deletion v3/sockjs/htmlfile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func TestHandler_htmlFile(t *testing.T) {
if rw.Body.String() != expectedIFrame {
t.Errorf("Unexpected response body, got '%s', expected '%s'", rw.Body, expectedIFrame)
}

sess, _ := h.sessionByRequest(req)
if rt := sess.ReceiverType(); rt != ReceiverTypeHtmlFile {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeHtmlFile)
}
}

func TestHandler_cannotIntoXSS(t *testing.T) {
Expand Down
8 changes: 7 additions & 1 deletion v3/sockjs/httpreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ type httpReceiver struct {
currentResponseSize uint32
doneCh chan struct{}
interruptCh chan struct{}
recType ReceiverType
}

func newHTTPReceiver(rw http.ResponseWriter, req *http.Request, maxResponse uint32, frameWriter frameWriter) *httpReceiver {
func newHTTPReceiver(rw http.ResponseWriter, req *http.Request, maxResponse uint32, frameWriter frameWriter, receiverType ReceiverType) *httpReceiver {
recv := &httpReceiver{
rw: rw,
frameWriter: frameWriter,
maxResponseSize: maxResponse,
doneCh: make(chan struct{}),
interruptCh: make(chan struct{}),
recType: receiverType,
}
ctx := req.Context()

Expand Down Expand Up @@ -105,3 +107,7 @@ func (recv *httpReceiver) canSend() bool {
defer recv.Unlock()
return recv.state != stateHTTPReceiverClosed
}

func (recv *httpReceiver) receiverType() ReceiverType {
return recv.recType
}
14 changes: 7 additions & 7 deletions v3/sockjs/httpreceiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (t *testFrameWriter) write(w io.Writer, frame string) (int, error) {
func TestHttpReceiver_Create(t *testing.T) {
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "", nil)
recv := newHTTPReceiver(rec, req, 1024, new(testFrameWriter))
recv := newHTTPReceiver(rec, req, 1024, new(testFrameWriter), ReceiverTypeNone)
if recv.doneCh != recv.doneNotify() {
t.Errorf("Calling done() must return close channel, but it does not")
}
Expand All @@ -36,7 +36,7 @@ func TestHttpReceiver_Create(t *testing.T) {
func TestHttpReceiver_SendEmptyFrames(t *testing.T) {
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "", nil)
recv := newHTTPReceiver(rec, req, 1024, new(testFrameWriter))
recv := newHTTPReceiver(rec, req, 1024, new(testFrameWriter), ReceiverTypeNone)
noError(t, recv.sendBulk())
if rec.Body.String() != "" {
t.Errorf("Incorrect body content received from receiver '%s'", rec.Body.String())
Expand All @@ -47,7 +47,7 @@ func TestHttpReceiver_SendFrame(t *testing.T) {
rec := httptest.NewRecorder()
fw := new(testFrameWriter)
req, _ := http.NewRequest("GET", "", nil)
recv := newHTTPReceiver(rec, req, 1024, fw)
recv := newHTTPReceiver(rec, req, 1024, fw, ReceiverTypeNone)
var frame = "some frame content"
noError(t, recv.sendFrame(frame))
if len(fw.frames) != 1 || fw.frames[0] != frame {
Expand All @@ -60,7 +60,7 @@ func TestHttpReceiver_SendBulk(t *testing.T) {
rec := httptest.NewRecorder()
fw := new(testFrameWriter)
req, _ := http.NewRequest("GET", "", nil)
recv := newHTTPReceiver(rec, req, 1024, fw)
recv := newHTTPReceiver(rec, req, 1024, fw, ReceiverTypeNone)
noError(t, recv.sendBulk("message 1", "message 2", "message 3"))
expected := "a[\"message 1\",\"message 2\",\"message 3\"]"
if len(fw.frames) != 1 || fw.frames[0] != expected {
Expand All @@ -71,7 +71,7 @@ func TestHttpReceiver_SendBulk(t *testing.T) {
func TestHttpReceiver_MaximumResponseSize(t *testing.T) {
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "", nil)
recv := newHTTPReceiver(rec, req, 52, new(testFrameWriter))
recv := newHTTPReceiver(rec, req, 52, new(testFrameWriter), ReceiverTypeNone)
noError(t, recv.sendBulk("message 1", "message 2")) // produces 26 bytes of response in 1 frame
if recv.currentResponseSize != 26 {
t.Errorf("Incorrect response size calcualated, got '%d' expected '%d'", recv.currentResponseSize, 26)
Expand All @@ -92,7 +92,7 @@ func TestHttpReceiver_MaximumResponseSize(t *testing.T) {
func TestHttpReceiver_Close(t *testing.T) {
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "", nil)
recv := newHTTPReceiver(rec, req, 1024, nil)
recv := newHTTPReceiver(rec, req, 1024, nil, ReceiverTypeNone)
recv.close()
if recv.state != stateHTTPReceiverClosed {
t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed)
Expand All @@ -104,7 +104,7 @@ func TestHttpReceiver_ConnectionInterrupt(t *testing.T) {
req, _ := http.NewRequest("GET", "", nil)
ctx, cancel := context.WithCancel(req.Context())
req = req.WithContext(ctx)
recv := newHTTPReceiver(rw, req, 1024, nil)
recv := newHTTPReceiver(rw, req, 1024, nil, ReceiverTypeNone)
cancel()
select {
case <-recv.interruptCh:
Expand Down
2 changes: 1 addition & 1 deletion v3/sockjs/jsonp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (h *Handler) jsonp(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
recv := newHTTPReceiver(rw, req, 1, &jsonpFrameWriter{callback})
recv := newHTTPReceiver(rw, req, 1, &jsonpFrameWriter{callback}, ReceiverTypeJSONP)
if err := sess.attachReceiver(recv); err != nil {
if err := recv.sendFrame(cFrame); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
Expand Down
4 changes: 4 additions & 0 deletions v3/sockjs/jsonp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ func TestHandler_jsonp(t *testing.T) {
if rw.Body.String() != expectedBody {
t.Errorf("Unexpected body, got '%s', expected '%s'", rw.Body, expectedBody)
}
sess, _ := h.sessionByRequest(req)
if rt := sess.ReceiverType(); rt != ReceiverTypeJSONP {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeJSONP)
}
}

func TestHandler_jsonpSendNoPayload(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions v3/sockjs/rawwebsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ func (w *rawWsReceiver) sendFrame(frame string) error {
return nil
}

func (w *rawWsReceiver) receiverType() ReceiverType {
return ReceiverTypeRawWebsocket
}

func parseCloseFrame(frame string) (status uint32, reason string, err error) {
var items [2]interface{}
if err := json.Unmarshal([]byte(frame)[1:], &items); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions v3/sockjs/rawwebsocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func TestHandler_RawWebSocketTerminationByServer(t *testing.T) {
url := "ws" + server.URL[4:]
h.handlerFunc = func(conn *session) {
// close the session without sending any message
if rt := conn.ReceiverType(); rt != ReceiverTypeRawWebsocket {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeRawWebsocket)
}
conn.Close(3000, "some close message")
conn.Close(0, "this should be ignored")
}
Expand Down
30 changes: 30 additions & 0 deletions v3/sockjs/receiver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package sockjs

type ReceiverType int

const (
ReceiverTypeNone ReceiverType = iota
ReceiverTypeXHR
ReceiverTypeEventSource
ReceiverTypeHtmlFile
ReceiverTypeJSONP
ReceiverTypeXHRStreaming
ReceiverTypeRawWebsocket
ReceiverTypeWebsocket
)

type receiver interface {
// sendBulk send multiple data messages in frame frame in format: a["msg 1", "msg 2", ....]
sendBulk(...string) error
// sendFrame sends given frame over the wire (with possible chunking depending on receiver)
sendFrame(string) error
// close closes the receiver in a "done" way (idempotent)
close()
canSend() bool
// done notification channel gets closed whenever receiver ends
doneNotify() <-chan struct{}
// interrupted channel gets closed whenever receiver is interrupted (i.e. http connection drops,...)
interruptedNotify() <-chan struct{}
// returns the type of receiver
receiverType() ReceiverType
}
23 changes: 9 additions & 14 deletions v3/sockjs/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,6 @@ type session struct {
closeCh chan struct{}
}

type receiver interface {
// sendBulk send multiple data messages in frame frame in format: a["msg 1", "msg 2", ....]
sendBulk(...string) error
// sendFrame sends given frame over the wire (with possible chunking depending on receiver)
sendFrame(string) error
// close closes the receiver in a "done" way (idempotent)
close()
canSend() bool
// done notification channel gets closed whenever receiver ends
doneNotify() <-chan struct{}
// interrupted channel gets closed whenever receiver is interrupted (i.e. http connection drops,...)
interruptedNotify() <-chan struct{}
}

// session is a central component that handles receiving and sending frames. It maintains internal state
func newSession(req *http.Request, sessionID string, sessionTimeoutInterval, heartbeatInterval time.Duration) *session {
s := &session{
Expand Down Expand Up @@ -225,3 +211,12 @@ func (s *session) GetSessionState() SessionState {
defer s.mux.RUnlock()
return s.state
}

func (s *session) ReceiverType() ReceiverType {
s.mux.RLock()
defer s.mux.RUnlock()
if s.recv != nil {
return s.recv.receiverType()
}
return ReceiverTypeNone
}
4 changes: 4 additions & 0 deletions v3/sockjs/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,7 @@ func noError(t *testing.T, err error) {
t.Fail()
}
}

func (t *testReceiver) receiverType() ReceiverType {
return ReceiverTypeNone
}
1 change: 1 addition & 0 deletions v3/sockjs/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@ func (w *wsReceiver) canSend() bool {
}
func (w *wsReceiver) doneNotify() <-chan struct{} { return w.closeCh }
func (w *wsReceiver) interruptedNotify() <-chan struct{} { return nil }
func (w *wsReceiver) receiverType() ReceiverType { return ReceiverTypeWebsocket }
7 changes: 6 additions & 1 deletion v3/sockjs/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ func TestHandler_WebSocket(t *testing.T) {
defer server.CloseClientConnections()
url := "ws" + server.URL[4:]
var connCh = make(chan *session)
h.handlerFunc = func(conn *session) { connCh <- conn }
h.handlerFunc = func(conn *session) {
if rt := conn.ReceiverType(); rt != ReceiverTypeWebsocket {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeWebsocket)
}
connCh <- conn
}
conn, resp, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
t.Errorf("Unexpected error '%v'", err)
Expand Down
4 changes: 2 additions & 2 deletions v3/sockjs/xhr.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (h *Handler) xhrPoll(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
receiver := newHTTPReceiver(rw, req, 1, new(xhrFrameWriter))
receiver := newHTTPReceiver(rw, req, 1, new(xhrFrameWriter), ReceiverTypeXHR)
if err := sess.attachReceiver(receiver); err != nil {
if err := receiver.sendFrame(cFrame); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
Expand All @@ -87,7 +87,7 @@ func (h *Handler) xhrStreaming(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
receiver := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(xhrFrameWriter))
receiver := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(xhrFrameWriter), ReceiverTypeXHRStreaming)

if err := sess.attachReceiver(receiver); err != nil {
if err := receiver.sendFrame(cFrame); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions v3/sockjs/xhr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func TestHandler_XhrPoll(t *testing.T) {
if rw.Header().Get("content-type") != "application/javascript; charset=UTF-8" {
t.Errorf("Wrong content type received, got '%s'", rw.Header().Get("content-type"))
}
sess, _ := h.sessionByRequest(req)
if rt := sess.ReceiverType(); rt != ReceiverTypeXHR {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeXHR)
}
}

func TestHandler_XhrPollConnectionInterrupted(t *testing.T) {
Expand Down Expand Up @@ -151,6 +155,10 @@ func TestHandler_XhrStreaming(t *testing.T) {
if rw.Body.String() != expectedBody {
t.Errorf("Unexpected body, got '%s' expected '%s'", rw.Body, expectedBody)
}
sess, _ := h.sessionByRequest(req)
if rt := sess.ReceiverType(); rt != ReceiverTypeXHRStreaming {
t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeXHRStreaming)
}
}

func TestHandler_XhrStreamingAnotherReceiver(t *testing.T) {
Expand Down

0 comments on commit 9d53909

Please sign in to comment.