Skip to content

Commit

Permalink
simplify State api
Browse files Browse the repository at this point in the history
  • Loading branch information
gobwas committed Aug 5, 2018
1 parent 85980ff commit a81a09b
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 20 deletions.
22 changes: 17 additions & 5 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ func (s State) Clear(v State) State {
return s & (^v)
}

// ServerSide reports whether states represents server side.
func (s State) ServerSide() bool { return s.Is(StateServerSide) }

// ClientSide reports whether state represents client side.
func (s State) ClientSide() bool { return s.Is(StateClientSide) }

// Extended reports whether state is extended.
func (s State) Extended() bool { return s.Is(StateExtended) }

// Fragmented reports whether state is fragmented.
func (s State) Fragmented() bool { return s.Is(StateFragmented) }

// ProtocolError describes error during checking/parsing websocket frames or
// headers.
type ProtocolError string
Expand Down Expand Up @@ -79,23 +91,23 @@ func CheckHeader(h Header, s State) error {
// non-zero values. If a nonzero value is received and none of the
// negotiated extensions defines the meaning of such a nonzero value, the
// receiving endpoint MUST _Fail the WebSocket Connection_.
case h.Rsv != 0 && !s.Is(StateExtended):
case h.Rsv != 0 && !s.Extended():
return ErrProtocolNonZeroRsv

// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
// status code 1002 (protocol error) as defined in Section 7.4.1.
case s.Is(StateServerSide) && !h.Masked:
case s.ServerSide() && !h.Masked:
return ErrProtocolMaskRequired
case s.Is(StateClientSide) && h.Masked:
case s.ClientSide() && h.Masked:
return ErrProtocolMaskUnexpected

// [RFC6455]: See detailed explanation in 5.4 section.
case s.Is(StateFragmented) && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
return ErrProtocolContinuationExpected
case !s.Is(StateFragmented) && h.OpCode == OpContinuation:
case !s.Fragmented() && h.OpCode == OpContinuation:
return ErrProtocolContinuationUnexpected
}

Expand Down
4 changes: 3 additions & 1 deletion dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,9 @@ func TestDialerCancelation(t *testing.T) {

ctx := context.Background()
if t := test.ctxTimeout; t != 0 {
ctx, _ = context.WithTimeout(ctx, t)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, t)
defer cancel()
}
if t := test.ctxCancelAfter; t != 0 {
var cancel context.CancelFunc
Expand Down
14 changes: 7 additions & 7 deletions wsutil/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ func (c ControlHandler) HandlePing(h ws.Header) error {
return ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpPong,
Masked: c.State.Is(ws.StateClientSide),
Masked: c.State.ClientSide(),
})
}

// In other way reply with Pong frame with copied payload.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.Is(ws.StateClientSide),
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)

Expand All @@ -98,7 +98,7 @@ func (c ControlHandler) HandlePing(h ws.Header) error {
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p)
r := c.Src
if c.State.Is(ws.StateServerSide) && !c.DisableSrcCiphering {
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}

Expand Down Expand Up @@ -135,7 +135,7 @@ func (c ControlHandler) HandleClose(h ws.Header) error {
err := ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpClose,
Masked: c.State.Is(ws.StateClientSide),
Masked: c.State.ClientSide(),
})
if err != nil {
return err
Expand All @@ -155,15 +155,15 @@ func (c ControlHandler) HandleClose(h ws.Header) error {
// Prepare bytes both for reading reason and sending response.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.Is(ws.StateClientSide),
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)

// Get the subslice to read the frame payload out.
subp := p[:h.Length]

r := c.Src
if c.State.Is(ws.StateServerSide) && !c.DisableSrcCiphering {
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
if _, err := io.ReadFull(r, subp); err != nil {
Expand Down Expand Up @@ -212,7 +212,7 @@ func (c ControlHandler) closeWithProtocolError(reason error) error {
f := ws.NewCloseFrame(ws.NewCloseFrameBody(
ws.StatusProtocolError, reason.Error(),
))
if c.State.Is(ws.StateClientSide) {
if c.State.ClientSide() {
ws.MaskFrameInPlace(f)
}
return ws.WriteFrame(c.Dst, f)
Expand Down
2 changes: 1 addition & 1 deletion wsutil/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error {
Length: int64(len(msg.Payload)),
OpCode: msg.OpCode,
Fin: true,
Masked: state.Is(ws.StateServerSide),
Masked: state.ServerSide(),
})
}

Expand Down
2 changes: 1 addition & 1 deletion wsutil/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (r *Reader) NextFrame() (hdr ws.Header, err error) {
}

func (r *Reader) fragmented() bool {
return r.State.Is(ws.StateFragmented)
return r.State.Fragmented()
}

func (r *Reader) resetFragment() {
Expand Down
8 changes: 4 additions & 4 deletions wsutil/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *

func reserve(state ws.State, n int) (offset int) {
var mask int
if state.Is(ws.StateClientSide) {
if state.ClientSide() {
mask = 4
}

Expand All @@ -221,7 +221,7 @@ func reserve(state ws.State, n int) (offset int) {
func headerSize(s ws.State, n int) int {
return ws.HeaderSize(ws.Header{
Length: int64(n),
Masked: s.Is(ws.StateClientSide),
Masked: s.ClientSide(),
})
}

Expand Down Expand Up @@ -388,7 +388,7 @@ func (w *Writer) FlushFragment() error {

func (w *Writer) flushFragment(fin bool) error {
frame := ws.NewFrame(w.opCode(), fin, w.buf[:w.n])
if w.state.Is(ws.StateClientSide) {
if w.state.ClientSide() {
frame = ws.MaskFrameInPlace(frame)
}

Expand Down Expand Up @@ -433,7 +433,7 @@ func (w *bytesWriter) Write(p []byte) (int, error) {

func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
var frame ws.Frame
if s.Is(ws.StateClientSide) {
if s.ClientSide() {
// Should copy bytes to prevent corruption of caller data.
payload := pbytes.GetLen(len(p))
defer pbytes.Put(payload)
Expand Down
2 changes: 1 addition & 1 deletion wsutil/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func genReserveTestCases(s ws.State, n, m, exp int) []reserveTestCase {
ret := make([]reserveTestCase, m-n)
for i := n; i < m; i++ {
var suffix string
if s.Is(ws.StateClientSide) {
if s.ClientSide() {
suffix = " masked"
}

Expand Down

0 comments on commit a81a09b

Please sign in to comment.