diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b31cdb4
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+bin/
+reports/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..2744317
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2017 Sergey Kamardin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..c7104fe
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,47 @@
+BENCH ?=.
+BENCH_BASE?=master
+
+autobahn:
+ go build -o ./bin/autobahn ./example/autobahn
+
+test:
+ go test -cover ./...
+
+testrfc: PID:=$(shell mktemp -t autobahn.XXXX)
+testrfc: autobahn
+ ./bin/autobahn & echo $$! > $(PID)
+ if [ -z "$$(ps | grep $$(cat $(PID)) | grep autobahn)" ]; then\
+ echo "could not start autobahn";\
+ exit 1;\
+ fi;\
+ wstest -m fuzzingclient -s ./example/autobahn/fuzzingclient.json
+ pkill -9 -F $(PID)
+
+rfc: testrfc
+ open ./example/autobahn/reports/servers/index.html
+
+bench:
+ go test -run=none -bench=$(BENCH) -benchmem
+
+benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
+benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX)
+benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX)
+benchcmp:
+ if [ ! -z "$(shell git status -s)" ]; then\
+ echo "could not compare with $(BENCH_BASE) – found unstaged changes";\
+ exit 1;\
+ fi;\
+ if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\
+ echo "comparing the same branches";\
+ exit 1;\
+ fi;\
+ echo "benchmarking $(BENCH_BRANCH)...";\
+ go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\
+ echo "benchmarking $(BENCH_BASE)...";\
+ git checkout -q $(BENCH_BASE);\
+ go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\
+ git checkout -q $(BENCH_BRANCH);\
+ echo "\nresults:";\
+ echo "========\n";\
+ benchcmp $(BENCH_OLD) $(BENCH_NEW);\
+
diff --git a/check.go b/check.go
new file mode 100644
index 0000000..1437de1
--- /dev/null
+++ b/check.go
@@ -0,0 +1,141 @@
+package ws
+
+import (
+ "fmt"
+ "unicode/utf8"
+)
+
+// State represents state of websocket endpoint.
+// It used by some functions to be more strict when checking compatibility with RFC6455.
+type State uint8
+
+const (
+ // StateServerSide means that endpoint (caller) is a server.
+ StateServerSide State = 0x1 << iota
+ // StateServerSide means that endpoint (caller) is a client.
+ StateClientSide
+ // StateExtended means that extension was negotiated during handshake.
+ StateExtended
+ // StateFragmented means that endpoint (caller) has received fragmented
+ // frame and waits for continuation parts.
+ StateFragmented
+)
+
+// Is checks whether the s has v enabled.
+func (s State) Is(v State) bool {
+ return uint8(s)&uint8(v) != 0
+}
+
+// Set enables v state on s.
+func (s State) Set(v State) State {
+ return s | v
+}
+
+// Clear disables v state on s.
+func (s State) Clear(v State) State {
+ return s & (^v)
+}
+
+// SetOrClearIf enables or disables v state on s depending on cond.
+func (s State) SetOrClearIf(cond bool, v State) (ret State) {
+ if cond {
+ ret = s.Set(v)
+ } else {
+ ret = s.Clear(v)
+ }
+ return
+}
+
+// ProtocolError describes error during checking/parsing websocket frames or headers.
+type ProtocolError error
+
+// Errors used by the protocol checkers.
+var (
+ ErrProtocolOpCodeReserved = ProtocolError(fmt.Errorf("use of reserved op code"))
+ ErrProtocolControlPayloadOverflow = ProtocolError(fmt.Errorf("control frame payload limit exceeded"))
+ ErrProtocolControlNotFinal = ProtocolError(fmt.Errorf("control frame is not final"))
+ ErrProtocolNonZeroRsv = ProtocolError(fmt.Errorf("non-zero rsv bits with no extension negotiated"))
+ ErrProtocolMaskRequired = ProtocolError(fmt.Errorf("frames from client to server must be masked"))
+ ErrProtocolMaskUnexpected = ProtocolError(fmt.Errorf("frames from server to client must be not masked"))
+ ErrProtocolContinuationExpected = ProtocolError(fmt.Errorf("unexpected non-continuation data frame"))
+ ErrProtocolContinuationUnexpected = ProtocolError(fmt.Errorf("unexpected continuation data frame"))
+ ErrProtocolStatusCodeNotInUse = ProtocolError(fmt.Errorf("status code is not in use"))
+ ErrProtocolStatusCodeApplicationLevel = ProtocolError(fmt.Errorf("status code is only application level"))
+ ErrProtocolStatusCodeNoMeaning = ProtocolError(fmt.Errorf("status code has no meaning yet"))
+ ErrProtocolStatusCodeUnknown = ProtocolError(fmt.Errorf("status code is not defined in spec"))
+ ErrProtocolInvalidUTF8 = ProtocolError(fmt.Errorf("invalid utf8 sequence in close reason"))
+)
+
+// CheckHeader checks h to contain valid header data for given state s.
+//
+// Note that zero state (0) means that state is clean,
+// neither server or client side, nor fragmented, nor extended.
+func CheckHeader(h Header, s State) error {
+ if h.OpCode.IsReserved() {
+ return ErrProtocolOpCodeReserved
+ }
+ if h.OpCode.IsControl() {
+ if h.Length > MaxControlFramePayloadSize {
+ return ErrProtocolControlPayloadOverflow
+ }
+ if !h.Fin {
+ return ErrProtocolControlNotFinal
+ }
+ }
+
+ switch {
+ // [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
+ // non-zero values. If a nonzero value is received and none of the
+ // negotiated extensions defines the meaning of such a nonzero value, the
+ // receiving endpoint MUST _Fail the WebSocket Connection_.
+ case h.Rsv != 0 && !s.Is(StateExtended):
+ return ErrProtocolNonZeroRsv
+
+ // [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
+ // In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
+ // as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
+ // A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
+ // status code 1002 (protocol error) as defined in Section 7.4.1.
+ case s.Is(StateServerSide) && h.Mask == nil:
+ return ErrProtocolMaskRequired
+ case s.Is(StateClientSide) && h.Mask != nil:
+ return ErrProtocolMaskUnexpected
+
+ // [RFC6455]: See detailed explanation in 5.4 section.
+ case s.Is(StateFragmented) && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
+ return ErrProtocolContinuationExpected
+ case !s.Is(StateFragmented) && h.OpCode == OpContinuation:
+ return ErrProtocolContinuationUnexpected
+ }
+
+ return nil
+}
+
+// CheckCloseFrameData checks received close information
+// to be valid RFC6455 compatible clsoe info.
+//
+// Note that code.Empty() or code.IsAppLevel() will raise error.
+//
+// If endpoint sends close frame without status code (with frame.Length = 0),
+// application should not check its payload.
+func CheckCloseFrameData(code StatusCode, reason string) error {
+ switch {
+ case code.IsNotUsed():
+ return ErrProtocolStatusCodeNotInUse
+
+ case code.IsProtocolReserved():
+ return ErrProtocolStatusCodeApplicationLevel
+
+ case code == StatusNoMeaningYet:
+ return ErrProtocolStatusCodeNoMeaning
+
+ case code.IsProtocolSpec() && !code.IsProtocolDefined():
+ return ErrProtocolStatusCodeUnknown
+
+ case !utf8.ValidString(reason):
+ return ErrProtocolInvalidUTF8
+
+ default:
+ return nil
+ }
+}
diff --git a/cipher.go b/cipher.go
new file mode 100644
index 0000000..d06e141
--- /dev/null
+++ b/cipher.go
@@ -0,0 +1,55 @@
+package ws
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Cipher applies XOR cipher to the payload using mask.
+// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
+//
+// To convert masked data into unmasked data, or vice versa, the following
+// algorithm is applied. The same algorithm applies regardless of the
+// direction of the translation, e.g., the same steps are applied to
+// mask the data as to unmask the data.
+func Cipher(payload, mask []byte, offset int) {
+ if len(mask) != 4 {
+ return
+ }
+
+ n := len(payload)
+ if n < 8 {
+ for i := 0; i < n; i++ {
+ payload[i] ^= mask[(offset+i)%4]
+ }
+ return
+ }
+
+ // Calculate position in mask due to previously processed bytes number.
+ mpos := offset % 4
+ // Count number of bytes will processed one by one from the begining of payload.
+ // Bitwise used to avoid additional if.
+ ln := (4 - mpos) & 0x0b
+ // Count number of bytes will processed one by one from the end of payload.
+ // This is done to process payload by 8 bytes in each iteration of main loop.
+ rn := (n - ln) % 8
+
+ for i := 0; i < ln; i++ {
+ payload[i] ^= mask[(mpos+i)%4]
+ }
+ for i := n - rn; i < n; i++ {
+ payload[i] ^= mask[(mpos+i)%4]
+ }
+
+ ph := *(*reflect.SliceHeader)(unsafe.Pointer(&payload))
+ mh := *(*reflect.SliceHeader)(unsafe.Pointer(&mask))
+
+ m := *(*uint32)(unsafe.Pointer(mh.Data))
+ m2 := uint64(m)<<32 | uint64(m)
+
+ // Process the rest of bytes as uint64.
+ for i := ln; i+8 <= n-rn; i += 8 {
+ v := (*uint64)(unsafe.Pointer(ph.Data + uintptr(i)))
+ *v = *v ^ m2
+ }
+}
diff --git a/cipher_test.go b/cipher_test.go
new file mode 100644
index 0000000..1fd3e1a
--- /dev/null
+++ b/cipher_test.go
@@ -0,0 +1,154 @@
+package ws
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+)
+
+func TestCipher(t *testing.T) {
+ for i, test := range []struct {
+ in []byte
+ mask [4]byte
+ }{
+ {
+ in: []byte("Hello, XOR!"),
+ mask: [4]byte{1, 2, 3, 4},
+ },
+ {
+ in: []byte("Hello, XOR!"),
+ mask: [4]byte{255, 255, 255, 255},
+ },
+ } {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ // naive implementation of xor-cipher
+ exp := cipherNaive(test.in, test.mask[:], 0)
+
+ res := make([]byte, len(test.in))
+ copy(res, test.in)
+ Cipher(res, test.mask[:], 0)
+
+ if !reflect.DeepEqual(res, exp) {
+ t.Errorf("Cipher(%v, %v):\nact:\t%v\nexp:\t%v\n", test.in, test.mask, res, exp)
+ }
+ })
+ }
+}
+
+func TestCipherChops(t *testing.T) {
+ for n := 2; n <= 1024; n <<= 1 {
+ t.Run(fmt.Sprintf("%d", n), func(t *testing.T) {
+ p := make([]byte, n)
+ b := make([]byte, n)
+ m := make([]byte, 4)
+
+ _, err := rand.Read(p)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = rand.Read(m)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ exp := cipherNaive(p, m, 0)
+
+ for i := 1; i <= n; i <<= 1 {
+ copy(b, p)
+ s := n / i
+
+ for j := s; j <= n; j += s {
+ l, r := j-s, j
+ Cipher(b[l:r], m, l)
+ if !reflect.DeepEqual(b[l:r], exp[l:r]) {
+ t.Errorf("unexpected Cipher([%d:%d]) = %x; want %x", l, r, b[l:r], exp[l:r])
+ return
+ }
+ }
+ }
+
+ l := 0
+ copy(b, p)
+ for l < n {
+ r := rand.Intn(n-l) + l + 1
+ Cipher(b[l:r], m, l)
+ if !reflect.DeepEqual(b[l:r], exp[l:r]) {
+ t.Errorf("unexpected Cipher([%d:%d]):\nact:\t%v\nexp:\t%v\nact:\t%#x\nexp:\t%#x\n\n", l, r, b[l:r], exp[l:r], b[l:r], exp[l:r])
+ return
+ }
+ l = r
+ }
+ })
+ }
+}
+
+func cipherNaive(p, m []byte, pos int) []byte {
+ r := make([]byte, len(p))
+ copy(r, p)
+ cipherNaiveNoCp(r, m, pos)
+ return r
+}
+
+func cipherNaiveNoCp(p, m []byte, pos int) []byte {
+ for i := 0; i < len(p); i++ {
+ p[i] ^= m[(pos+i)%4]
+ }
+ return p
+}
+
+func BenchmarkCipher(b *testing.B) {
+ for _, bench := range []struct {
+ size int
+ offset int
+ }{
+ {
+ size: 7,
+ offset: 1,
+ },
+ {
+ size: 125,
+ },
+ {
+ size: 1024,
+ },
+ {
+ size: 4096,
+ },
+ {
+ size: 4100,
+ offset: 4,
+ },
+ {
+ size: 4099,
+ offset: 3,
+ },
+ {
+ size: (1 << 15) + 7,
+ offset: 49,
+ },
+ } {
+ bts := make([]byte, bench.size)
+ _, err := rand.Read(bts)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ mask := make([]byte, 4)
+ _, err = rand.Read(mask)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ //b.Run(fmt.Sprintf("naive_bytes=%d;offset=%d", bench.size, bench.offset), func(b *testing.B) {
+ // for i := 0; i < b.N; i++ {
+ // cipherNaiveNoCp(bts, mask, bench.offset)
+ // }
+ //})
+ b.Run(fmt.Sprintf("bytes=%d;offset=%d", bench.size, bench.offset), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ Cipher(bts, mask, bench.offset)
+ }
+ })
+ }
+}
diff --git a/client.go b/client.go
new file mode 100644
index 0000000..d9ef970
--- /dev/null
+++ b/client.go
@@ -0,0 +1,253 @@
+package ws
+
+import (
+ "bufio"
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/gobwas/pool/pbufio"
+)
+
+// ReaderPool describes object that manages reuse of bufio.Reader instances.
+type ReaderPool interface {
+ Get(io.Reader) *bufio.Reader
+ Put(*bufio.Reader)
+}
+
+// WriterPool describes object that manages reuse of bufio.Writer instances.
+type WriterPool interface {
+ Get(io.Writer) *bufio.Writer
+ Put(*bufio.Writer)
+}
+
+// Handshake represents handshake result.
+type Handshake struct {
+ // Protocol is the selected during handshake subprotocol.
+ Protocol string
+
+ // Extensions is the list of negotiated extensions.
+ Extensions []string
+}
+
+// Response represents result of dialing.
+type Response struct {
+ *http.Response
+ Handshake
+}
+
+var (
+ defaultWriterPool = writerPool(512)
+ defaultReaderPool = readerPool(512)
+)
+
+// Errors used by the websocket client.
+var (
+ ErrBadStatus = fmt.Errorf("unexpected http status")
+ ErrBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
+)
+
+// Dialer contains options for establishing websocket connection to an url.
+type Dialer struct {
+ // Header is the set of custom headers that will be sent with the request.
+ Header http.Header
+
+ // Protocol is the list of subprotocol names the client wishes to speak, ordered by preference.
+ // See https://tools.ietf.org/html/rfc6455#section-4.1
+ Protocol []string
+
+ // Extensions is the list of extensions, that client wishes to speak.
+ // See https://tools.ietf.org/html/rfc6455#section-4.1
+ // See https://tools.ietf.org/html/rfc6455#section-9.1
+ Extensions []string
+
+ // NetDial is the function that is used to get plain tcp connection.
+ // If it is not nil, then it is used instead of net.Dialer.
+ NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // NetDialTLS is the function that is used to get plain tcp connection with tls encryption.
+ // If it is not nil, then it is used instead of tls.DialWithDialer.
+ NetDialTLS func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // TLSConfig is passed to tls.DialWithDialer.
+ TLSConfig *tls.Config
+
+ // WriterPool is used to reuse bufio.Writers.
+ WriterPool WriterPool
+
+ // ReaderPool is used to reuse bufio.Readers.
+ ReaderPool ReaderPool
+}
+
+// Dial connects to the url host and handshakes connection to websocket.
+func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, resp Response, err error) {
+ req := getRequest()
+ defer putRequest(req)
+
+ err = req.Reset(urlstr, d.Header, d.Protocol, d.Extensions)
+ if err != nil {
+ return
+ }
+
+ conn, err = d.dial(ctx, req.URL)
+ if err != nil {
+ return
+ }
+
+ resp.Response, err = d.send(ctx, conn, req)
+ if err != nil {
+ return
+ }
+
+ resp.Protocol, resp.Extensions, err = d.handshake(req, resp)
+
+ return
+}
+
+func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
+ addr := hostport(u)
+ if u.Scheme == "wss" {
+ if nd := d.NetDialTLS; nd != nil {
+ return nd(ctx, "tcp", addr)
+ }
+
+ var nd net.Dialer
+ if deadline, ok := ctx.Deadline(); ok {
+ nd.Deadline = deadline
+ }
+ return tls.DialWithDialer(&nd, "tcp", addr, d.TLSConfig)
+ }
+
+ if nd := d.NetDial; nd != nil {
+ return nd(ctx, "tcp", addr)
+ }
+
+ var nd net.Dialer
+ return nd.DialContext(ctx, "tcp", addr)
+}
+
+func (d Dialer) send(ctx context.Context, conn net.Conn, req *request) (resp *http.Response, err error) {
+ type respAndError struct {
+ resp *http.Response
+ err error
+ }
+ var (
+ wp WriterPool
+ rp ReaderPool
+ )
+ if wp = d.WriterPool; wp == nil {
+ wp = defaultWriterPool
+ }
+ if rp = d.ReaderPool; rp == nil {
+ rp = defaultReaderPool
+ }
+
+ bw := wp.Get(conn)
+ defer wp.Put(bw)
+
+ if err = req.Write(bw); err != nil {
+ return
+ }
+ if err = bw.Flush(); err != nil {
+ return
+ }
+
+ br := rp.Get(conn)
+ defer rp.Put(br)
+
+ if deadline, ok := ctx.Deadline(); ok {
+ ch := make(chan respAndError, 2)
+ time.AfterFunc(deadline.Sub(time.Now()), func() {
+ ch <- respAndError{nil, timeoutError{}}
+ })
+ go func() {
+ resp, err = http.ReadResponse(br, nil)
+ ch <- respAndError{resp, err}
+ }()
+ r := <-ch
+ resp, err = r.resp, r.err
+ return
+ }
+
+ return http.ReadResponse(br, nil)
+}
+
+func (d Dialer) handshake(req *request, resp Response) (protocol string, extensions []string, err error) {
+ if resp.StatusCode != 101 {
+ err = ErrBadStatus
+ return
+ }
+ if upgrade := resp.Header.Get(headerUpgrade); strings.ToLower(upgrade) != "websocket" {
+ err = ErrBadUpgrade
+ return
+ }
+ if connection := resp.Header.Get(headerConnection); !strings.Contains(strings.ToLower(connection), "upgrade") {
+ err = ErrBadConnection
+ return
+ }
+ if !checkNonce(resp.Header.Get(headerSecAccept), req.Nonce) {
+ err = ErrBadSecAccept
+ return
+ }
+ if extensions := resp.Header.Get(headerSecExtensions); extensions != "" {
+ // TODO(gobwas): implement extensions logic.
+ }
+
+ // We check single value of Sec-Websocket-Protocol header according to this:
+ // RFC6455 1.3: "The server selects one or none of the acceptable protocols and echoes
+ // that value in its handshake to indicate that it has selected that
+ // protocol."
+ if protocol = resp.Header.Get(headerSecProtocol); protocol != "" {
+ var has bool
+ for _, p := range req.Protocols {
+ if has = p == protocol; has {
+ break
+ }
+ }
+ if !has {
+ err = ErrBadSubProtocol
+ return
+ }
+ }
+ return
+}
+
+func hostport(u *url.URL) string {
+ host, port := split2(u.Host, ':')
+ if port != "" {
+ return u.Host
+ }
+ if u.Scheme == "wss" {
+ return host + ":443"
+ }
+ return host + ":80"
+}
+
+func split2(s string, sep byte) (a, b string) {
+ if i := strings.LastIndexByte(s, sep); i != -1 {
+ return s[:i], s[i+1:]
+ }
+ return s, ""
+}
+
+type readerPool int
+
+func (n readerPool) Get(r io.Reader) *bufio.Reader { return pbufio.GetReader(r, int(n)) }
+func (n readerPool) Put(r *bufio.Reader) { pbufio.PutReader(r, int(n)) }
+
+type writerPool int
+
+func (n writerPool) Get(w io.Writer) *bufio.Writer { return pbufio.GetWriter(w, int(n)) }
+func (n writerPool) Put(w *bufio.Writer) { pbufio.PutWriter(w, int(n)) }
+
+type timeoutError struct{}
+
+func (timeoutError) Timeout() bool { return true }
+func (timeoutError) Temporary() bool { return true }
+func (timeoutError) Error() string { return "client timeout" }
diff --git a/client_test.go b/client_test.go
new file mode 100644
index 0000000..c20de73
--- /dev/null
+++ b/client_test.go
@@ -0,0 +1,245 @@
+package ws
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "testing"
+ "time"
+)
+
+func TestDialerHandshake(t *testing.T) {
+ for i, test := range []struct {
+ res *http.Response
+ accept bool
+ protocols []string
+ err error
+ }{
+ {
+ res: &http.Response{
+ StatusCode: 101,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Upgrade"},
+ headerUpgrade: []string{"websocket"},
+ },
+ },
+ accept: true,
+ },
+ {
+ res: &http.Response{
+ StatusCode: 101,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Upgrade"},
+ headerUpgrade: []string{"websocket"},
+ headerSecProtocol: []string{"json"},
+ },
+ },
+ protocols: []string{"xml", "json", "soap"},
+ accept: true,
+ },
+ {
+ res: &http.Response{
+ StatusCode: 400,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Upgrade"},
+ headerUpgrade: []string{"websocket"},
+ },
+ },
+ err: ErrBadStatus,
+ },
+ {
+ res: &http.Response{
+ StatusCode: 101,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Error"},
+ headerUpgrade: []string{"websocket"},
+ },
+ },
+ err: ErrBadConnection,
+ },
+ {
+ res: &http.Response{
+ StatusCode: 101,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Upgrade"},
+ headerUpgrade: []string{"iproto"},
+ },
+ },
+ err: ErrBadUpgrade,
+ },
+ {
+ res: &http.Response{
+ StatusCode: 101,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Upgrade"},
+ headerUpgrade: []string{"websocket"},
+ },
+ },
+ accept: false,
+ err: ErrBadSecAccept,
+ },
+ {
+ res: &http.Response{
+ StatusCode: 101,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: http.Header{
+ headerConnection: []string{"Upgrade"},
+ headerUpgrade: []string{"websocket"},
+ headerSecProtocol: []string{"oops"},
+ },
+ },
+ accept: true,
+ err: ErrBadSubProtocol,
+ },
+ } {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ rb := &bytes.Buffer{}
+ wb := &bytes.Buffer{}
+ wbuf := bufio.NewReader(wb)
+
+ sig := make(chan struct{})
+ go func() {
+ <-sig
+ req, err := http.ReadRequest(wbuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var key []byte
+ if test.accept {
+ rk := req.Header.Get(headerSecKey)
+ key = []byte(rk)
+ } else {
+ key = make([]byte, 24)
+ rand.Read(key)
+ }
+
+ accept := makeAccept([]byte(key))
+ test.res.Header.Set(headerSecAccept, string(accept))
+ test.res.Request = req
+ test.res.Write(rb)
+
+ sig <- struct{}{}
+ }()
+
+ conn := &stubConn{
+ read: func(p []byte) (int, error) {
+ <-sig
+ return rb.Read(p)
+ },
+ write: func(p []byte) (int, error) {
+ n, err := wb.Write(p)
+ sig <- struct{}{}
+ return n, err
+ },
+ close: func() error { return nil },
+ }
+
+ pr := stubReadPool{}
+ pw := stubWritePool{}
+
+ d := Dialer{
+ Protocol: test.protocols,
+ NetDial: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return conn, nil
+ },
+ ReaderPool: &pr,
+ WriterPool: &pw,
+ }
+
+ _, _, err := d.Dial(context.Background(), "ws://gobwas.com")
+ if test.err != err {
+ t.Errorf("unexpected error: %v;\n\twant %v", err, test.err)
+ }
+ })
+ }
+}
+
+func TestHostPortResolve(t *testing.T) {
+ for _, test := range []struct {
+ url *url.URL
+ ret string
+ }{
+ {
+ url: &url.URL{Host: "example.com", Scheme: "ws"},
+ ret: "example.com:80",
+ },
+ {
+ url: &url.URL{Host: "example.com", Scheme: "wss"},
+ ret: "example.com:443",
+ },
+ {
+ url: &url.URL{Host: "example.com:3000", Scheme: "wss"},
+ ret: "example.com:3000",
+ },
+ } {
+ t.Run(test.url.String(), func(t *testing.T) {
+ ret := hostport(test.url)
+ if test.ret != ret {
+ t.Errorf("expected %s; got %s", test.ret, ret)
+ }
+ })
+ }
+}
+
+type stubConn struct {
+ read func([]byte) (int, error)
+ write func([]byte) (int, error)
+ close func() error
+}
+
+func (s stubConn) Read(p []byte) (int, error) { return s.read(p) }
+func (s stubConn) Write(p []byte) (int, error) { return s.write(p) }
+func (s stubConn) Close() error { return s.close() }
+func (s stubConn) LocalAddr() net.Addr { return nil }
+func (s stubConn) RemoteAddr() net.Addr { return nil }
+func (s stubConn) SetDeadline(t time.Time) error { return nil }
+func (s stubConn) SetReadDeadline(t time.Time) error { return nil }
+func (s stubConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func makeNonceFrom(bts []byte) (ret [nonceSize]byte) {
+ base64.StdEncoding.Encode(ret[:], bts)
+ return
+}
+
+func nonceAsSlice(bts [nonceSize]byte) []byte {
+ return bts[:]
+}
+
+type stubPool struct {
+ getCalls int
+ putCalls int
+}
+
+type stubWritePool struct {
+ stubPool
+}
+
+func (s *stubWritePool) Get(w io.Writer) *bufio.Writer { s.getCalls++; return bufio.NewWriter(w) }
+func (s *stubWritePool) Put(bw *bufio.Writer) { s.putCalls++ }
+
+type stubReadPool struct {
+ stubPool
+}
+
+func (s *stubReadPool) Get(r io.Reader) *bufio.Reader { s.getCalls++; return bufio.NewReader(r) }
+func (s *stubReadPool) Put(br *bufio.Reader) { s.putCalls++ }
diff --git a/example/autobahn/autobahn.go b/example/autobahn/autobahn.go
new file mode 100644
index 0000000..1831f6e
--- /dev/null
+++ b/example/autobahn/autobahn.go
@@ -0,0 +1,251 @@
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+
+ "github.com/gobwas/ws"
+ "github.com/gobwas/ws/wsutil"
+)
+
+const dir = "./example/autobahn"
+
+var (
+ addr = flag.String("listen", ":9001", "addr to listen")
+ reports = flag.String("reports", dir+"/reports", "path to reports directory")
+ static = flag.String("static", dir+"/static", "path to static assets directory")
+)
+
+func main() {
+ flag.Parse()
+
+ log.Printf("reports dir is set to: %s", *reports)
+ log.Printf("static dir is set to: %s", *static)
+
+ http.HandleFunc("/", handlerIndex())
+ http.HandleFunc("/library", handlerEcho())
+ http.HandleFunc("/utils", handlerEcho2())
+ http.Handle("/reports/", http.StripPrefix("/reports/", http.FileServer(http.Dir(*reports))))
+
+ log.Printf("ready to listen on %s", *addr)
+ log.Fatal(http.ListenAndServe(*addr, nil))
+}
+
+var (
+ CloseInvalidPayload = ws.MustCompileFrame(
+ ws.NewCloseFrame(ws.StatusInvalidFramePayloadData, ""),
+ )
+ CloseProtocolError = ws.MustCompileFrame(
+ ws.NewCloseFrame(ws.StatusProtocolError, ""),
+ )
+)
+
+func handlerEcho2() func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ conn, _, _, err := ws.Upgrade(r, w, nil)
+ if err != nil {
+ log.Printf("upgrade error: %s", err)
+ return
+ }
+ defer conn.Close()
+
+ ch := wsutil.ControlHandler(conn, 0)
+
+ rd := wsutil.NewReader(conn, ws.StateServerSide)
+ rd.HandleIntermediate(ch)
+
+ for {
+ var r io.Reader = rd
+ var ur *wsutil.UTF8Reader
+
+ h, err := rd.Next()
+ if err != nil {
+ log.Printf("next reader error: %s", err)
+ return
+ }
+
+ switch {
+ case h.OpCode.IsControl():
+ if err = ch(h, rd); err != nil {
+ log.Print(err)
+ return
+ }
+ continue
+
+ case h.OpCode == ws.OpText:
+ ur = wsutil.NewUTF8Reader(r)
+ r = ur
+ }
+
+ wr := wsutil.NextWriter(conn, h.OpCode, false)
+ _, err = io.Copy(wr, r)
+ if err == nil && ur != nil {
+ err = ur.Close()
+ }
+ if err == nil {
+ err = wr.Flush()
+ }
+ if err != nil {
+ log.Printf("copy error: %s", err)
+ return
+ }
+ }
+ }
+}
+
+func handlerEcho() func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ conn, _, _, err := ws.Upgrade(r, w, nil)
+ if err != nil {
+ log.Printf("upgrade error: %s", err)
+ return
+ }
+ defer conn.Close()
+
+ state := ws.StateServerSide
+
+ textPending := false
+ utf8Reader := wsutil.NewUTF8Reader(nil)
+ cipherReader := wsutil.NewCipherReader(nil, nil)
+
+ for {
+ header, err := ws.ReadHeader(conn)
+ if err != nil {
+ log.Printf("read header error: %s", err)
+ break
+ }
+ if err = ws.CheckHeader(header, state); err != nil {
+ log.Printf("header check error: %s", err)
+ conn.Write(CloseProtocolError)
+ return
+ }
+
+ var r io.Reader
+ cipherReader.Reset(
+ io.LimitReader(conn, header.Length),
+ header.Mask,
+ )
+ r = cipherReader
+
+ var utf8Fin bool
+ switch header.OpCode {
+ case ws.OpPing:
+ header.OpCode = ws.OpPong
+ header.Mask = nil
+ ws.WriteHeader(conn, header)
+ io.CopyN(conn, cipherReader, header.Length)
+ continue
+
+ case ws.OpPong:
+ io.CopyN(ioutil.Discard, conn, header.Length)
+ continue
+
+ case ws.OpClose:
+ utf8Fin = true
+
+ case ws.OpContinuation:
+ if textPending {
+ utf8Reader.SetSource(cipherReader)
+ r = utf8Reader
+ }
+ if header.Fin {
+ state = state.Clear(ws.StateFragmented)
+ textPending = false
+ utf8Fin = true
+ }
+
+ case ws.OpText:
+ utf8Reader.Reset(cipherReader)
+ r = utf8Reader
+
+ if !header.Fin {
+ state = state.Set(ws.StateFragmented)
+ textPending = true
+ } else {
+ utf8Fin = true
+ }
+
+ case ws.OpBinary:
+ if !header.Fin {
+ state = state.Set(ws.StateFragmented)
+ }
+ }
+
+ payload := make([]byte, header.Length)
+ _, err = io.ReadFull(r, payload)
+ if err == nil && utf8Fin {
+ err = utf8Reader.Close()
+ }
+ if err != nil {
+ log.Printf("read payload error: %s", err)
+ if err == wsutil.ErrInvalidUtf8 {
+ conn.Write(CloseInvalidPayload)
+ } else {
+ conn.Write(ws.CompiledClose)
+ }
+ return
+ }
+
+ if header.OpCode == ws.OpClose {
+ code, reason := ws.ParseCloseFrameData(payload)
+ log.Printf("close frame received: %v %v", code, reason)
+
+ if !code.Empty() {
+ switch {
+ case code.IsProtocolSpec() && !code.IsProtocolDefined():
+ err = fmt.Errorf("close code from spec range is not defined")
+ default:
+ err = ws.CheckCloseFrameData(code, reason)
+ }
+ if err != nil {
+ log.Printf("invalid close data: %s", err)
+ conn.Write(CloseProtocolError)
+ } else {
+ ws.WriteFrame(conn, ws.NewCloseFrame(code, ""))
+ }
+ return
+ }
+
+ conn.Write(ws.CompiledClose)
+ return
+ }
+
+ header.Mask = nil
+ ws.WriteHeader(conn, header)
+ conn.Write(payload)
+ }
+ }
+}
+
+func handlerIndex() func(w http.ResponseWriter, r *http.Request) {
+ index, err := os.Open(*static + "/index.html")
+ if err != nil {
+ log.Fatal(err)
+ }
+ bts, err := ioutil.ReadAll(index)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("reqeust to %s", r.URL)
+ switch r.URL.Path {
+ case "/":
+ buf := bytes.NewBuffer(bts)
+ _, err := buf.WriteTo(w)
+ if err != nil {
+ log.Printf("write index bytes error: %s", err)
+ }
+ case "/favicon.ico":
+ w.WriteHeader(http.StatusNotFound)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }
+}
diff --git a/example/autobahn/fuzzingclient.json b/example/autobahn/fuzzingclient.json
new file mode 100644
index 0000000..39aa808
--- /dev/null
+++ b/example/autobahn/fuzzingclient.json
@@ -0,0 +1,17 @@
+
+{
+ "outdir": "./example/autobahn/reports/servers",
+ "servers": [
+ {
+ "agent": "Library",
+ "url": "ws://127.0.0.1:9001/library"
+ },
+ {
+ "agent": "Utils",
+ "url": "ws://127.0.0.1:9001/utils"
+ }
+ ],
+ "cases": ["*"],
+ "exclude-cases": [],
+ "exclude-agent-cases": {}
+}
diff --git a/example/autobahn/static/index.html b/example/autobahn/static/index.html
new file mode 100644
index 0000000..77f9265
--- /dev/null
+++ b/example/autobahn/static/index.html
@@ -0,0 +1,7 @@
+
+
+Welcome to WebSocket test server!
+Ready to Autobahn!
+Reports
+
+
diff --git a/frame.go b/frame.go
new file mode 100644
index 0000000..08e4cc7
--- /dev/null
+++ b/frame.go
@@ -0,0 +1,339 @@
+package ws
+
+import (
+ "bytes"
+ "encoding/binary"
+)
+
+// Constants defined by specification.
+const (
+ MaxControlFramePayloadSize = 125 // All control frames must have a payload length of 125 bytes or less.
+)
+
+// OpCode represents operation code.
+type OpCode byte
+
+// Operation codes defined by specification.
+// See https://tools.ietf.org/html/rfc6455#section-5.2
+const (
+ OpContinuation OpCode = 0x0
+ OpText = 0x1
+ OpBinary = 0x2
+ OpClose = 0x8
+ OpPing = 0x9
+ OpPong = 0xa
+)
+
+// IsControl checks wheter the c is control operation code.
+// See https://tools.ietf.org/html/rfc6455#section-5.5
+func (c OpCode) IsControl() bool {
+ // RFC6455: Control frames are identified by opcodes where
+ // the most significant bit of the opcode is 1.
+ //
+ // Note that OpCode is only 4 bit length.
+ return c&0x8 != 0
+}
+
+// IsData checks wheter the c is data operation code.
+// See https://tools.ietf.org/html/rfc6455#section-5.6
+func (c OpCode) IsData() bool {
+ // RFC6455: Data frames (e.g., non-control frames) are identified by opcodes
+ // where the most significant bit of the opcode is 0.
+ //
+ // Note that OpCode is only 4 bit length.
+ return c&0x8 == 0
+}
+
+// IsReserved checks wheter the c is reserved operation code.
+// See https://tools.ietf.org/html/rfc6455#section-5.2
+func (c OpCode) IsReserved() bool {
+ // RFC6455:
+ // %x3-7 are reserved for further non-control frames
+ // %xB-F are reserved for further control frames
+ return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf)
+}
+
+// StatusCode represents the encoded reason for closure of websocket connection.
+//
+// There are few helper methods on StatusCode that helps to define a range in
+// which given code is lay in. accordingly to ranges defined in specification.
+//
+// See https://tools.ietf.org/html/rfc6455#section-7.4
+type StatusCode uint16
+
+// StatusCodeRange describes range of StatusCode values.
+type StatusCodeRange struct {
+ Min, Max StatusCode
+}
+
+// Status code ranges defined by specification.
+// See https://tools.ietf.org/html/rfc6455#section-7.4.2
+var (
+ StatusRangeNotInUse = StatusCodeRange{0, 999}
+ StatusRangeProtocol = StatusCodeRange{1000, 2999}
+ StatusRangeApplication = StatusCodeRange{3000, 3999}
+ StatusRangePrivate = StatusCodeRange{4000, 4999}
+)
+
+// Status codes defined by specification.
+// See https://tools.ietf.org/html/rfc6455#section-7.4.1
+const (
+ StatusNormalClosure StatusCode = 1000
+ StatusGoingAway = 1001
+ StatusProtocolError = 1002
+ StatusUnsupportedData = 1003
+ StatusNoMeaningYet = 1004
+ StatusNoStatusRcvd = 1005
+ StatusAbnormalClosure = 1006
+ StatusInvalidFramePayloadData = 1007
+ StatusPolicyViolation = 1008
+ StatusMessageTooBig = 1009
+ StatusMandatoryExt = 1010
+ StatusInternalServerError = 1011
+ StatusTLSHandshake = 1015
+)
+
+// In reports whether the code is defined in given range.
+func (s StatusCode) In(r StatusCodeRange) bool {
+ return r.Min <= s && s <= r.Max
+}
+
+// Empty reports wheter the code is empty.
+// Empty code has no any meaning neither app level codes nor other.
+// This method is useful just to check that code is golang default value 0.
+func (s StatusCode) Empty() bool {
+ return s == 0
+}
+
+// IsNotUsed reports whether the code is predefined in not used range.
+func (s StatusCode) IsNotUsed() bool {
+ return s.In(StatusRangeNotInUse)
+}
+
+// IsApplicationSpec reports whether the code should be defined by
+// application, framework or libraries specification.
+func (s StatusCode) IsApplicationSpec() bool {
+ return s.In(StatusRangeApplication)
+}
+
+// IsPrivateSpec reports whether the code should be defined privately.
+func (s StatusCode) IsPrivateSpec() bool {
+ return s.In(StatusRangePrivate)
+}
+
+// IsProtocolSpec reports whether the code should be defined by protocol specification.
+func (s StatusCode) IsProtocolSpec() bool {
+ return s.In(StatusRangeProtocol)
+}
+
+// IsProtocolDefined reports whether the code is already defined by protocol specification.
+func (s StatusCode) IsProtocolDefined() bool {
+ switch s {
+ case StatusNormalClosure,
+ StatusGoingAway,
+ StatusProtocolError,
+ StatusUnsupportedData,
+ StatusInvalidFramePayloadData,
+ StatusPolicyViolation,
+ StatusMessageTooBig,
+ StatusMandatoryExt,
+ StatusInternalServerError,
+ StatusNoStatusRcvd,
+ StatusAbnormalClosure,
+ StatusTLSHandshake:
+ return true
+ }
+ return false
+}
+
+// IsProtocolReserved reports whether the code is defined by protocol specification
+// to be reserved only for application usage purpose.
+func (s StatusCode) IsProtocolReserved() bool {
+ switch s {
+ // [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a
+ // Close control frame by an endpoint.
+ case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
+ return true
+ default:
+ return false
+ }
+}
+
+// Common frames with no special meaning.
+var (
+ PingFrame = Frame{Header{Fin: true, OpCode: OpPing}, nil}
+ PongFrame = Frame{Header{Fin: true, OpCode: OpPong}, nil}
+ CloseFrame = Frame{Header{Fin: true, OpCode: OpClose}, nil}
+)
+
+// Compiled control frames for common use cases.
+// For construct-serialize optimizations.
+var (
+ CompiledPing = MustCompileFrame(PingFrame)
+ CompiledPong = MustCompileFrame(PongFrame)
+ CompiledClose = MustCompileFrame(CloseFrame)
+)
+
+// Header represents websocket frame header.
+// See https://tools.ietf.org/html/rfc6455#section-5.2
+type Header struct {
+ Fin bool
+ Rsv byte
+ OpCode OpCode
+ Length int64
+ Mask []byte
+}
+
+// Rsv1 reports whether the header has first rsv bit set.
+func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 }
+
+// Rsv2 reports whether the header has second rsv bit set.
+func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 }
+
+// Rsv3 reports whether the header has third rsv bit set.
+func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 }
+
+// Frame represents websocket frame.
+// See https://tools.ietf.org/html/rfc6455#section-5.2
+type Frame struct {
+ Header Header
+ Payload []byte
+}
+
+// NewFrame creates frame with given operation code,
+// flag of completeness and payload bytes.
+func NewFrame(op OpCode, fin bool, p []byte) Frame {
+ return Frame{
+ Header: Header{
+ Fin: fin,
+ OpCode: op,
+ Length: int64(len(p)),
+ },
+ Payload: p,
+ }
+}
+
+// NewTextFrame creates text frame with s as payload.
+// Note that the s is copied in the returned frame payload.
+func NewTextFrame(s string) Frame {
+ p := make([]byte, len(s))
+ copy(p, s)
+ return NewFrame(OpText, true, p)
+}
+
+// NewBinaryFrame creates binary frame with p as payload.
+// Note that p is left as is in the returned frame without copying.
+func NewBinaryFrame(p []byte) Frame {
+ return NewFrame(OpBinary, true, p)
+}
+
+// NewPingFrame creates ping frame with p as payload.
+// Note that p is left as is in the returned frame without copying.
+func NewPingFrame(p []byte) Frame {
+ return NewFrame(OpPing, true, p)
+}
+
+// NewPongFrame creates pong frame with p as payload.
+// Note that p is left as is in the returned frame.
+func NewPongFrame(p []byte) Frame {
+ return NewFrame(OpPong, true, p)
+}
+
+// NewCloseFrame creates close frame with given closure code and reason.
+// Note that it crops reason to fit the limit of control frames payload.
+// See https://tools.ietf.org/html/rfc6455#section-5.5
+func NewCloseFrame(code StatusCode, reason string) Frame {
+ return NewFrame(OpClose, true, NewCloseFrameData(code, reason))
+}
+
+// NewCloseFrameData makes byte representation of code and reason.
+//
+// Note that returned slice is at most 125 bytes length.
+// If reason is too big it will crop it to fit the limit defined by thte spec.
+//
+// See https://tools.ietf.org/html/rfc6455#section-5.5
+func NewCloseFrameData(code StatusCode, reason string) []byte {
+ n := min(2+len(reason), MaxControlFramePayloadSize) // 2 is for status code uint16 encoding.
+ p := make([]byte, n)
+ PutCloseFrameData(p, code, reason)
+ return p
+}
+
+// PutCloseFrameData encodes code and reason into buf and returns the number of bytes written.
+// If the buffer is too small to accomodate at least code, PutCloseFrameData will panic.
+// Note that it does not checks maximum control frame payload size limit.
+func PutCloseFrameData(p []byte, code StatusCode, reason string) int {
+ binary.BigEndian.PutUint16(p, uint16(code))
+ n := copy(p[2:], reason)
+ return n + 2
+}
+
+// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
+// Note that it copies f payload to prevent collisions.
+// For less allocations you could use MaskFrameInplace or construct frame manually.
+func MaskFrame(f Frame) Frame {
+ return MaskFrameWith(f, NewMask())
+}
+
+// MaskFrameWith masks frame with given mask and returns frame
+// with masked payload and Mask header's field set.
+// Note that it copies f payload to prevent collisions.
+// For less allocations you could use MaskFrameInplaceWith or construct frame manually.
+func MaskFrameWith(f Frame, mask []byte) Frame {
+ p := make([]byte, len(f.Payload))
+ copy(p, f.Payload)
+ f.Payload = p
+ return MaskFrameInplaceWith(f, mask)
+}
+
+// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
+// Note that it applies xor cipher to f.Payload without copying, that is, it modifies f.Payload inplace.
+func MaskFrameInplace(f Frame) Frame {
+ return MaskFrameInplaceWith(f, NewMask())
+}
+
+// MaskFrameInplaceWith masks frame with given mask and returns frame
+// with masked payload and Mask header's field set.
+// Note that it applies xor cipher to f.Payload without copying, that is, it modifies f.Payload inplace.
+func MaskFrameInplaceWith(f Frame, mask []byte) Frame {
+ f.Header.Mask = mask
+ Cipher(f.Payload, mask, 0)
+ return f
+}
+
+// NewMask creates new random mask.
+func NewMask() []byte {
+ return randBytes(4)
+}
+
+// CompileFrame returns byte representation of given frame.
+// In terms of memory consumption it is useful to precompile static frames which are often used.
+func CompileFrame(f Frame) (bts []byte, err error) {
+ buf := bytes.NewBuffer(make([]byte, 0, 16))
+ err = WriteFrame(buf, f)
+ bts = buf.Bytes()
+ return
+}
+
+// MustCompileFrame is like CompileFrame but panics if frame cannot be encoded.
+func MustCompileFrame(f Frame) []byte {
+ bts, err := CompileFrame(f)
+ if err != nil {
+ panic(err)
+ }
+ return bts
+}
+
+// Rsv creates rsv byte representation.
+func Rsv(r1, r2, r3 bool) (rsv byte) {
+ if r1 {
+ rsv |= bit5
+ }
+ if r2 {
+ rsv |= bit6
+ }
+ if r3 {
+ rsv |= bit7
+ }
+ return rsv
+}
diff --git a/frame_test.go b/frame_test.go
new file mode 100644
index 0000000..981257a
--- /dev/null
+++ b/frame_test.go
@@ -0,0 +1,26 @@
+package ws
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestOpCodeIsControl(t *testing.T) {
+ for _, test := range []struct {
+ code OpCode
+ exp bool
+ }{
+ {OpClose, true},
+ {OpPing, true},
+ {OpPong, true},
+ {OpBinary, false},
+ {OpText, false},
+ {OpContinuation, false},
+ } {
+ t.Run(fmt.Sprintf("0x%02x", test.code), func(t *testing.T) {
+ if act := test.code.IsControl(); act != test.exp {
+ t.Errorf("IsControl = %v; want %v", act, test.exp)
+ }
+ })
+ }
+}
diff --git a/read.go b/read.go
new file mode 100644
index 0000000..6771819
--- /dev/null
+++ b/read.go
@@ -0,0 +1,149 @@
+package ws
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+)
+
+const (
+ PlatformSizeLimit = int64(^(uint(0)) >> 1) // Max int value for current platform.
+)
+
+// Errors used by frame reader.
+var (
+ ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0")
+ ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits")
+)
+
+// ReadHeader reads a frame header from r.
+func ReadHeader(r io.Reader) (h Header, err error) {
+ // Make slice with 2 bytes len for header, but with 8 byte capacity.
+ // The most useful case of reading header is to read header from
+ // client, that is with mask (4 byte) and some length most cases <= uint16 (2 bytes).
+ // If such case happened, we will reuse bytes without extra allocation.
+ bts := make([]byte, 2, 8)
+
+ //var hv uint64
+ //hp := uintptr(unsafe.Pointer(&hv))
+ //bts := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: hp, Len: 2, Cap: 8}))
+
+ // Prepare to hold first 2 bytes to choose size of next read.
+ _, err = io.ReadFull(r, bts)
+ if err != nil {
+ return
+ }
+
+ h.Fin = bts[0]&bit0 != 0
+ h.Rsv = (bts[0] & 0x70) >> 4
+ h.OpCode = OpCode(bts[0] & 0x0f)
+
+ var extra int
+
+ mask := bts[1]&bit0 != 0
+ if mask {
+ extra += 4
+ }
+
+ length := bts[1] & 0x7f
+ switch {
+ case length < 126:
+ h.Length = int64(length)
+
+ case length == 126:
+ extra += 2
+
+ case length == 127:
+ extra += 8
+
+ default:
+ err = ErrHeaderLengthUnexpected
+ return
+ }
+
+ if extra == 0 {
+ return
+ }
+
+ if extra <= 8 {
+ bts = bts[:extra]
+ } else {
+ bts = make([]byte, extra)
+ }
+
+ _, err = io.ReadFull(r, bts)
+ if err != nil {
+ return
+ }
+
+ switch {
+ case length == 126:
+ h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
+ bts = bts[2:]
+
+ case length == 127:
+ if bts[0]&0x80 != 0 {
+ err = ErrHeaderLengthMSB
+ return
+ }
+ h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
+ bts = bts[8:]
+ }
+
+ if mask {
+ // TODO(gobwas): move to type Mask uint32
+ h.Mask = bts[:4]
+ }
+
+ return
+}
+
+// ReadFrame reads a frame from r.
+// It is not designed for high optimized use case cause it makes allocation
+// for frame.Header.Length size inside to read frame payload into.
+//
+// Note that ReadFrame does not unmask payload.
+func ReadFrame(r io.Reader) (f Frame, err error) {
+ f.Header, err = ReadHeader(r)
+ if err != nil {
+ return
+ }
+
+ if f.Header.Length > 0 {
+ // int(f.Header.Length) is safe here cause we have
+ // checked it for overflow above in ReadHeader.
+ f.Payload = make([]byte, int(f.Header.Length))
+ _, err = io.ReadFull(r, f.Payload)
+ }
+
+ return
+}
+
+// ParseCloseFrameData parses close frame status code and closure reason if any provided.
+// If there is no status code in the payload
+// the empty status code is returned (code.Empty()) with empty string as a reason.
+func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) {
+ if len(payload) < 2 {
+ // We returning empty StatusCode here, preventing the situation
+ // when endpoint really sent code 1005 and we should return ProtocolError on that.
+ //
+ // In other words, we ignoring this rule [RFC6455:7.1.5]:
+ // If this Close control frame contains no status code, _The WebSocket
+ // Connection Close Code_ is considered to be 1005.
+ return
+ }
+ code = StatusCode(binary.BigEndian.Uint16(payload))
+ reason = string(payload[2:])
+ return
+}
+
+// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing
+// that it does not copies payload bytes into reason, but prepares unsafe cast.
+func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) {
+ if len(payload) < 2 {
+ return
+ }
+ code = StatusCode(binary.BigEndian.Uint16(payload))
+ reason = btsToString(payload[2:])
+ return
+}
diff --git a/read_test.go b/read_test.go
new file mode 100644
index 0000000..e540c2e
--- /dev/null
+++ b/read_test.go
@@ -0,0 +1,67 @@
+package ws
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+)
+
+func TestReadHeader(t *testing.T) {
+ for i, test := range append([]RWCase{
+ {
+ Data: bits("0000 0000 0 1111111 10000000 00000000 00000000 00000000 00000000 00000000 00000000 00000000"),
+ // _______________________________________________________________________
+ // |
+ // Length value
+ Err: true,
+ },
+ }, RWCases...) {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ r := bytes.NewReader(test.Data)
+ h, err := ReadHeader(r)
+ if test.Err && err == nil {
+ t.Errorf("expected error, got nil")
+ }
+ if !test.Err && err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if test.Err {
+ return
+ }
+ if !reflect.DeepEqual(h, test.Header) {
+ t.Errorf("ReadHeader()\nread:\n\t%#v\nwant:\n\t%#v", h, test.Header)
+ }
+ })
+ }
+}
+
+func BenchmarkReadHeader(b *testing.B) {
+ for i, bench := range []struct {
+ label string
+ header Header
+ }{
+ {"t", Header{OpCode: OpText, Fin: true}},
+ {"t-m", Header{OpCode: OpText, Fin: true, Mask: NewMask()}},
+ {"t-m-u16", Header{OpCode: OpText, Fin: true, Length: len16, Mask: NewMask()}},
+ {"t-m-u64", Header{OpCode: OpText, Fin: true, Length: len64, Mask: NewMask()}},
+ } {
+ b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) {
+ bts := MustCompileFrame(Frame{Header: bench.header})
+ rds := make([]io.Reader, b.N)
+ for i := 0; i < b.N; i++ {
+ rds[i] = bytes.NewReader(bts)
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ _, err := ReadHeader(rds[i])
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
diff --git a/request.go b/request.go
new file mode 100644
index 0000000..5233524
--- /dev/null
+++ b/request.go
@@ -0,0 +1,190 @@
+package ws
+
+import (
+ "crypto/sha1"
+ "encoding/base64"
+ "fmt"
+ "hash"
+ "math/rand"
+ "net/http"
+ "net/textproto"
+ "net/url"
+ "strings"
+ "sync"
+)
+
+const (
+ nonceSize = 24
+ acceptSize = 28
+ shaSumSize = 20
+)
+
+var (
+ headerUpgrade = textproto.CanonicalMIMEHeaderKey("Upgrade")
+ headerConnection = textproto.CanonicalMIMEHeaderKey("Connection")
+ headerHost = textproto.CanonicalMIMEHeaderKey("Host")
+ headerOrigin = textproto.CanonicalMIMEHeaderKey("Origin")
+ headerSecVersion = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Version")
+ headerSecProtocol = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Protocol")
+ headerSecExtensions = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Extensions")
+ headerSecKey = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Key")
+ headerSecAccept = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Accept")
+)
+
+var ErrBadNonce = fmt.Errorf("nonce size is not %d", nonceSize)
+
+var WebSocketMagic = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
+
+type request struct {
+ http.Request
+ Nonce [nonceSize]byte
+ Protocols []string
+ Extensions []string
+}
+
+func (req *request) Reset(urlstr string, headers http.Header, protocols, extensions []string) error {
+ u, err := url.ParseRequestURI(urlstr)
+ if err != nil {
+ return err
+ }
+
+ req.URL = u
+
+ newNonce(req.Nonce[:])
+ req.Header.Set(headerSecKey, string(req.Nonce[:]))
+
+ req.Protocols = protocols
+ if protocols != nil {
+ req.Header.Set(headerSecProtocol, strings.Join(protocols, ", "))
+ }
+
+ req.Extensions = extensions
+ if extensions != nil {
+ req.Header.Set(headerSecExtensions, strings.Join(extensions, ", "))
+ }
+
+ req.Header.Set("User-Agent", "") // Disable default user-agent header.
+
+ if headers != nil {
+ for k, v := range headers {
+ req.Header[k] = v
+ }
+ }
+
+ return nil
+}
+
+var requestPool sync.Pool
+
+func getRequest() *request {
+ if req := requestPool.Get(); req != nil {
+ return req.(*request)
+ }
+ return newCommonRequest()
+}
+
+func putRequest(req *request) {
+ req.URL = nil
+ req.Protocols = nil
+ req.Extensions = nil
+
+ for k := range req.Header {
+ switch k {
+ case headerUpgrade, headerConnection, headerSecVersion:
+ // leave common headers
+ default:
+ delete(req.Header, k)
+ }
+ }
+
+ requestPool.Put(req)
+}
+
+func newCommonRequest() *request {
+ req := &request{
+ Request: http.Request{
+ Header: make(http.Header),
+ },
+ }
+
+ req.Header.Set(headerUpgrade, "websocket")
+ req.Header.Set(headerConnection, "Upgrade")
+ req.Header.Set(headerSecVersion, "13")
+
+ return req
+}
+
+var sha1Pool sync.Pool
+
+func acquireSha1() hash.Hash {
+ if h := sha1Pool.Get(); h != nil {
+ return h.(hash.Hash)
+ }
+ return sha1.New()
+}
+
+func releaseSha1(h hash.Hash) {
+ h.Reset()
+ sha1Pool.Put(h)
+}
+
+// todo bench put expect to req as array
+func checkNonce(accept string, nonce [nonceSize]byte) bool {
+ if len(accept) != 28 {
+ return false
+ }
+
+ expect := makeAccept(nonce[:])
+
+ return string(expect) == accept
+}
+
+//const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
+//
+//func newNonce(dest []byte) {
+// for i := 0; i < 22; i++ {
+// dest[i] = alphabet[rand.Intn(len(alphabet))]
+// }
+// dest[22] = '='
+// dest[23] = '='
+//}
+
+func randBytes(n int) []byte {
+ bts := make([]byte, n)
+ if _, err := rand.Read(bts); err != nil {
+ panic(fmt.Sprintf("rand read error: %s", err))
+ }
+ return bts
+}
+
+func newNonce(dest []byte) {
+ base64.StdEncoding.Encode(dest, randBytes(16))
+}
+
+func makeAccept(nonce []byte) []byte {
+ bts := make([]byte, 0, acceptSize+shaSumSize)
+ n := putAccept(nonce, bts)
+ return bts[:n]
+}
+
+func putAccept(nonce, buf []byte) int {
+ if cap(buf) < acceptSize+shaSumSize {
+ panic(fmt.Sprintf("buffer cap is %d; want at least %d", len(buf), acceptSize+shaSumSize))
+ }
+ if len(nonce) != nonceSize {
+ panic(fmt.Sprintf("nonce size is %d; want %d", len(nonce), nonceSize))
+ }
+
+ sha := acquireSha1()
+ defer releaseSha1(sha)
+
+ sha.Write(nonce)
+ sha.Write(WebSocketMagic)
+
+ buf = buf[:acceptSize]
+ sum := sha.Sum(buf[acceptSize:])
+
+ base64.StdEncoding.Encode(buf, sum)
+
+ return acceptSize
+}
diff --git a/request_test.go b/request_test.go
new file mode 100644
index 0000000..a8f3962
--- /dev/null
+++ b/request_test.go
@@ -0,0 +1,143 @@
+package ws
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/rand"
+ "encoding/base64"
+ "fmt"
+ "net/http"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestRequestReset(t *testing.T) {
+ for i, test := range []struct {
+ url string
+ expHost string
+ expPath string
+ protocols []string
+ extensions []string
+ headers http.Header
+ err bool
+ }{
+ {
+ url: "wss://websocket.com/chat",
+ expHost: "websocket.com",
+ expPath: "/chat",
+ protocols: []string{"subproto", "hello"},
+ extensions: []string{"foo; bar=1", "baz"},
+ headers: http.Header{
+ "Origin": []string{"https://websocket.com"},
+ },
+ },
+ {
+ url: "websocket.com/chat",
+ err: true,
+ },
+ } {
+ commonHeaders := map[string]string{
+ headerConnection: "Upgrade",
+ headerUpgrade: "websocket",
+ headerSecVersion: "13",
+ }
+
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ req := getRequest()
+ err := req.Reset(test.url, test.headers, test.protocols, test.extensions)
+ if test.err && err == nil {
+ t.Errorf("expected error; got nil")
+ }
+ if !test.err && err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if test.err {
+ return
+ }
+
+ buf := &bytes.Buffer{}
+ if err = req.Write(buf); err != nil {
+ t.Errorf("dumping request error: %s", err)
+ return
+ }
+
+ r, err := http.ReadRequest(bufio.NewReader(buf))
+ if err != nil {
+ t.Errorf("read request error: %s", err)
+ return
+ }
+
+ if r.Method != "GET" {
+ t.Errorf("http method is %s; want GET", r.Method)
+ }
+ if r.ProtoMinor != 1 {
+ t.Errorf("http proto minor is %d; want 1", r.ProtoMinor)
+ }
+ if r.ProtoMajor != 1 {
+ t.Errorf("http proto major is %d; want 1", r.ProtoMajor)
+ }
+ if r.URL.Path != test.expPath {
+ t.Errorf("http path is %s; want %s", r.URL.Path, test.expPath)
+ }
+
+ key := r.Header.Get(headerSecKey)
+ bts, err := base64.StdEncoding.DecodeString(key)
+ if err != nil {
+ t.Errorf("bad %q header: %s", headerSecKey, err)
+ }
+ if n := len(bts); n != 16 {
+ t.Errorf("nonce len is %d; want 16", n)
+ }
+ r.Header.Del(headerSecKey)
+
+ sub := r.Header.Get(headerSecProtocol)
+ protocols := strings.Split(sub, ",")
+ for i, p := range protocols {
+ protocols[i] = strings.TrimSpace(p)
+ }
+ if !reflect.DeepEqual(protocols, test.protocols) {
+ t.Errorf("%q headers is %v; want %s", headerSecProtocol, protocols, test.protocols)
+ }
+ r.Header.Del(headerSecProtocol)
+
+ ext := r.Header.Get(headerSecExtensions)
+ extensions := strings.Split(ext, ",")
+ for i, e := range extensions {
+ extensions[i] = strings.TrimSpace(e)
+ }
+ if !reflect.DeepEqual(extensions, test.extensions) {
+ t.Errorf("%q headers is %v; want %s", headerSecExtensions, extensions, test.extensions)
+ }
+ r.Header.Del(headerSecExtensions)
+
+ for key, exp := range commonHeaders {
+ if act := r.Header.Get(key); act != exp {
+ t.Errorf("http %q header is %q; want %q", key, act, exp)
+ }
+ r.Header.Del(key)
+ }
+ for key, exp := range test.headers {
+ if act := r.Header.Get(key); act != exp[0] {
+ t.Errorf("http %q custom header is %q; want %q", key, act, exp[0])
+ }
+ r.Header.Del(key)
+ }
+ if len(r.Header) != 0 {
+ t.Errorf("http request has extra headers:\n\t%v", r.Header)
+ }
+ })
+ }
+}
+
+func BenchmarkMakeAccept(b *testing.B) {
+ nonce := make([]byte, nonceSize)
+ _, err := rand.Read(nonce)
+ if err != nil {
+ b.Fatal(err)
+ }
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ _ = makeAccept(nonce)
+ }
+}
diff --git a/rw_test.go b/rw_test.go
new file mode 100644
index 0000000..4c3342b
--- /dev/null
+++ b/rw_test.go
@@ -0,0 +1,86 @@
+package ws
+
+import (
+ "fmt"
+ "strings"
+)
+
+type RWCase struct {
+ Data []byte
+ Header Header
+ Err bool
+}
+
+var RWCases = []RWCase{
+ {
+ Data: bits("1 001 0001 0 1100100"),
+ // _ ___ ____ _ _______
+ // | | | | |
+ // Fin | | Mask Length
+ // Rsv |
+ // TextFrame
+ Header: Header{
+ Fin: true,
+ Rsv: Rsv(false, false, true),
+ OpCode: OpText,
+ Length: 100,
+ Mask: nil,
+ },
+ },
+ {
+ Data: bits("1 001 0001 1 1100100 00000001 10001000 00000000 11111111"),
+ // _ ___ ____ _ _______ ___________________________________
+ // | | | | | |
+ // Fin | | Mask Length Mask value
+ // Rsv |
+ // TextFrame
+ Header: Header{
+ Fin: true,
+ Rsv: Rsv(false, false, true),
+ OpCode: OpText,
+ Length: 100,
+ Mask: []byte{0x01, 0x88, 0x00, 0xff},
+ },
+ },
+ {
+ Data: bits("0 110 0010 0 1111110 00001111 11111111"),
+ // _ ___ ____ _ _______ _________________
+ // | | | | | |
+ // Fin | | Mask Length Length value
+ // Rsv |
+ // BinaryFrame
+ Header: Header{
+ Fin: false,
+ Rsv: Rsv(true, true, false),
+ OpCode: OpBinary,
+ Length: 0x0fff,
+ Mask: nil,
+ },
+ },
+ {
+ Data: bits("1 000 1010 0 1111111 01111111 00000000 00000000 00000000 00000000 00000000 00000000 00000000"),
+ // _ ___ ____ _ _______ _______________________________________________________________________
+ // | | | | | |
+ // Fin | | Mask Length Length value
+ // Rsv |
+ // PongFrame
+ Header: Header{
+ Fin: true,
+ Rsv: Rsv(false, false, false),
+ OpCode: OpPong,
+ Length: 0x7f00000000000000,
+ Mask: nil,
+ },
+ },
+}
+
+func bits(s string) []byte {
+ s = strings.Replace(s, " ", "", -1)
+ bts := make([]byte, len(s)/8)
+
+ for i, j := 0, 0; i < len(s); i, j = i+8, j+1 {
+ fmt.Sscanf(s[i:], "%08b", &bts[j])
+ }
+
+ return bts
+}
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..49d145f
--- /dev/null
+++ b/server.go
@@ -0,0 +1,220 @@
+package ws
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "net/http"
+ "strings"
+ _ "unsafe" // for go:linkname
+)
+
+const (
+ textUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
+ crlf = "\r\n"
+ colonAndSpace = ": "
+)
+
+// Errors used by upgraders.
+var (
+ ErrBadHost = fmt.Errorf("bad %q header", headerHost)
+ ErrBadUpgrade = fmt.Errorf("bad %q header", headerUpgrade)
+ ErrBadConnection = fmt.Errorf("bad %q header", headerConnection)
+ ErrBadSecAccept = fmt.Errorf("bad %q header", headerSecAccept)
+ ErrBadSecKey = fmt.Errorf("bad %q header", headerSecKey)
+ ErrBadSecVersion = fmt.Errorf("bad %q header", headerSecVersion)
+ ErrBadHijacker = fmt.Errorf("given http.ResponseWriter is not a http.Hijacker")
+)
+
+// SelectFromSlice creates accept function that could be used as Protocol/Extension
+// select function in the UpgradeConfig.
+func SelectFromSlice(accept []string) func(string) bool {
+ if len(accept) > 16 {
+ mp := make(map[string]struct{}, len(accept))
+ for _, p := range accept {
+ mp[p] = struct{}{}
+ }
+ return func(p string) bool {
+ _, ok := mp[p]
+ return ok
+ }
+ }
+ return func(p string) bool {
+ for _, ok := range accept {
+ if p == ok {
+ return true
+ }
+ }
+ return false
+ }
+}
+
+// UpgradeConfig contains options for upgrading http connection to websocket.
+type UpgradeConfig struct {
+ // Header is the set of custom headers that will be sent with the response.
+ Header http.Header
+
+ // Protocol is the select function that is used to select subprotocol
+ // from client passed list.
+ Protocol func(string) bool
+
+ // Extension is the select function that is used to select extensions
+ // from client passed list.
+ Extension func(string) bool
+}
+
+// Upgrade upgrades http connection to websocket.
+// It hijacks net.Conn from response writer.
+//
+// If succeed it returns upgraded connection and Handshake struct describing
+// handshake info.
+func Upgrade(r *http.Request, w http.ResponseWriter, c *UpgradeConfig) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
+ if r.Host == "" {
+ err = ErrBadHost
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ if u := getHeader(r.Header, headerUpgrade); u != "websocket" && strings.ToLower(u) != "websocket" {
+ err = ErrBadUpgrade
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ if c := getHeader(r.Header, headerConnection); c != "Upgrade" && !hasToken(c, "upgrade") {
+ err = ErrBadConnection
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ if v := getHeader(r.Header, headerSecVersion); v != "13" {
+ err = ErrBadSecVersion
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ nonce := getHeader(r.Header, headerSecKey)
+ if len(nonce) != nonceSize {
+ err = ErrBadSecKey
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ err = ErrBadHijacker
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ conn, rw, err = hj.Hijack()
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+ rw.WriteString(textUpgrade)
+
+ accept := makeAccept(strToBytes(nonce))
+ writeHeaderKey(rw.Writer, headerSecAccept)
+ writeHeaderValueBytes(rw.Writer, accept)
+
+ if c != nil {
+ if p, check := r.Header[headerSecProtocol], c.Protocol; len(p) > 0 && check != nil {
+ for _, v := range p {
+ if check(v) {
+ hs.Protocol = v
+ writeHeader(rw.Writer, headerSecProtocol, hs.Protocol)
+ break
+ }
+ }
+ }
+ // TODO(gobwas) parse extensions.
+ //if e, check := r.Header[headerSecExtensions], c.Extension; len(e) > 0 && check != nil {
+ // hs.Extensions = selectExtensions(e, c.Extension)
+ // if len(hs.Extensions) > 0 {
+ // writeHeader(rw.Writer, headerSecExtensions, strings.Join(hs.Extensions, ", "))
+ // }
+ //}
+ for key, values := range c.Header {
+ for _, val := range values {
+ writeHeader(rw.Writer, key, val)
+ }
+ }
+ }
+
+ rw.WriteString(crlf)
+
+ err = rw.Flush()
+
+ return
+}
+
+// getHeader is the same as textproto.MIMEHeader.Get, except the thing,
+// that key is already canonical. This helps to increase performance.
+func getHeader(h http.Header, key string) string {
+ if h == nil {
+ return ""
+ }
+ v := h[key]
+ if len(v) == 0 {
+ return ""
+ }
+ return v[0]
+}
+
+func writeHeader(bw *bufio.Writer, key, value string) {
+ writeHeaderKey(bw, key)
+ writeHeaderValue(bw, value)
+}
+
+func writeHeaderKey(bw *bufio.Writer, key string) {
+ bw.WriteString(key)
+ bw.WriteString(colonAndSpace)
+}
+
+func writeHeaderValue(bw *bufio.Writer, value string) {
+ bw.WriteString(value)
+ bw.WriteString(crlf)
+}
+
+func writeHeaderValueBytes(bw *bufio.Writer, value []byte) {
+ bw.Write(value)
+ bw.WriteString(crlf)
+}
+
+func hasToken(header, token string) bool {
+ var pos int
+ for i := 0; i <= len(header); i++ {
+ if i == len(header) || header[i] == ',' {
+ v := strings.TrimSpace(header[pos:i])
+ if len(v) == len(token) && strings.ToLower(v) == token {
+ return true
+ }
+ pos = i + 1
+ }
+ }
+ return false
+}
+
+func selectProtocol(h string, ok func(string) bool) string {
+ var start int
+ for i := 0; i < len(h); i++ {
+ c := h[i]
+ // The elements that comprise this value MUST be non-empty strings with characters in the range
+ // U+0021 to U+007E not including separator characters as defined in [RFC2616]
+ // and MUST all be unique strings.
+ if c != ',' && '!' <= c && c <= '~' {
+ continue
+ }
+ if str := h[start:i]; len(str) > 0 && ok(str) {
+ return str
+ }
+ start = i + 1
+ }
+ if str := h[start:]; len(str) > 0 && ok(str) {
+ return str
+ }
+ return ""
+}
+
+func selectExtensions(h []string, ok func(string) bool) []string {
+ // TODO(gobwas): parse extensions with params
+ return nil
+}
diff --git a/server_test.go b/server_test.go
new file mode 100644
index 0000000..de1e549
--- /dev/null
+++ b/server_test.go
@@ -0,0 +1,426 @@
+package ws
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "reflect"
+ "sort"
+ "strings"
+ "testing"
+ _ "unsafe" // for go:linkname
+)
+
+func TestUpgrade(t *testing.T) {
+ for i, test := range []struct {
+ nonce []byte
+ req *http.Request
+ res *http.Response
+ hs Handshake
+ cfg *UpgradeConfig
+ err error
+ }{
+ {
+ nonce: mustMakeNonce(),
+ req: mustMakeRequest("GET", "ws://example.org", http.Header{
+ headerUpgrade: []string{"websocket"},
+ headerConnection: []string{"Upgrade"},
+ headerSecVersion: []string{"13"},
+ }),
+ res: mustMakeResponse(101, http.Header{
+ headerUpgrade: []string{"websocket"},
+ headerConnection: []string{"Upgrade"},
+ }),
+ cfg: &UpgradeConfig{
+ Protocol: func(sub string) bool {
+ return true
+ },
+ },
+ },
+ {
+ nonce: mustMakeNonce(),
+ req: mustMakeRequest("GET", "ws://example.org", http.Header{
+ headerUpgrade: []string{"WEBSOCKET"},
+ headerConnection: []string{"UPGRADE"},
+ headerSecVersion: []string{"13"},
+ }),
+ res: mustMakeResponse(101, http.Header{
+ headerUpgrade: []string{"websocket"},
+ headerConnection: []string{"Upgrade"},
+ }),
+ cfg: &UpgradeConfig{
+ Protocol: func(sub string) bool {
+ return true
+ },
+ },
+ },
+ {
+ nonce: mustMakeNonce(),
+ req: mustMakeRequest("GET", "ws://example.org", http.Header{
+ headerUpgrade: []string{"websocket"},
+ headerConnection: []string{"Upgrade"},
+ headerSecVersion: []string{"13"},
+ headerSecProtocol: []string{"a", "b", "c", "d"},
+ }),
+ res: mustMakeResponse(101, http.Header{
+ headerUpgrade: []string{"websocket"},
+ headerConnection: []string{"Upgrade"},
+ headerSecProtocol: []string{"b"},
+ }),
+ hs: Handshake{Protocol: "b"},
+ cfg: &UpgradeConfig{
+ Protocol: SelectFromSlice([]string{"b", "d"}),
+ },
+ },
+ // TODO(gobwas) uncomment after selectExtension is ready.
+ //{
+ // nonce: mustMakeNonce(),
+ // req: mustMakeRequest("GET", "ws://example.org", http.Header{
+ // headerUpgrade: []string{"websocket"},
+ // headerConnection: []string{"Upgrade"},
+ // headerSecVersion: []string{"13"},
+ // headerSecExtensions: []string{"a", "b", "c", "d"},
+ // }),
+ // res: mustMakeResponse(101, http.Header{
+ // headerUpgrade: []string{"websocket"},
+ // headerConnection: []string{"Upgrade"},
+ // headerSecExtensions: []string{"b", "d"},
+ // }),
+ // hs: Handshake{Extensions: ["b", "d"]},
+ // cfg: &UpgradeConfig{
+ // Extension: SelectFromSlice([]string{"b", "d"}),
+ // },
+ //},
+ } {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ if test.nonce != nil {
+ test.req.Header.Set(headerSecKey, string(test.nonce))
+ test.res.Header.Set(headerSecAccept, string(makeAccept(test.nonce)))
+ }
+
+ res := newRecorder()
+ _, _, hs, err := Upgrade(test.req, res, test.cfg)
+ if test.err != err {
+ t.Errorf("expected error to be '%v', got '%v'", test.err, err)
+ return
+ }
+
+ actRespBts := sortHeaders(res.Bytes())
+ expRespBts := sortHeaders(dumpResponse(test.res))
+ if !bytes.Equal(actRespBts, expRespBts) {
+ t.Errorf(
+ "unexpected http response:\n---- act:\n%s\n---- want:\n%s\n====",
+ actRespBts, expRespBts,
+ )
+ return
+ }
+
+ if !reflect.DeepEqual(hs, test.hs) {
+ t.Errorf("unexpected handshake: %#v; want %#v", hs, test.hs)
+ }
+ })
+ }
+}
+
+func BenchmarkUpgrade(b *testing.B) {
+ bts101 := []byte("HTTP/1.1 101")
+ for _, bench := range []struct {
+ label string
+ req *http.Request
+ cfg *UpgradeConfig
+ }{
+ {
+ label: "base",
+ req: mustMakeRequest("GET", "ws://example.org", http.Header{
+ headerUpgrade: []string{"websocket"},
+ headerConnection: []string{"Upgrade"},
+ headerSecVersion: []string{"13"},
+ headerSecKey: []string{string(mustMakeNonce())},
+ }),
+ cfg: &UpgradeConfig{
+ Protocol: func(sub string) bool {
+ return true
+ },
+ },
+ },
+ {
+ label: "uppercase",
+ req: mustMakeRequest("GET", "ws://example.org", http.Header{
+ headerUpgrade: []string{"WEBSOCKET"},
+ headerConnection: []string{"UPGRADE"},
+ headerSecVersion: []string{"13"},
+ headerSecKey: []string{string(mustMakeNonce())},
+ }),
+ cfg: &UpgradeConfig{
+ Protocol: func(sub string) bool {
+ return true
+ },
+ },
+ },
+ } {
+ b.Run(bench.label, func(b *testing.B) {
+ b.ReportAllocs()
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ res := newRecorder()
+ _, _, _, err := Upgrade(bench.req, res, bench.cfg)
+ if err != nil {
+ b.Fatal(err)
+ }
+ if !bytes.HasPrefix(res.Body.Bytes(), bts101) {
+ b.Fatalf("unexpected http status code: %v\n%s", res.Code, res.Body.String())
+ }
+ }
+ })
+ })
+ }
+}
+
+func TestSelectProtocol(t *testing.T) {
+ for i, test := range []struct {
+ header string
+ }{
+ {"jsonrpc, soap, grpc"},
+ } {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ exp := strings.Split(test.header, ",")
+ for i, p := range exp {
+ exp[i] = strings.TrimSpace(p)
+ }
+
+ var calls []string
+ selectProtocol(test.header, func(s string) bool {
+ calls = append(calls, s)
+ return false
+ })
+
+ if !reflect.DeepEqual(calls, exp) {
+ t.Errorf("selectProtocol(%q, fn); called fn with %v; want %v", test.header, calls, exp)
+ }
+ })
+ }
+}
+
+func TestHasToken(t *testing.T) {
+ for i, test := range []struct {
+ header string
+ token string
+ exp bool
+ }{
+ {"Keep-Alive, Close, Upgrade", "upgrade", true},
+ {"Keep-Alive, Close, upgrade, hello", "upgrade", true},
+ {"Keep-Alive, Close, hello", "upgrade", false},
+ } {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ if has := hasToken(test.header, test.token); has != test.exp {
+ t.Errorf("hasToken(%q, %q) = %v; want %v", test.header, test.token, has, test.exp)
+ }
+ })
+ }
+}
+
+func BenchmarkHasToken(b *testing.B) {
+ for i, bench := range []struct {
+ header string
+ token string
+ }{
+ {"Keep-Alive, Close, Upgrade", "upgrade"},
+ {"Keep-Alive, Close, upgrade, hello", "upgrade"},
+ {"Keep-Alive, Close, hello", "upgrade"},
+ } {
+ b.Run(fmt.Sprintf("#%d", i), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _ = hasToken(bench.header, bench.token)
+ }
+ })
+ }
+}
+
+func TestSelectExtensions(t *testing.T) {
+
+}
+
+func BenchmarkSelectProtocol(b *testing.B) {
+ for _, bench := range []struct {
+ label string
+ header string
+ accept func(string) bool
+ }{
+ {
+ label: "never accept",
+ header: "jsonrpc, soap, grpc",
+ accept: func(s string) bool {
+ return len(s)%2 == 2 // never ok
+ },
+ },
+ {
+ label: "from slice",
+ header: "a, b, c, d, e, f, g",
+ accept: SelectFromSlice([]string{"g", "f", "e", "d"}),
+ },
+ {
+ label: "uniq 1024 from slise",
+ header: strings.Join(randProtocols(1024, 16), ", "),
+ accept: SelectFromSlice(randProtocols(1024, 17)),
+ },
+ } {
+ b.Run(fmt.Sprintf("#%s_optimized", bench.label), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ selectProtocol(bench.header, bench.accept)
+ }
+ })
+ }
+}
+
+func randProtocols(n, m int) []string {
+ ret := make([]string, n)
+ bts := make([]byte, m)
+ uniq := map[string]bool{}
+ for i := 0; i < n; i++ {
+ for {
+ for j := 0; j < m; j++ {
+ bts[j] = byte(rand.Intn('x'-'a') + 'a')
+ }
+ str := string(bts)
+ if _, has := uniq[str]; !has {
+ ret[i] = str
+ break
+ }
+ }
+ }
+ return ret
+}
+func dumpRequest(req *http.Request) []byte {
+ bts, err := httputil.DumpRequest(req, true)
+ if err != nil {
+ panic(err)
+ }
+ return bts
+}
+
+func dumpResponse(res *http.Response) []byte {
+ cleanClose := !res.Close
+ if cleanClose {
+ for _, v := range res.Header[headerConnection] {
+ if v == "close" {
+ cleanClose = false
+ break
+ }
+ }
+ }
+
+ bts, err := httputil.DumpResponse(res, true)
+ if err != nil {
+ panic(err)
+ }
+
+ if cleanClose {
+ bts = bytes.Replace(bts, []byte("Connection: close\r\n"), nil, -1)
+ }
+
+ return bts
+}
+
+type headersBytes [][]byte
+
+func (h headersBytes) Len() int { return len(h) }
+func (h headersBytes) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+func (h headersBytes) Less(i, j int) bool { return string(h[i]) < string(h[j]) }
+
+func sortHeaders(bts []byte) []byte {
+ lines := bytes.Split(bts, []byte("\r\n"))
+ if len(lines) <= 1 {
+ return bts
+ }
+ sort.Sort(headersBytes(lines[1 : len(lines)-2]))
+ return bytes.Join(lines, []byte("\r\n"))
+}
+
+type recorder struct {
+ *httptest.ResponseRecorder
+ hijacked bool
+}
+
+func newRecorder() *recorder {
+ return &recorder{
+ ResponseRecorder: httptest.NewRecorder(),
+ }
+}
+
+func (r *recorder) Bytes() []byte {
+ if r.hijacked {
+ return r.ResponseRecorder.Body.Bytes()
+ }
+ return dumpResponse(r.Result())
+}
+
+//go:linkname httpPutBufioReader net/http.putBufioReader
+func httpPutBufioReader(*bufio.Reader)
+
+//go:linkname httpPutBufioWriter net/http.putBufioWriter
+func httpPutBufioWriter(*bufio.Writer)
+
+//go:linkname httpNewBufioReader net/http.newBufioReader
+func httpNewBufioReader(io.Reader) *bufio.Reader
+
+//go:linkname httpNewBufioWriterSize net/http.newBufioWriterSize
+func httpNewBufioWriterSize(io.Writer, int) *bufio.Writer
+
+func (r *recorder) Hijack() (conn net.Conn, brw *bufio.ReadWriter, err error) {
+ if r.hijacked {
+ err = fmt.Errorf("already hijacked")
+ return
+ }
+
+ r.hijacked = true
+
+ buf := r.ResponseRecorder.Body
+
+ conn = stubConn{
+ read: buf.Read,
+ write: buf.Write,
+ close: func() error { return nil },
+ }
+
+ // Use httpNewBufio* linked functions here to make
+ // benchmark more closer to real life usage.
+ br := httpNewBufioReader(buf)
+ bw := httpNewBufioWriterSize(buf, 4<<10)
+
+ brw = bufio.NewReadWriter(br, bw)
+
+ return
+}
+
+func mustMakeRequest(method, url string, headers http.Header) *http.Request {
+ req, err := http.NewRequest(method, url, nil)
+ if err != nil {
+ panic(err)
+ }
+ req.Header = headers
+ return req
+}
+
+func mustMakeResponse(code int, headers http.Header) *http.Response {
+ res := &http.Response{
+ StatusCode: code,
+ Status: http.StatusText(code),
+ Header: headers,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: -1,
+ }
+ return res
+}
+
+func mustMakeNonce() []byte {
+ b := make([]byte, nonceSize)
+ newNonce(b)
+ return b
+}
diff --git a/server_test.s b/server_test.s
new file mode 100644
index 0000000..e69de29
diff --git a/util.go b/util.go
new file mode 100644
index 0000000..e16abc5
--- /dev/null
+++ b/util.go
@@ -0,0 +1,25 @@
+package ws
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+func strToBytes(str string) []byte {
+ s := *(*reflect.StringHeader)(unsafe.Pointer(&str))
+ b := &reflect.SliceHeader{Data: s.Data, Len: s.Len, Cap: s.Len}
+ return *(*[]byte)(unsafe.Pointer(b))
+}
+
+func btsToString(bts []byte) string {
+ b := *(*reflect.SliceHeader)(unsafe.Pointer(&bts))
+ s := &reflect.StringHeader{Data: b.Data, Len: b.Len}
+ return *(*string)(unsafe.Pointer(s))
+}
+
+func min(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}
diff --git a/write.go b/write.go
new file mode 100644
index 0000000..a3056be
--- /dev/null
+++ b/write.go
@@ -0,0 +1,82 @@
+package ws
+
+import (
+ "encoding/binary"
+ "io"
+)
+
+const (
+ bit0 = 0x80
+ bit1 = 0x40
+ bit2 = 0x20
+ bit3 = 0x10
+ bit4 = 0x08
+ bit5 = 0x04
+ bit6 = 0x02
+ bit7 = 0x01
+
+ len16 = int64(^(uint16(0)))
+ len64 = int64(^(uint64(0)) >> 1)
+)
+
+func WriteHeader(w io.Writer, h Header) error {
+ size := 2
+
+ var lenByte byte
+ switch {
+ case h.Length < 126:
+ lenByte = byte(h.Length)
+ size += 0
+
+ case h.Length <= len16:
+ lenByte = 126
+ size += 2
+
+ case h.Length <= len64:
+ lenByte = 127
+ size += 8
+
+ default:
+ return ErrHeaderLengthUnexpected
+ }
+
+ if h.Mask != nil {
+ lenByte |= bit0
+ size += 4
+ }
+
+ bts := make([]byte, size)
+
+ if h.Fin {
+ bts[0] |= bit0
+ }
+ bts[0] |= h.Rsv << 4
+ bts[0] |= byte(h.OpCode)
+ bts[1] = lenByte
+
+ maskPos := 2 // after fin, rsv and op code byte and length byte.
+ switch {
+ case lenByte == 126:
+ binary.BigEndian.PutUint16(bts[2:], uint16(h.Length))
+ maskPos += 2
+ case lenByte == 127:
+ binary.BigEndian.PutUint64(bts[2:], uint64(h.Length))
+ maskPos += 8
+ }
+
+ if h.Mask != nil {
+ copy(bts[maskPos:], h.Mask)
+ }
+
+ _, err := w.Write(bts)
+ return err
+}
+
+func WriteFrame(w io.Writer, f Frame) error {
+ err := WriteHeader(w, f.Header)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(f.Payload)
+ return err
+}
diff --git a/write_test.go b/write_test.go
new file mode 100644
index 0000000..041bf7c
--- /dev/null
+++ b/write_test.go
@@ -0,0 +1,73 @@
+package ws
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "testing"
+)
+
+func TestWriteHeader(t *testing.T) {
+ for i, test := range RWCases {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ buf := &bytes.Buffer{}
+ err := WriteHeader(buf, test.Header)
+ if test.Err && err == nil {
+ t.Errorf("expected error, got nil")
+ }
+ if !test.Err && err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if test.Err {
+ return
+ }
+ if bts := buf.Bytes(); !bytes.Equal(bts, test.Data) {
+ t.Errorf("WriteHeader()\nwrote:\n\t%08b\nwant:\n\t%08b", bts, test.Data)
+ }
+ })
+ }
+}
+
+func BenchmarkWriteHeader(b *testing.B) {
+ for _, bench := range []struct {
+ label string
+ header Header
+ }{
+ {
+ "ping", Header{
+ OpCode: OpPing,
+ Fin: true,
+ },
+ },
+ {
+ "text16", Header{
+ OpCode: OpText,
+ Fin: true,
+ Length: int64(^(uint16(0))),
+ },
+ },
+ {
+ "text64", Header{
+ OpCode: OpText,
+ Fin: true,
+ Length: int64(^(uint64(0)) >> 1),
+ },
+ },
+ {
+ "text64mask", Header{
+ OpCode: OpText,
+ Fin: true,
+ Length: int64(^(uint64(0)) >> 1),
+ Mask: []byte("mask"),
+ },
+ },
+ } {
+ b.Run(bench.label, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if err := WriteHeader(ioutil.Discard, bench.header); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
diff --git a/wsutil/cipher.go b/wsutil/cipher.go
new file mode 100644
index 0000000..f0c9d3a
--- /dev/null
+++ b/wsutil/cipher.go
@@ -0,0 +1,30 @@
+package wsutil
+
+import (
+ "io"
+
+ "github.com/gobwas/ws"
+)
+
+type CipherReader struct {
+ r io.Reader
+ mask []byte
+ pos int
+}
+
+func NewCipherReader(r io.Reader, mask []byte) *CipherReader {
+ return &CipherReader{r, mask, 0}
+}
+
+func (c *CipherReader) Reset(r io.Reader, mask []byte) {
+ c.r = r
+ c.mask = mask
+ c.pos = 0
+}
+
+func (c *CipherReader) Read(p []byte) (n int, err error) {
+ n, err = c.r.Read(p)
+ ws.Cipher(p[:n], c.mask, c.pos)
+ c.pos += n
+ return
+}
diff --git a/wsutil/cipher_test.go b/wsutil/cipher_test.go
new file mode 100644
index 0000000..7fcd03c
--- /dev/null
+++ b/wsutil/cipher_test.go
@@ -0,0 +1,50 @@
+package wsutil
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "reflect"
+ "testing"
+
+ . "github.com/gobwas/ws"
+)
+
+func TestCipherReader(t *testing.T) {
+ for i, test := range []struct {
+ label string
+ data []byte
+ chop int
+ }{
+ {
+ label: "simple",
+ data: []byte("hello, websockets!"),
+ chop: 512,
+ },
+ {
+ label: "chopped",
+ data: []byte("hello, websockets!"),
+ chop: 3,
+ },
+ } {
+ t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
+ mask := NewMask()
+ masked := make([]byte, len(test.data))
+ copy(masked, test.data)
+ Cipher(masked, mask, 0)
+
+ src := &chopReader{bytes.NewReader(masked), test.chop}
+ rd := NewCipherReader(src, mask)
+
+ bts, err := ioutil.ReadAll(rd)
+ if err != nil {
+ t.Errorf("unexpected error: %s", err)
+ return
+ }
+ if !reflect.DeepEqual(bts, test.data) {
+ t.Errorf("read data is not equal:\n\tact:\t%#v\n\texp:\t%#x\n", bts, test.data)
+ return
+ }
+ })
+ }
+}
diff --git a/wsutil/reader.go b/wsutil/reader.go
new file mode 100644
index 0000000..f611af5
--- /dev/null
+++ b/wsutil/reader.go
@@ -0,0 +1,155 @@
+package wsutil
+
+import (
+ "io"
+ "io/ioutil"
+ "strconv"
+
+ "github.com/gobwas/pool/pbytes"
+ "github.com/gobwas/ws"
+)
+
+type FrameHandler func(h ws.Header, r io.Reader) error
+
+func ControlHandler(w io.Writer, state ws.State) FrameHandler {
+ return func(h ws.Header, rd io.Reader) (err error) {
+ // int(h.Length) is safe cause control frame could be < 125 bytes length.
+ p := pbytes.GetBufLen(int(h.Length))
+ defer pbytes.PutBuf(p)
+
+ _, err = io.ReadFull(rd, p)
+ if err != nil {
+ return
+ }
+
+ var f ws.Frame
+
+ switch h.OpCode {
+ default:
+ return
+ case ws.OpPing:
+ f = ws.NewPongFrame(p)
+ case ws.OpClose:
+ code, reason := ws.ParseCloseFrameDataUnsafe(p)
+ if code.Empty() {
+ code = ws.StatusNoStatusRcvd
+ f = ws.CloseFrame
+ } else if err = ws.CheckCloseFrameData(code, reason); err != nil {
+ code = ws.StatusProtocolError
+ reason = err.Error()
+ f = ws.NewCloseFrame(code, reason)
+ } else {
+ // [RFC6455:5.5.1]:
+ // If an endpoint receives a Close frame and did not previously
+ // send a Close frame, the endpoint MUST send a Close frame in
+ // response. (When sending a Close frame in response, the endpoint
+ // typically echos the status code it received.)
+ f = ws.NewCloseFrame(code, "")
+ }
+ err = ErrClosed{code, reason}
+ }
+
+ if state.Is(ws.StateClientSide) {
+ f = ws.MaskFrame(f)
+ }
+ if ew := ws.WriteFrame(w, f); ew != nil {
+ err = ew
+ }
+
+ return
+ }
+}
+
+func NextReader(r io.Reader, s ws.State) (h ws.Header, rd *Reader, err error) {
+ rd = NewReader(r, s)
+ h, err = rd.Next()
+ return
+}
+
+type handler struct {
+ continuation FrameHandler
+ intermediate FrameHandler
+}
+
+type Reader struct {
+ src io.Reader
+ state ws.State
+ payload io.Reader
+ handler handler
+}
+
+type ErrClosed struct {
+ code ws.StatusCode
+ reason string
+}
+
+func (err ErrClosed) Error() string {
+ return "ws closed: " + strconv.FormatUint(uint64(err.code), 10) + " " + err.reason
+}
+
+func NewReader(r io.Reader, s ws.State) *Reader {
+ return &Reader{
+ src: r,
+ state: s,
+ }
+}
+
+func (r *Reader) HandleContinuation(h FrameHandler) { r.handler.continuation = h }
+func (r *Reader) HandleIntermediate(h FrameHandler) { r.handler.intermediate = h }
+
+func (r *Reader) Read(p []byte) (n int, err error) {
+ if r.payload == nil {
+ _, err = r.Next()
+ if err != nil || r.payload == nil {
+ return
+ }
+ }
+
+ n, err = r.payload.Read(p)
+
+ if err == io.EOF {
+ r.payload = nil
+ if r.state.Is(ws.StateFragmented) {
+ err = nil
+ }
+ }
+
+ return
+}
+
+func (r *Reader) Next() (h ws.Header, err error) {
+ h, err = ws.ReadHeader(r.src)
+ if err != nil {
+ return
+ }
+ if err = ws.CheckHeader(h, r.state); err != nil {
+ return
+ }
+
+ src := io.LimitReader(r.src, h.Length)
+ rd := src
+ if mask := h.Mask; mask != nil {
+ rd = NewCipherReader(rd, mask)
+ }
+
+ if r.state.Is(ws.StateFragmented) && h.OpCode.IsControl() {
+ if hi := r.handler.intermediate; hi != nil {
+ err = hi(h, rd)
+ }
+ if err == nil {
+ _, err = io.Copy(ioutil.Discard, src)
+ }
+ return
+ }
+
+ if h.OpCode == ws.OpContinuation {
+ if hc := r.handler.continuation; hc != nil {
+ err = hc(h, rd)
+ }
+ }
+
+ r.state = r.state.SetOrClearIf(!h.Fin, ws.StateFragmented)
+ r.payload = rd
+
+ return
+}
diff --git a/wsutil/reader_test.go b/wsutil/reader_test.go
new file mode 100644
index 0000000..4c8f199
--- /dev/null
+++ b/wsutil/reader_test.go
@@ -0,0 +1,140 @@
+package wsutil
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "testing"
+
+ . "github.com/gobwas/ws"
+)
+
+func TestReader(t *testing.T) {
+ for i, test := range []struct {
+ label string
+ seq []Frame
+ chop int
+ exp []byte
+ err error
+ }{
+ {
+ label: "empty",
+ seq: []Frame{},
+ err: io.EOF,
+ },
+ {
+ label: "single",
+ seq: []Frame{
+ NewTextFrame("Привет, Мир!"),
+ },
+ exp: []byte("Привет, Мир!"),
+ },
+ {
+ label: "single_masked",
+ seq: []Frame{
+ MaskFrame(NewTextFrame("Привет, Мир!")),
+ },
+ exp: []byte("Привет, Мир!"),
+ },
+ {
+ label: "fragmented",
+ seq: []Frame{
+ NewFrame(OpText, false, []byte("Привет,")),
+ NewFrame(OpContinuation, false, []byte(" о дивный,")),
+ NewFrame(OpContinuation, false, []byte(" новый ")),
+ NewFrame(OpContinuation, true, []byte("Мир!")),
+
+ NewTextFrame("Hello, Brave New World!"),
+ },
+ exp: []byte("Привет, о дивный, новый Мир!"),
+ },
+ {
+ label: "fragmented_masked",
+ seq: []Frame{
+ MaskFrame(NewFrame(OpText, false, []byte("Привет,"))),
+ MaskFrame(NewFrame(OpContinuation, false, []byte(" о дивный,"))),
+ MaskFrame(NewFrame(OpContinuation, false, []byte(" новый "))),
+ MaskFrame(NewFrame(OpContinuation, true, []byte("Мир!"))),
+
+ MaskFrame(NewTextFrame("Hello, Brave New World!")),
+ },
+ exp: []byte("Привет, о дивный, новый Мир!"),
+ },
+ {
+ label: "fragmented_and_control",
+ seq: []Frame{
+ NewFrame(OpText, false, []byte("Привет,")),
+ NewFrame(OpPing, true, nil),
+ NewFrame(OpContinuation, false, []byte(" о дивный,")),
+ NewFrame(OpPing, true, nil),
+ NewFrame(OpContinuation, false, []byte(" новый ")),
+ NewFrame(OpPing, true, nil),
+ NewFrame(OpPing, true, []byte("ping info")),
+ NewFrame(OpContinuation, true, []byte("Мир!")),
+ },
+ exp: []byte("Привет, о дивный, новый Мир!"),
+ },
+ {
+ label: "fragmented_and_control_mask",
+ seq: []Frame{
+ MaskFrame(NewFrame(OpText, false, []byte("Привет,"))),
+ MaskFrame(NewFrame(OpPing, true, nil)),
+ MaskFrame(NewFrame(OpContinuation, false, []byte(" о дивный,"))),
+ MaskFrame(NewFrame(OpPing, true, nil)),
+ MaskFrame(NewFrame(OpContinuation, false, []byte(" новый "))),
+ MaskFrame(NewFrame(OpPing, true, nil)),
+ MaskFrame(NewFrame(OpPing, true, []byte("ping info"))),
+ MaskFrame(NewFrame(OpContinuation, true, []byte("Мир!"))),
+ },
+ exp: []byte("Привет, о дивный, новый Мир!"),
+ },
+ } {
+ t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
+ // Prepare input.
+ buf := &bytes.Buffer{}
+ for _, f := range test.seq {
+ if err := WriteFrame(buf, f); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ conn := &chopReader{
+ src: bytes.NewReader(buf.Bytes()),
+ sz: test.chop,
+ }
+
+ var bts []byte
+ _, reader, err := NextReader(conn, 0)
+ if err == nil {
+ bts, err = ioutil.ReadAll(reader)
+ }
+ if err != test.err {
+ t.Errorf("unexpected error; got %v; want %v", err, test.err)
+ return
+ }
+ if test.err == nil && !bytes.Equal(bts, test.exp) {
+ t.Errorf(
+ "ReadAll from reader:\nact:\t%#x\nexp:\t%#x\nact:\t%s\nexp:\t%s\n",
+ bts, test.exp, string(bts), string(test.exp),
+ )
+ }
+ })
+ }
+}
+
+type chopReader struct {
+ src io.Reader
+ sz int
+}
+
+func (c chopReader) Read(p []byte) (n int, err error) {
+ sz := c.sz
+ if sz == 0 {
+ sz = 1
+ }
+ if sz > len(p) {
+ sz = len(p)
+ }
+ return c.src.Read(p[:sz])
+}
diff --git a/wsutil/utf8.go b/wsutil/utf8.go
new file mode 100644
index 0000000..e304b81
--- /dev/null
+++ b/wsutil/utf8.go
@@ -0,0 +1,116 @@
+package wsutil
+
+import (
+ "fmt"
+ "io"
+)
+
+var ErrInvalidUtf8 = fmt.Errorf("invalid utf8")
+
+type UTF8Reader struct {
+ r io.Reader
+ state uint32
+ codep uint32
+ buf []byte
+}
+
+func NewUTF8Reader(r io.Reader) *UTF8Reader {
+ return &UTF8Reader{r: r}
+}
+
+func (u *UTF8Reader) Reset(r io.Reader) {
+ u.state = 0
+ u.codep = 0
+ u.SetSource(r)
+}
+
+func (u *UTF8Reader) SetSource(r io.Reader) {
+ u.r = r
+}
+
+func (u *UTF8Reader) Read(p []byte) (n int, err error) {
+ n, err = u.r.Read(p)
+
+ s, c := u.state, u.codep
+ for i := 0; i < n; i++ {
+ c, s = decode(s, c, p[i])
+ if s == utf8Reject {
+ u.state = s
+ return i, ErrInvalidUtf8
+ }
+ }
+ u.state, u.codep = s, c
+
+ if err == io.EOF && u.state != utf8Accept {
+ err = ErrInvalidUtf8
+ }
+
+ return
+}
+
+func (u *UTF8Reader) Close() error {
+ if u.state != utf8Accept {
+ return ErrInvalidUtf8
+ }
+ return nil
+}
+
+// Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
+//
+// Copyright (c) 2008-2009 Bjoern Hoehrmann
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to
+// deal in the Software without restriction, including without limitation the
+// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+// sell copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+// IN THE SOFTWARE.
+
+const (
+ utf8Accept = 0
+ utf8Reject = 12
+)
+
+var utf8d = [...]byte{
+ // The first part of the table maps bytes to character classes that
+ // to reduce the size of the transition table and create bitmasks.
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
+
+ // The second part is a transition table that maps a combination
+ // of a state of the automaton and a character class to a state.
+ 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
+ 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12,
+ 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12,
+ 12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12,
+ 12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
+}
+
+func decode(state, codep uint32, b byte) (uint32, uint32) {
+ t := uint32(utf8d[b])
+
+ if state != utf8Accept {
+ codep = (uint32(b) & 0x3f) | (codep << 6)
+ } else {
+ codep = (0xff >> t) & uint32(b)
+ }
+
+ return codep, uint32(utf8d[256+state+t])
+}
diff --git a/wsutil/utf8_test.go b/wsutil/utf8_test.go
new file mode 100644
index 0000000..9a4d8fa
--- /dev/null
+++ b/wsutil/utf8_test.go
@@ -0,0 +1,208 @@
+package wsutil
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "testing"
+)
+
+func TestUTF8ReaderReadFull(t *testing.T) {
+ for i, test := range []struct {
+ hex string
+ errRead bool
+ errClose bool
+ n int
+ chop int
+ }{
+ {
+ hex: "cebae1bdb9cf83cebcceb5eda080656469746564",
+ errClose: true,
+ errRead: true,
+ n: 12,
+ },
+ {
+ hex: "cebae1bdb9cf83cebcceb5eda080656469746564",
+ errRead: true,
+ errClose: true,
+ n: 12,
+ chop: 1,
+ },
+ {
+ hex: "7f7f7fdf",
+ errRead: false,
+ errClose: true,
+ n: 4,
+ },
+ {
+ hex: "dfbf",
+ n: 2,
+ },
+ } {
+ t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
+ bts, err := hex.DecodeString(test.hex)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ chop := test.chop
+ if chop <= 0 {
+ chop = len(bts)
+ }
+
+ src := bytes.NewReader(bts)
+ r := NewUTF8Reader(chopReader{src, chop})
+
+ p := make([]byte, src.Len())
+ n, err := io.ReadFull(r, p)
+
+ if test.errRead && err == nil {
+ t.Errorf("expected read error; got nil")
+ }
+ if !test.errRead && err != nil {
+ t.Errorf("unexpected read error: %s", err)
+ }
+ if n != test.n {
+ t.Errorf("ReadFull() read %d; want %d", n, test.n)
+ }
+
+ err = r.Close()
+ if test.errClose && err == nil {
+ t.Errorf("expected close error; got nil")
+ }
+ if !test.errClose && err != nil {
+ t.Errorf("unexpected close error: %s", err)
+ }
+ })
+ }
+}
+
+func TestUTF8Reader(t *testing.T) {
+ for i, test := range []struct {
+ label string
+
+ data []byte
+ // or
+ hex string
+
+ chop int
+
+ err bool
+ at int
+ }{
+ {
+ data: []byte("hello, world!"),
+ chop: 2,
+ },
+ {
+ data: []byte{0x7f, 0xf0},
+ err: true,
+ at: 2,
+ chop: 1,
+ },
+ {
+ data: []byte{0x7f, 0xf0},
+ err: true,
+ at: 2,
+ chop: 1,
+ },
+ {
+ hex: "48656c6c6f2dc2b540c39fc3b6c3a4c3bcc3a0c3a12d5554462d382121",
+ chop: 1,
+ },
+ } {
+ t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
+ data := test.data
+ if h := test.hex; h != "" {
+ var err error
+ if data, err = hex.DecodeString(h); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ cr := &chopReader{
+ src: bytes.NewReader(data),
+ sz: test.chop,
+ }
+
+ r := NewUTF8Reader(cr)
+
+ bts := make([]byte, 2*len(data))
+
+ var (
+ i, n int
+ err error
+ )
+ for {
+ n, err = r.Read(bts[i:])
+ i += n
+ if err != nil {
+ if err == io.EOF {
+ err = nil
+ }
+ bts = bts[:i]
+ break
+ }
+ }
+ if err == nil {
+ err = r.Close()
+ }
+ if test.err && err == nil {
+ t.Errorf("want error; got nil")
+ return
+ }
+ if !test.err && err != nil {
+ t.Errorf("unexpected error: %s", err)
+ return
+ }
+ if test.err && err == ErrInvalidUtf8 && i != test.at {
+ t.Errorf("received error at %d; want at %d", i, test.at)
+ return
+ }
+ if !test.err && !bytes.Equal(bts, data) {
+ t.Errorf("bytes are not equal")
+ }
+ })
+ }
+}
+
+func BenchmarkUTF8Reader(b *testing.B) {
+ for i, bench := range []struct {
+ label string
+ data []byte
+ chop int
+ err bool
+ }{
+ {
+ data: bytes.Repeat([]byte("x"), 1024),
+ chop: 128,
+ },
+ {
+ data: append(
+ bytes.Repeat([]byte("x"), 1024),
+ append(
+ []byte{0x7f, 0xf0},
+ bytes.Repeat([]byte("x"), 128)...,
+ )...,
+ ),
+ err: true,
+ chop: 7,
+ },
+ } {
+ b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ cr := &chopReader{
+ src: bytes.NewReader(bench.data),
+ sz: bench.chop,
+ }
+ r := NewUTF8Reader(cr)
+ _, err := ioutil.ReadAll(r)
+ if !bench.err && err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
diff --git a/wsutil/writer.go b/wsutil/writer.go
new file mode 100644
index 0000000..39fee88
--- /dev/null
+++ b/wsutil/writer.go
@@ -0,0 +1,162 @@
+package wsutil
+
+import (
+ "io"
+
+ "github.com/gobwas/pool/pbytes"
+ "github.com/gobwas/ws"
+)
+
+const defaultWriteBuffer = 4096
+
+type WriterConfig struct {
+ Op ws.OpCode
+ Mask bool
+}
+
+type Writer struct {
+ wr io.Writer
+ buf []byte
+ n int
+
+ dirty bool
+ frames int
+
+ op ws.OpCode
+ mask bool
+}
+
+func NextWriter(dst io.Writer, op ws.OpCode, mask bool) *Writer {
+ return NewWriterSize(dst, 0, WriterConfig{Op: op, Mask: mask})
+}
+
+func NewWriter(dst io.Writer, c WriterConfig) *Writer {
+ return NewWriterSize(dst, defaultWriteBuffer, c)
+}
+
+func NewWriterSize(dst io.Writer, n int, c WriterConfig) *Writer {
+ if n <= 0 {
+ n = defaultWriteBuffer
+ }
+ return NewWriterBuffer(dst, make([]byte, n), c)
+}
+
+func NewWriterBuffer(wr io.Writer, buf []byte, c WriterConfig) *Writer {
+ return &Writer{
+ wr: wr,
+ buf: buf,
+ op: c.Op,
+ mask: c.Mask,
+ }
+}
+
+func (w *Writer) Write(p []byte) (n int, err error) {
+ // Even if len(p) == 0 we mark w as dirty,
+ // cause even empty p (and empty frame) may have a value.
+ w.dirty = true
+
+ if len(p) > len(w.buf) && w.n == 0 {
+ // Large write.
+ return w.write(p)
+ }
+ for {
+ nn := copy(w.buf[w.n:], p)
+ p = p[nn:]
+ w.n += nn
+ n += nn
+
+ if len(p) == 0 {
+ break
+ }
+
+ _, err = w.write(w.buf)
+ if err != nil {
+ break
+ }
+ w.n = 0
+ }
+ return
+}
+
+func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
+ var nn int
+ for {
+ if w.n == len(w.buf) { // buffer is full.
+ if _, err = w.write(w.buf); err != nil {
+ return
+ }
+ w.n = 0
+ }
+
+ nn, err = src.Read(w.buf[w.n:])
+ w.n += nn
+ n += int64(nn)
+ w.dirty = true
+
+ if err != nil {
+ break
+ }
+ }
+ if err == io.EOF {
+ err = nil
+ }
+ return
+}
+
+func (w *Writer) Flush() error {
+ _, err := w.flush()
+ return err
+}
+
+func (w *Writer) opCode() ws.OpCode {
+ if w.frames > 0 {
+ return ws.OpContinuation
+ } else {
+ return w.op
+ }
+}
+
+func (w *Writer) flush() (n int, err error) {
+ if w.n == 0 && !w.dirty {
+ return 0, nil
+ }
+
+ n, err = w.writeFrame(w.opCode(), w.buf[:w.n], true)
+ w.dirty = false
+ w.n = 0
+ w.frames = 0
+
+ return
+}
+
+func (w *Writer) write(p []byte) (n int, err error) {
+ return w.writeFrame(w.opCode(), p, false)
+}
+
+func (w *Writer) writeFrame(op ws.OpCode, p []byte, fin bool) (n int, err error) {
+ header := ws.Header{
+ OpCode: op,
+ Length: int64(len(p)),
+ Fin: fin,
+ }
+
+ payload := p
+ if w.mask {
+ header.Mask = ws.NewMask()
+
+ payload = pbytes.GetBufLen(len(p))
+ defer pbytes.PutBuf(payload)
+
+ copy(payload, p)
+ ws.Cipher(payload, header.Mask, 0)
+ }
+
+ err = ws.WriteHeader(w.wr, header)
+ if err == nil {
+ n, err = w.wr.Write(payload)
+ }
+
+ w.frames++
+
+ return
+}
diff --git a/wsutil/writer_test.go b/wsutil/writer_test.go
new file mode 100644
index 0000000..1247be3
--- /dev/null
+++ b/wsutil/writer_test.go
@@ -0,0 +1,252 @@
+package wsutil
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+
+ . "github.com/gobwas/ws"
+)
+
+func TestWriter(t *testing.T) {
+ for i, test := range []struct {
+ label string
+ config WriterConfig
+ size int
+ data [][]byte
+ expFrm []Frame
+ expBts []byte
+ }{
+ {
+ config: WriterConfig{Op: OpText},
+ },
+ {
+ config: WriterConfig{Op: OpText},
+ data: [][]byte{
+ []byte{},
+ },
+ expBts: MustCompileFrame(NewTextFrame("")),
+ },
+ {
+ config: WriterConfig{Op: OpText},
+ data: [][]byte{
+ []byte("hello, world!"),
+ },
+ expBts: MustCompileFrame(NewTextFrame("hello, world!")),
+ },
+ {
+ config: WriterConfig{Op: OpText, Mask: true},
+ data: [][]byte{
+ []byte("hello, world!"),
+ },
+ expFrm: []Frame{NewTextFrame("hello, world!")},
+ },
+ {
+ config: WriterConfig{Op: OpText},
+ size: 5,
+ data: [][]byte{
+ []byte("hello"),
+ []byte(", wor"),
+ []byte("ld!"),
+ },
+ expBts: bytes.Join(
+ bts(
+ MustCompileFrame(Frame{
+ Header: Header{
+ Fin: false,
+ OpCode: OpText,
+ Length: 5,
+ },
+ Payload: []byte("hello"),
+ }),
+ MustCompileFrame(Frame{
+ Header: Header{
+ Fin: false,
+ OpCode: OpContinuation,
+ Length: 5,
+ },
+ Payload: []byte(", wor"),
+ }),
+ MustCompileFrame(Frame{
+ Header: Header{
+ Fin: true,
+ OpCode: OpContinuation,
+ Length: 3,
+ },
+ Payload: []byte("ld!"),
+ }),
+ ),
+ nil,
+ ),
+ },
+ { // Large write case.
+ config: WriterConfig{Op: OpText},
+ size: 5,
+ data: [][]byte{
+ []byte("hello, world!"),
+ },
+ expBts: bytes.Join(
+ bts(
+ MustCompileFrame(Frame{
+ Header: Header{
+ Fin: false,
+ OpCode: OpText,
+ Length: 13,
+ },
+ Payload: []byte("hello, world!"),
+ }),
+ MustCompileFrame(Frame{
+ Header: Header{
+ Fin: true,
+ OpCode: OpContinuation,
+ Length: 0,
+ },
+ }),
+ ),
+ nil,
+ ),
+ },
+ } {
+ t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
+ buf := &bytes.Buffer{}
+ w := NewWriterSize(buf, test.size, test.config)
+
+ for _, p := range test.data {
+ _, err := w.Write(p)
+ if err != nil {
+ t.Fatalf("unexpected Write() error: %s", err)
+ }
+ }
+ if err := w.Flush(); err != nil {
+ t.Fatalf("unexpected Flush() error: %s", err)
+ }
+ if test.expBts != nil {
+ if bts := buf.Bytes(); !bytes.Equal(test.expBts, bts) {
+ t.Errorf(
+ "wrote bytes:\nact:\t%#x\nexp:\t%#x\nacth:\t%s\nexph:\t%s\n", bts, test.expBts,
+ pretty(frames(bts)), pretty(frames(test.expBts)),
+ )
+ }
+ }
+ if test.expFrm != nil {
+ act := omitMask(frames(buf.Bytes()))
+ exp := omitMask(test.expFrm)
+
+ if !reflect.DeepEqual(act, exp) {
+ t.Errorf(
+ "wrote frames (mask omitted):\nact:\t%s\nexp:\t%s\n",
+ pretty(act), pretty(exp),
+ )
+ }
+ }
+ })
+ }
+}
+
+func TestWriterReadFrom(t *testing.T) {
+ for i, test := range []struct {
+ label string
+ chop int
+ size int
+ data []byte
+ exp []Frame
+ n int64
+ }{
+ {
+ chop: 1,
+ size: 1,
+ data: []byte("golang"),
+ exp: []Frame{
+ Frame{Header: Header{Fin: false, Length: 1, OpCode: OpText}, Payload: []byte{'g'}},
+ Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'o'}},
+ Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'l'}},
+ Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'a'}},
+ Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'n'}},
+ Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'g'}},
+ Frame{Header: Header{Fin: true, Length: 0, OpCode: OpContinuation}},
+ },
+ n: 6,
+ },
+ {
+ chop: 1,
+ size: 4,
+ data: []byte("golang"),
+ exp: []Frame{
+ Frame{Header: Header{Fin: false, Length: 4, OpCode: OpText}, Payload: []byte("gola")},
+ Frame{Header: Header{Fin: true, Length: 2, OpCode: OpContinuation}, Payload: []byte("ng")},
+ },
+ n: 6,
+ },
+ {
+ size: 64,
+ data: []byte{},
+ exp: []Frame{
+ Frame{Header: Header{Fin: true, Length: 0, OpCode: OpText}},
+ },
+ n: 0,
+ },
+ } {
+ t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
+ dst := &bytes.Buffer{}
+ wr := NewWriterSize(dst, test.size, WriterConfig{Op: OpText})
+
+ chop := test.chop
+ if chop == 0 {
+ chop = 128
+ }
+ src := &chopReader{bytes.NewReader(test.data), chop}
+
+ n, err := wr.ReadFrom(src)
+ if err == nil {
+ err = wr.Flush()
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if n != test.n {
+ t.Errorf("ReadFrom() read out %d; want %d", n, test.n)
+ }
+ if frames := frames(dst.Bytes()); !reflect.DeepEqual(frames, test.exp) {
+ t.Errorf("ReadFrom() read frames:\n\tact:\t%s\n\texp:\t%s\n", pretty(frames), pretty(test.exp))
+ }
+ })
+ }
+}
+
+func frames(p []byte) (ret []Frame) {
+ r := bytes.NewReader(p)
+ for stop := false; !stop; {
+ f, err := ReadFrame(r)
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ panic(err)
+
+ }
+ if mask := f.Header.Mask; mask != nil {
+ Cipher(f.Payload, mask, 0)
+ }
+ ret = append(ret, f)
+ }
+ return
+}
+
+func pretty(f []Frame) string {
+ str := "\n"
+ for _, f := range f {
+ str += fmt.Sprintf("\t%#v\n\t%#x (%s)\n\t----\n", f.Header, f.Payload, f.Payload)
+ }
+ return str
+}
+
+func omitMask(f []Frame) []Frame {
+ for i := 0; i < len(f); i++ {
+ f[i].Header.Mask = []byte{0, 0, 0, 0}
+ }
+ return f
+}
+
+func bts(b ...[]byte) [][]byte { return b }