diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b31cdb4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +bin/ +reports/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2744317 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Sergey Kamardin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c7104fe --- /dev/null +++ b/Makefile @@ -0,0 +1,47 @@ +BENCH ?=. +BENCH_BASE?=master + +autobahn: + go build -o ./bin/autobahn ./example/autobahn + +test: + go test -cover ./... + +testrfc: PID:=$(shell mktemp -t autobahn.XXXX) +testrfc: autobahn + ./bin/autobahn & echo $$! > $(PID) + if [ -z "$$(ps | grep $$(cat $(PID)) | grep autobahn)" ]; then\ + echo "could not start autobahn";\ + exit 1;\ + fi;\ + wstest -m fuzzingclient -s ./example/autobahn/fuzzingclient.json + pkill -9 -F $(PID) + +rfc: testrfc + open ./example/autobahn/reports/servers/index.html + +bench: + go test -run=none -bench=$(BENCH) -benchmem + +benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD) +benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX) +benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX) +benchcmp: + if [ ! -z "$(shell git status -s)" ]; then\ + echo "could not compare with $(BENCH_BASE) – found unstaged changes";\ + exit 1;\ + fi;\ + if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\ + echo "comparing the same branches";\ + exit 1;\ + fi;\ + echo "benchmarking $(BENCH_BRANCH)...";\ + go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\ + echo "benchmarking $(BENCH_BASE)...";\ + git checkout -q $(BENCH_BASE);\ + go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\ + git checkout -q $(BENCH_BRANCH);\ + echo "\nresults:";\ + echo "========\n";\ + benchcmp $(BENCH_OLD) $(BENCH_NEW);\ + diff --git a/check.go b/check.go new file mode 100644 index 0000000..1437de1 --- /dev/null +++ b/check.go @@ -0,0 +1,141 @@ +package ws + +import ( + "fmt" + "unicode/utf8" +) + +// State represents state of websocket endpoint. +// It used by some functions to be more strict when checking compatibility with RFC6455. +type State uint8 + +const ( + // StateServerSide means that endpoint (caller) is a server. + StateServerSide State = 0x1 << iota + // StateServerSide means that endpoint (caller) is a client. + StateClientSide + // StateExtended means that extension was negotiated during handshake. + StateExtended + // StateFragmented means that endpoint (caller) has received fragmented + // frame and waits for continuation parts. + StateFragmented +) + +// Is checks whether the s has v enabled. +func (s State) Is(v State) bool { + return uint8(s)&uint8(v) != 0 +} + +// Set enables v state on s. +func (s State) Set(v State) State { + return s | v +} + +// Clear disables v state on s. +func (s State) Clear(v State) State { + return s & (^v) +} + +// SetOrClearIf enables or disables v state on s depending on cond. +func (s State) SetOrClearIf(cond bool, v State) (ret State) { + if cond { + ret = s.Set(v) + } else { + ret = s.Clear(v) + } + return +} + +// ProtocolError describes error during checking/parsing websocket frames or headers. +type ProtocolError error + +// Errors used by the protocol checkers. +var ( + ErrProtocolOpCodeReserved = ProtocolError(fmt.Errorf("use of reserved op code")) + ErrProtocolControlPayloadOverflow = ProtocolError(fmt.Errorf("control frame payload limit exceeded")) + ErrProtocolControlNotFinal = ProtocolError(fmt.Errorf("control frame is not final")) + ErrProtocolNonZeroRsv = ProtocolError(fmt.Errorf("non-zero rsv bits with no extension negotiated")) + ErrProtocolMaskRequired = ProtocolError(fmt.Errorf("frames from client to server must be masked")) + ErrProtocolMaskUnexpected = ProtocolError(fmt.Errorf("frames from server to client must be not masked")) + ErrProtocolContinuationExpected = ProtocolError(fmt.Errorf("unexpected non-continuation data frame")) + ErrProtocolContinuationUnexpected = ProtocolError(fmt.Errorf("unexpected continuation data frame")) + ErrProtocolStatusCodeNotInUse = ProtocolError(fmt.Errorf("status code is not in use")) + ErrProtocolStatusCodeApplicationLevel = ProtocolError(fmt.Errorf("status code is only application level")) + ErrProtocolStatusCodeNoMeaning = ProtocolError(fmt.Errorf("status code has no meaning yet")) + ErrProtocolStatusCodeUnknown = ProtocolError(fmt.Errorf("status code is not defined in spec")) + ErrProtocolInvalidUTF8 = ProtocolError(fmt.Errorf("invalid utf8 sequence in close reason")) +) + +// CheckHeader checks h to contain valid header data for given state s. +// +// Note that zero state (0) means that state is clean, +// neither server or client side, nor fragmented, nor extended. +func CheckHeader(h Header, s State) error { + if h.OpCode.IsReserved() { + return ErrProtocolOpCodeReserved + } + if h.OpCode.IsControl() { + if h.Length > MaxControlFramePayloadSize { + return ErrProtocolControlPayloadOverflow + } + if !h.Fin { + return ErrProtocolControlNotFinal + } + } + + switch { + // [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for + // non-zero values. If a nonzero value is received and none of the + // negotiated extensions defines the meaning of such a nonzero value, the + // receiving endpoint MUST _Fail the WebSocket Connection_. + case h.Rsv != 0 && !s.Is(StateExtended): + return ErrProtocolNonZeroRsv + + // [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked. + // In this case, a server MAY send a Close frame with a status code of 1002 (protocol error) + // as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client. + // A client MUST close a connection if it detects a masked frame. In this case, it MAY use the + // status code 1002 (protocol error) as defined in Section 7.4.1. + case s.Is(StateServerSide) && h.Mask == nil: + return ErrProtocolMaskRequired + case s.Is(StateClientSide) && h.Mask != nil: + return ErrProtocolMaskUnexpected + + // [RFC6455]: See detailed explanation in 5.4 section. + case s.Is(StateFragmented) && !h.OpCode.IsControl() && h.OpCode != OpContinuation: + return ErrProtocolContinuationExpected + case !s.Is(StateFragmented) && h.OpCode == OpContinuation: + return ErrProtocolContinuationUnexpected + } + + return nil +} + +// CheckCloseFrameData checks received close information +// to be valid RFC6455 compatible clsoe info. +// +// Note that code.Empty() or code.IsAppLevel() will raise error. +// +// If endpoint sends close frame without status code (with frame.Length = 0), +// application should not check its payload. +func CheckCloseFrameData(code StatusCode, reason string) error { + switch { + case code.IsNotUsed(): + return ErrProtocolStatusCodeNotInUse + + case code.IsProtocolReserved(): + return ErrProtocolStatusCodeApplicationLevel + + case code == StatusNoMeaningYet: + return ErrProtocolStatusCodeNoMeaning + + case code.IsProtocolSpec() && !code.IsProtocolDefined(): + return ErrProtocolStatusCodeUnknown + + case !utf8.ValidString(reason): + return ErrProtocolInvalidUTF8 + + default: + return nil + } +} diff --git a/cipher.go b/cipher.go new file mode 100644 index 0000000..d06e141 --- /dev/null +++ b/cipher.go @@ -0,0 +1,55 @@ +package ws + +import ( + "reflect" + "unsafe" +) + +// Cipher applies XOR cipher to the payload using mask. +// Offset is used to cipher chunked data (e.g. in io.Reader implementations). +// +// To convert masked data into unmasked data, or vice versa, the following +// algorithm is applied. The same algorithm applies regardless of the +// direction of the translation, e.g., the same steps are applied to +// mask the data as to unmask the data. +func Cipher(payload, mask []byte, offset int) { + if len(mask) != 4 { + return + } + + n := len(payload) + if n < 8 { + for i := 0; i < n; i++ { + payload[i] ^= mask[(offset+i)%4] + } + return + } + + // Calculate position in mask due to previously processed bytes number. + mpos := offset % 4 + // Count number of bytes will processed one by one from the begining of payload. + // Bitwise used to avoid additional if. + ln := (4 - mpos) & 0x0b + // Count number of bytes will processed one by one from the end of payload. + // This is done to process payload by 8 bytes in each iteration of main loop. + rn := (n - ln) % 8 + + for i := 0; i < ln; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + for i := n - rn; i < n; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + + ph := *(*reflect.SliceHeader)(unsafe.Pointer(&payload)) + mh := *(*reflect.SliceHeader)(unsafe.Pointer(&mask)) + + m := *(*uint32)(unsafe.Pointer(mh.Data)) + m2 := uint64(m)<<32 | uint64(m) + + // Process the rest of bytes as uint64. + for i := ln; i+8 <= n-rn; i += 8 { + v := (*uint64)(unsafe.Pointer(ph.Data + uintptr(i))) + *v = *v ^ m2 + } +} diff --git a/cipher_test.go b/cipher_test.go new file mode 100644 index 0000000..1fd3e1a --- /dev/null +++ b/cipher_test.go @@ -0,0 +1,154 @@ +package ws + +import ( + "fmt" + "math/rand" + "reflect" + "testing" +) + +func TestCipher(t *testing.T) { + for i, test := range []struct { + in []byte + mask [4]byte + }{ + { + in: []byte("Hello, XOR!"), + mask: [4]byte{1, 2, 3, 4}, + }, + { + in: []byte("Hello, XOR!"), + mask: [4]byte{255, 255, 255, 255}, + }, + } { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + // naive implementation of xor-cipher + exp := cipherNaive(test.in, test.mask[:], 0) + + res := make([]byte, len(test.in)) + copy(res, test.in) + Cipher(res, test.mask[:], 0) + + if !reflect.DeepEqual(res, exp) { + t.Errorf("Cipher(%v, %v):\nact:\t%v\nexp:\t%v\n", test.in, test.mask, res, exp) + } + }) + } +} + +func TestCipherChops(t *testing.T) { + for n := 2; n <= 1024; n <<= 1 { + t.Run(fmt.Sprintf("%d", n), func(t *testing.T) { + p := make([]byte, n) + b := make([]byte, n) + m := make([]byte, 4) + + _, err := rand.Read(p) + if err != nil { + t.Fatal(err) + } + _, err = rand.Read(m) + if err != nil { + t.Fatal(err) + } + + exp := cipherNaive(p, m, 0) + + for i := 1; i <= n; i <<= 1 { + copy(b, p) + s := n / i + + for j := s; j <= n; j += s { + l, r := j-s, j + Cipher(b[l:r], m, l) + if !reflect.DeepEqual(b[l:r], exp[l:r]) { + t.Errorf("unexpected Cipher([%d:%d]) = %x; want %x", l, r, b[l:r], exp[l:r]) + return + } + } + } + + l := 0 + copy(b, p) + for l < n { + r := rand.Intn(n-l) + l + 1 + Cipher(b[l:r], m, l) + if !reflect.DeepEqual(b[l:r], exp[l:r]) { + t.Errorf("unexpected Cipher([%d:%d]):\nact:\t%v\nexp:\t%v\nact:\t%#x\nexp:\t%#x\n\n", l, r, b[l:r], exp[l:r], b[l:r], exp[l:r]) + return + } + l = r + } + }) + } +} + +func cipherNaive(p, m []byte, pos int) []byte { + r := make([]byte, len(p)) + copy(r, p) + cipherNaiveNoCp(r, m, pos) + return r +} + +func cipherNaiveNoCp(p, m []byte, pos int) []byte { + for i := 0; i < len(p); i++ { + p[i] ^= m[(pos+i)%4] + } + return p +} + +func BenchmarkCipher(b *testing.B) { + for _, bench := range []struct { + size int + offset int + }{ + { + size: 7, + offset: 1, + }, + { + size: 125, + }, + { + size: 1024, + }, + { + size: 4096, + }, + { + size: 4100, + offset: 4, + }, + { + size: 4099, + offset: 3, + }, + { + size: (1 << 15) + 7, + offset: 49, + }, + } { + bts := make([]byte, bench.size) + _, err := rand.Read(bts) + if err != nil { + b.Fatal(err) + } + + mask := make([]byte, 4) + _, err = rand.Read(mask) + if err != nil { + b.Fatal(err) + } + + //b.Run(fmt.Sprintf("naive_bytes=%d;offset=%d", bench.size, bench.offset), func(b *testing.B) { + // for i := 0; i < b.N; i++ { + // cipherNaiveNoCp(bts, mask, bench.offset) + // } + //}) + b.Run(fmt.Sprintf("bytes=%d;offset=%d", bench.size, bench.offset), func(b *testing.B) { + for i := 0; i < b.N; i++ { + Cipher(bts, mask, bench.offset) + } + }) + } +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..d9ef970 --- /dev/null +++ b/client.go @@ -0,0 +1,253 @@ +package ws + +import ( + "bufio" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gobwas/pool/pbufio" +) + +// ReaderPool describes object that manages reuse of bufio.Reader instances. +type ReaderPool interface { + Get(io.Reader) *bufio.Reader + Put(*bufio.Reader) +} + +// WriterPool describes object that manages reuse of bufio.Writer instances. +type WriterPool interface { + Get(io.Writer) *bufio.Writer + Put(*bufio.Writer) +} + +// Handshake represents handshake result. +type Handshake struct { + // Protocol is the selected during handshake subprotocol. + Protocol string + + // Extensions is the list of negotiated extensions. + Extensions []string +} + +// Response represents result of dialing. +type Response struct { + *http.Response + Handshake +} + +var ( + defaultWriterPool = writerPool(512) + defaultReaderPool = readerPool(512) +) + +// Errors used by the websocket client. +var ( + ErrBadStatus = fmt.Errorf("unexpected http status") + ErrBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol) +) + +// Dialer contains options for establishing websocket connection to an url. +type Dialer struct { + // Header is the set of custom headers that will be sent with the request. + Header http.Header + + // Protocol is the list of subprotocol names the client wishes to speak, ordered by preference. + // See https://tools.ietf.org/html/rfc6455#section-4.1 + Protocol []string + + // Extensions is the list of extensions, that client wishes to speak. + // See https://tools.ietf.org/html/rfc6455#section-4.1 + // See https://tools.ietf.org/html/rfc6455#section-9.1 + Extensions []string + + // NetDial is the function that is used to get plain tcp connection. + // If it is not nil, then it is used instead of net.Dialer. + NetDial func(ctx context.Context, network, addr string) (net.Conn, error) + + // NetDialTLS is the function that is used to get plain tcp connection with tls encryption. + // If it is not nil, then it is used instead of tls.DialWithDialer. + NetDialTLS func(ctx context.Context, network, addr string) (net.Conn, error) + + // TLSConfig is passed to tls.DialWithDialer. + TLSConfig *tls.Config + + // WriterPool is used to reuse bufio.Writers. + WriterPool WriterPool + + // ReaderPool is used to reuse bufio.Readers. + ReaderPool ReaderPool +} + +// Dial connects to the url host and handshakes connection to websocket. +func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, resp Response, err error) { + req := getRequest() + defer putRequest(req) + + err = req.Reset(urlstr, d.Header, d.Protocol, d.Extensions) + if err != nil { + return + } + + conn, err = d.dial(ctx, req.URL) + if err != nil { + return + } + + resp.Response, err = d.send(ctx, conn, req) + if err != nil { + return + } + + resp.Protocol, resp.Extensions, err = d.handshake(req, resp) + + return +} + +func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) { + addr := hostport(u) + if u.Scheme == "wss" { + if nd := d.NetDialTLS; nd != nil { + return nd(ctx, "tcp", addr) + } + + var nd net.Dialer + if deadline, ok := ctx.Deadline(); ok { + nd.Deadline = deadline + } + return tls.DialWithDialer(&nd, "tcp", addr, d.TLSConfig) + } + + if nd := d.NetDial; nd != nil { + return nd(ctx, "tcp", addr) + } + + var nd net.Dialer + return nd.DialContext(ctx, "tcp", addr) +} + +func (d Dialer) send(ctx context.Context, conn net.Conn, req *request) (resp *http.Response, err error) { + type respAndError struct { + resp *http.Response + err error + } + var ( + wp WriterPool + rp ReaderPool + ) + if wp = d.WriterPool; wp == nil { + wp = defaultWriterPool + } + if rp = d.ReaderPool; rp == nil { + rp = defaultReaderPool + } + + bw := wp.Get(conn) + defer wp.Put(bw) + + if err = req.Write(bw); err != nil { + return + } + if err = bw.Flush(); err != nil { + return + } + + br := rp.Get(conn) + defer rp.Put(br) + + if deadline, ok := ctx.Deadline(); ok { + ch := make(chan respAndError, 2) + time.AfterFunc(deadline.Sub(time.Now()), func() { + ch <- respAndError{nil, timeoutError{}} + }) + go func() { + resp, err = http.ReadResponse(br, nil) + ch <- respAndError{resp, err} + }() + r := <-ch + resp, err = r.resp, r.err + return + } + + return http.ReadResponse(br, nil) +} + +func (d Dialer) handshake(req *request, resp Response) (protocol string, extensions []string, err error) { + if resp.StatusCode != 101 { + err = ErrBadStatus + return + } + if upgrade := resp.Header.Get(headerUpgrade); strings.ToLower(upgrade) != "websocket" { + err = ErrBadUpgrade + return + } + if connection := resp.Header.Get(headerConnection); !strings.Contains(strings.ToLower(connection), "upgrade") { + err = ErrBadConnection + return + } + if !checkNonce(resp.Header.Get(headerSecAccept), req.Nonce) { + err = ErrBadSecAccept + return + } + if extensions := resp.Header.Get(headerSecExtensions); extensions != "" { + // TODO(gobwas): implement extensions logic. + } + + // We check single value of Sec-Websocket-Protocol header according to this: + // RFC6455 1.3: "The server selects one or none of the acceptable protocols and echoes + // that value in its handshake to indicate that it has selected that + // protocol." + if protocol = resp.Header.Get(headerSecProtocol); protocol != "" { + var has bool + for _, p := range req.Protocols { + if has = p == protocol; has { + break + } + } + if !has { + err = ErrBadSubProtocol + return + } + } + return +} + +func hostport(u *url.URL) string { + host, port := split2(u.Host, ':') + if port != "" { + return u.Host + } + if u.Scheme == "wss" { + return host + ":443" + } + return host + ":80" +} + +func split2(s string, sep byte) (a, b string) { + if i := strings.LastIndexByte(s, sep); i != -1 { + return s[:i], s[i+1:] + } + return s, "" +} + +type readerPool int + +func (n readerPool) Get(r io.Reader) *bufio.Reader { return pbufio.GetReader(r, int(n)) } +func (n readerPool) Put(r *bufio.Reader) { pbufio.PutReader(r, int(n)) } + +type writerPool int + +func (n writerPool) Get(w io.Writer) *bufio.Writer { return pbufio.GetWriter(w, int(n)) } +func (n writerPool) Put(w *bufio.Writer) { pbufio.PutWriter(w, int(n)) } + +type timeoutError struct{} + +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } +func (timeoutError) Error() string { return "client timeout" } diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..c20de73 --- /dev/null +++ b/client_test.go @@ -0,0 +1,245 @@ +package ws + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/url" + "testing" + "time" +) + +func TestDialerHandshake(t *testing.T) { + for i, test := range []struct { + res *http.Response + accept bool + protocols []string + err error + }{ + { + res: &http.Response{ + StatusCode: 101, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Upgrade"}, + headerUpgrade: []string{"websocket"}, + }, + }, + accept: true, + }, + { + res: &http.Response{ + StatusCode: 101, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Upgrade"}, + headerUpgrade: []string{"websocket"}, + headerSecProtocol: []string{"json"}, + }, + }, + protocols: []string{"xml", "json", "soap"}, + accept: true, + }, + { + res: &http.Response{ + StatusCode: 400, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Upgrade"}, + headerUpgrade: []string{"websocket"}, + }, + }, + err: ErrBadStatus, + }, + { + res: &http.Response{ + StatusCode: 101, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Error"}, + headerUpgrade: []string{"websocket"}, + }, + }, + err: ErrBadConnection, + }, + { + res: &http.Response{ + StatusCode: 101, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Upgrade"}, + headerUpgrade: []string{"iproto"}, + }, + }, + err: ErrBadUpgrade, + }, + { + res: &http.Response{ + StatusCode: 101, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Upgrade"}, + headerUpgrade: []string{"websocket"}, + }, + }, + accept: false, + err: ErrBadSecAccept, + }, + { + res: &http.Response{ + StatusCode: 101, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + headerConnection: []string{"Upgrade"}, + headerUpgrade: []string{"websocket"}, + headerSecProtocol: []string{"oops"}, + }, + }, + accept: true, + err: ErrBadSubProtocol, + }, + } { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + rb := &bytes.Buffer{} + wb := &bytes.Buffer{} + wbuf := bufio.NewReader(wb) + + sig := make(chan struct{}) + go func() { + <-sig + req, err := http.ReadRequest(wbuf) + if err != nil { + t.Fatal(err) + } + var key []byte + if test.accept { + rk := req.Header.Get(headerSecKey) + key = []byte(rk) + } else { + key = make([]byte, 24) + rand.Read(key) + } + + accept := makeAccept([]byte(key)) + test.res.Header.Set(headerSecAccept, string(accept)) + test.res.Request = req + test.res.Write(rb) + + sig <- struct{}{} + }() + + conn := &stubConn{ + read: func(p []byte) (int, error) { + <-sig + return rb.Read(p) + }, + write: func(p []byte) (int, error) { + n, err := wb.Write(p) + sig <- struct{}{} + return n, err + }, + close: func() error { return nil }, + } + + pr := stubReadPool{} + pw := stubWritePool{} + + d := Dialer{ + Protocol: test.protocols, + NetDial: func(_ context.Context, _, _ string) (net.Conn, error) { + return conn, nil + }, + ReaderPool: &pr, + WriterPool: &pw, + } + + _, _, err := d.Dial(context.Background(), "ws://gobwas.com") + if test.err != err { + t.Errorf("unexpected error: %v;\n\twant %v", err, test.err) + } + }) + } +} + +func TestHostPortResolve(t *testing.T) { + for _, test := range []struct { + url *url.URL + ret string + }{ + { + url: &url.URL{Host: "example.com", Scheme: "ws"}, + ret: "example.com:80", + }, + { + url: &url.URL{Host: "example.com", Scheme: "wss"}, + ret: "example.com:443", + }, + { + url: &url.URL{Host: "example.com:3000", Scheme: "wss"}, + ret: "example.com:3000", + }, + } { + t.Run(test.url.String(), func(t *testing.T) { + ret := hostport(test.url) + if test.ret != ret { + t.Errorf("expected %s; got %s", test.ret, ret) + } + }) + } +} + +type stubConn struct { + read func([]byte) (int, error) + write func([]byte) (int, error) + close func() error +} + +func (s stubConn) Read(p []byte) (int, error) { return s.read(p) } +func (s stubConn) Write(p []byte) (int, error) { return s.write(p) } +func (s stubConn) Close() error { return s.close() } +func (s stubConn) LocalAddr() net.Addr { return nil } +func (s stubConn) RemoteAddr() net.Addr { return nil } +func (s stubConn) SetDeadline(t time.Time) error { return nil } +func (s stubConn) SetReadDeadline(t time.Time) error { return nil } +func (s stubConn) SetWriteDeadline(t time.Time) error { return nil } + +func makeNonceFrom(bts []byte) (ret [nonceSize]byte) { + base64.StdEncoding.Encode(ret[:], bts) + return +} + +func nonceAsSlice(bts [nonceSize]byte) []byte { + return bts[:] +} + +type stubPool struct { + getCalls int + putCalls int +} + +type stubWritePool struct { + stubPool +} + +func (s *stubWritePool) Get(w io.Writer) *bufio.Writer { s.getCalls++; return bufio.NewWriter(w) } +func (s *stubWritePool) Put(bw *bufio.Writer) { s.putCalls++ } + +type stubReadPool struct { + stubPool +} + +func (s *stubReadPool) Get(r io.Reader) *bufio.Reader { s.getCalls++; return bufio.NewReader(r) } +func (s *stubReadPool) Put(br *bufio.Reader) { s.putCalls++ } diff --git a/example/autobahn/autobahn.go b/example/autobahn/autobahn.go new file mode 100644 index 0000000..1831f6e --- /dev/null +++ b/example/autobahn/autobahn.go @@ -0,0 +1,251 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +const dir = "./example/autobahn" + +var ( + addr = flag.String("listen", ":9001", "addr to listen") + reports = flag.String("reports", dir+"/reports", "path to reports directory") + static = flag.String("static", dir+"/static", "path to static assets directory") +) + +func main() { + flag.Parse() + + log.Printf("reports dir is set to: %s", *reports) + log.Printf("static dir is set to: %s", *static) + + http.HandleFunc("/", handlerIndex()) + http.HandleFunc("/library", handlerEcho()) + http.HandleFunc("/utils", handlerEcho2()) + http.Handle("/reports/", http.StripPrefix("/reports/", http.FileServer(http.Dir(*reports)))) + + log.Printf("ready to listen on %s", *addr) + log.Fatal(http.ListenAndServe(*addr, nil)) +} + +var ( + CloseInvalidPayload = ws.MustCompileFrame( + ws.NewCloseFrame(ws.StatusInvalidFramePayloadData, ""), + ) + CloseProtocolError = ws.MustCompileFrame( + ws.NewCloseFrame(ws.StatusProtocolError, ""), + ) +) + +func handlerEcho2() func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.Upgrade(r, w, nil) + if err != nil { + log.Printf("upgrade error: %s", err) + return + } + defer conn.Close() + + ch := wsutil.ControlHandler(conn, 0) + + rd := wsutil.NewReader(conn, ws.StateServerSide) + rd.HandleIntermediate(ch) + + for { + var r io.Reader = rd + var ur *wsutil.UTF8Reader + + h, err := rd.Next() + if err != nil { + log.Printf("next reader error: %s", err) + return + } + + switch { + case h.OpCode.IsControl(): + if err = ch(h, rd); err != nil { + log.Print(err) + return + } + continue + + case h.OpCode == ws.OpText: + ur = wsutil.NewUTF8Reader(r) + r = ur + } + + wr := wsutil.NextWriter(conn, h.OpCode, false) + _, err = io.Copy(wr, r) + if err == nil && ur != nil { + err = ur.Close() + } + if err == nil { + err = wr.Flush() + } + if err != nil { + log.Printf("copy error: %s", err) + return + } + } + } +} + +func handlerEcho() func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.Upgrade(r, w, nil) + if err != nil { + log.Printf("upgrade error: %s", err) + return + } + defer conn.Close() + + state := ws.StateServerSide + + textPending := false + utf8Reader := wsutil.NewUTF8Reader(nil) + cipherReader := wsutil.NewCipherReader(nil, nil) + + for { + header, err := ws.ReadHeader(conn) + if err != nil { + log.Printf("read header error: %s", err) + break + } + if err = ws.CheckHeader(header, state); err != nil { + log.Printf("header check error: %s", err) + conn.Write(CloseProtocolError) + return + } + + var r io.Reader + cipherReader.Reset( + io.LimitReader(conn, header.Length), + header.Mask, + ) + r = cipherReader + + var utf8Fin bool + switch header.OpCode { + case ws.OpPing: + header.OpCode = ws.OpPong + header.Mask = nil + ws.WriteHeader(conn, header) + io.CopyN(conn, cipherReader, header.Length) + continue + + case ws.OpPong: + io.CopyN(ioutil.Discard, conn, header.Length) + continue + + case ws.OpClose: + utf8Fin = true + + case ws.OpContinuation: + if textPending { + utf8Reader.SetSource(cipherReader) + r = utf8Reader + } + if header.Fin { + state = state.Clear(ws.StateFragmented) + textPending = false + utf8Fin = true + } + + case ws.OpText: + utf8Reader.Reset(cipherReader) + r = utf8Reader + + if !header.Fin { + state = state.Set(ws.StateFragmented) + textPending = true + } else { + utf8Fin = true + } + + case ws.OpBinary: + if !header.Fin { + state = state.Set(ws.StateFragmented) + } + } + + payload := make([]byte, header.Length) + _, err = io.ReadFull(r, payload) + if err == nil && utf8Fin { + err = utf8Reader.Close() + } + if err != nil { + log.Printf("read payload error: %s", err) + if err == wsutil.ErrInvalidUtf8 { + conn.Write(CloseInvalidPayload) + } else { + conn.Write(ws.CompiledClose) + } + return + } + + if header.OpCode == ws.OpClose { + code, reason := ws.ParseCloseFrameData(payload) + log.Printf("close frame received: %v %v", code, reason) + + if !code.Empty() { + switch { + case code.IsProtocolSpec() && !code.IsProtocolDefined(): + err = fmt.Errorf("close code from spec range is not defined") + default: + err = ws.CheckCloseFrameData(code, reason) + } + if err != nil { + log.Printf("invalid close data: %s", err) + conn.Write(CloseProtocolError) + } else { + ws.WriteFrame(conn, ws.NewCloseFrame(code, "")) + } + return + } + + conn.Write(ws.CompiledClose) + return + } + + header.Mask = nil + ws.WriteHeader(conn, header) + conn.Write(payload) + } + } +} + +func handlerIndex() func(w http.ResponseWriter, r *http.Request) { + index, err := os.Open(*static + "/index.html") + if err != nil { + log.Fatal(err) + } + bts, err := ioutil.ReadAll(index) + if err != nil { + log.Fatal(err) + } + + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("reqeust to %s", r.URL) + switch r.URL.Path { + case "/": + buf := bytes.NewBuffer(bts) + _, err := buf.WriteTo(w) + if err != nil { + log.Printf("write index bytes error: %s", err) + } + case "/favicon.ico": + w.WriteHeader(http.StatusNotFound) + default: + w.WriteHeader(http.StatusNotFound) + } + } +} diff --git a/example/autobahn/fuzzingclient.json b/example/autobahn/fuzzingclient.json new file mode 100644 index 0000000..39aa808 --- /dev/null +++ b/example/autobahn/fuzzingclient.json @@ -0,0 +1,17 @@ + +{ + "outdir": "./example/autobahn/reports/servers", + "servers": [ + { + "agent": "Library", + "url": "ws://127.0.0.1:9001/library" + }, + { + "agent": "Utils", + "url": "ws://127.0.0.1:9001/utils" + } + ], + "cases": ["*"], + "exclude-cases": [], + "exclude-agent-cases": {} +} diff --git a/example/autobahn/static/index.html b/example/autobahn/static/index.html new file mode 100644 index 0000000..77f9265 --- /dev/null +++ b/example/autobahn/static/index.html @@ -0,0 +1,7 @@ + + +

Welcome to WebSocket test server!

+

Ready to Autobahn!

+Reports + + diff --git a/frame.go b/frame.go new file mode 100644 index 0000000..08e4cc7 --- /dev/null +++ b/frame.go @@ -0,0 +1,339 @@ +package ws + +import ( + "bytes" + "encoding/binary" +) + +// Constants defined by specification. +const ( + MaxControlFramePayloadSize = 125 // All control frames must have a payload length of 125 bytes or less. +) + +// OpCode represents operation code. +type OpCode byte + +// Operation codes defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +const ( + OpContinuation OpCode = 0x0 + OpText = 0x1 + OpBinary = 0x2 + OpClose = 0x8 + OpPing = 0x9 + OpPong = 0xa +) + +// IsControl checks wheter the c is control operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func (c OpCode) IsControl() bool { + // RFC6455: Control frames are identified by opcodes where + // the most significant bit of the opcode is 1. + // + // Note that OpCode is only 4 bit length. + return c&0x8 != 0 +} + +// IsData checks wheter the c is data operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +func (c OpCode) IsData() bool { + // RFC6455: Data frames (e.g., non-control frames) are identified by opcodes + // where the most significant bit of the opcode is 0. + // + // Note that OpCode is only 4 bit length. + return c&0x8 == 0 +} + +// IsReserved checks wheter the c is reserved operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func (c OpCode) IsReserved() bool { + // RFC6455: + // %x3-7 are reserved for further non-control frames + // %xB-F are reserved for further control frames + return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf) +} + +// StatusCode represents the encoded reason for closure of websocket connection. +// +// There are few helper methods on StatusCode that helps to define a range in +// which given code is lay in. accordingly to ranges defined in specification. +// +// See https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode uint16 + +// StatusCodeRange describes range of StatusCode values. +type StatusCodeRange struct { + Min, Max StatusCode +} + +// Status code ranges defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-7.4.2 +var ( + StatusRangeNotInUse = StatusCodeRange{0, 999} + StatusRangeProtocol = StatusCodeRange{1000, 2999} + StatusRangeApplication = StatusCodeRange{3000, 3999} + StatusRangePrivate = StatusCodeRange{4000, 4999} +) + +// Status codes defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-7.4.1 +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway = 1001 + StatusProtocolError = 1002 + StatusUnsupportedData = 1003 + StatusNoMeaningYet = 1004 + StatusNoStatusRcvd = 1005 + StatusAbnormalClosure = 1006 + StatusInvalidFramePayloadData = 1007 + StatusPolicyViolation = 1008 + StatusMessageTooBig = 1009 + StatusMandatoryExt = 1010 + StatusInternalServerError = 1011 + StatusTLSHandshake = 1015 +) + +// In reports whether the code is defined in given range. +func (s StatusCode) In(r StatusCodeRange) bool { + return r.Min <= s && s <= r.Max +} + +// Empty reports wheter the code is empty. +// Empty code has no any meaning neither app level codes nor other. +// This method is useful just to check that code is golang default value 0. +func (s StatusCode) Empty() bool { + return s == 0 +} + +// IsNotUsed reports whether the code is predefined in not used range. +func (s StatusCode) IsNotUsed() bool { + return s.In(StatusRangeNotInUse) +} + +// IsApplicationSpec reports whether the code should be defined by +// application, framework or libraries specification. +func (s StatusCode) IsApplicationSpec() bool { + return s.In(StatusRangeApplication) +} + +// IsPrivateSpec reports whether the code should be defined privately. +func (s StatusCode) IsPrivateSpec() bool { + return s.In(StatusRangePrivate) +} + +// IsProtocolSpec reports whether the code should be defined by protocol specification. +func (s StatusCode) IsProtocolSpec() bool { + return s.In(StatusRangeProtocol) +} + +// IsProtocolDefined reports whether the code is already defined by protocol specification. +func (s StatusCode) IsProtocolDefined() bool { + switch s { + case StatusNormalClosure, + StatusGoingAway, + StatusProtocolError, + StatusUnsupportedData, + StatusInvalidFramePayloadData, + StatusPolicyViolation, + StatusMessageTooBig, + StatusMandatoryExt, + StatusInternalServerError, + StatusNoStatusRcvd, + StatusAbnormalClosure, + StatusTLSHandshake: + return true + } + return false +} + +// IsProtocolReserved reports whether the code is defined by protocol specification +// to be reserved only for application usage purpose. +func (s StatusCode) IsProtocolReserved() bool { + switch s { + // [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a + // Close control frame by an endpoint. + case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return true + default: + return false + } +} + +// Common frames with no special meaning. +var ( + PingFrame = Frame{Header{Fin: true, OpCode: OpPing}, nil} + PongFrame = Frame{Header{Fin: true, OpCode: OpPong}, nil} + CloseFrame = Frame{Header{Fin: true, OpCode: OpClose}, nil} +) + +// Compiled control frames for common use cases. +// For construct-serialize optimizations. +var ( + CompiledPing = MustCompileFrame(PingFrame) + CompiledPong = MustCompileFrame(PongFrame) + CompiledClose = MustCompileFrame(CloseFrame) +) + +// Header represents websocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + Rsv byte + OpCode OpCode + Length int64 + Mask []byte +} + +// Rsv1 reports whether the header has first rsv bit set. +func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 } + +// Rsv2 reports whether the header has second rsv bit set. +func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 } + +// Rsv3 reports whether the header has third rsv bit set. +func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 } + +// Frame represents websocket frame. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Frame struct { + Header Header + Payload []byte +} + +// NewFrame creates frame with given operation code, +// flag of completeness and payload bytes. +func NewFrame(op OpCode, fin bool, p []byte) Frame { + return Frame{ + Header: Header{ + Fin: fin, + OpCode: op, + Length: int64(len(p)), + }, + Payload: p, + } +} + +// NewTextFrame creates text frame with s as payload. +// Note that the s is copied in the returned frame payload. +func NewTextFrame(s string) Frame { + p := make([]byte, len(s)) + copy(p, s) + return NewFrame(OpText, true, p) +} + +// NewBinaryFrame creates binary frame with p as payload. +// Note that p is left as is in the returned frame without copying. +func NewBinaryFrame(p []byte) Frame { + return NewFrame(OpBinary, true, p) +} + +// NewPingFrame creates ping frame with p as payload. +// Note that p is left as is in the returned frame without copying. +func NewPingFrame(p []byte) Frame { + return NewFrame(OpPing, true, p) +} + +// NewPongFrame creates pong frame with p as payload. +// Note that p is left as is in the returned frame. +func NewPongFrame(p []byte) Frame { + return NewFrame(OpPong, true, p) +} + +// NewCloseFrame creates close frame with given closure code and reason. +// Note that it crops reason to fit the limit of control frames payload. +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func NewCloseFrame(code StatusCode, reason string) Frame { + return NewFrame(OpClose, true, NewCloseFrameData(code, reason)) +} + +// NewCloseFrameData makes byte representation of code and reason. +// +// Note that returned slice is at most 125 bytes length. +// If reason is too big it will crop it to fit the limit defined by thte spec. +// +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func NewCloseFrameData(code StatusCode, reason string) []byte { + n := min(2+len(reason), MaxControlFramePayloadSize) // 2 is for status code uint16 encoding. + p := make([]byte, n) + PutCloseFrameData(p, code, reason) + return p +} + +// PutCloseFrameData encodes code and reason into buf and returns the number of bytes written. +// If the buffer is too small to accomodate at least code, PutCloseFrameData will panic. +// Note that it does not checks maximum control frame payload size limit. +func PutCloseFrameData(p []byte, code StatusCode, reason string) int { + binary.BigEndian.PutUint16(p, uint16(code)) + n := copy(p[2:], reason) + return n + 2 +} + +// MaskFrame masks frame and returns frame with masked payload and Mask header's field set. +// Note that it copies f payload to prevent collisions. +// For less allocations you could use MaskFrameInplace or construct frame manually. +func MaskFrame(f Frame) Frame { + return MaskFrameWith(f, NewMask()) +} + +// MaskFrameWith masks frame with given mask and returns frame +// with masked payload and Mask header's field set. +// Note that it copies f payload to prevent collisions. +// For less allocations you could use MaskFrameInplaceWith or construct frame manually. +func MaskFrameWith(f Frame, mask []byte) Frame { + p := make([]byte, len(f.Payload)) + copy(p, f.Payload) + f.Payload = p + return MaskFrameInplaceWith(f, mask) +} + +// MaskFrame masks frame and returns frame with masked payload and Mask header's field set. +// Note that it applies xor cipher to f.Payload without copying, that is, it modifies f.Payload inplace. +func MaskFrameInplace(f Frame) Frame { + return MaskFrameInplaceWith(f, NewMask()) +} + +// MaskFrameInplaceWith masks frame with given mask and returns frame +// with masked payload and Mask header's field set. +// Note that it applies xor cipher to f.Payload without copying, that is, it modifies f.Payload inplace. +func MaskFrameInplaceWith(f Frame, mask []byte) Frame { + f.Header.Mask = mask + Cipher(f.Payload, mask, 0) + return f +} + +// NewMask creates new random mask. +func NewMask() []byte { + return randBytes(4) +} + +// CompileFrame returns byte representation of given frame. +// In terms of memory consumption it is useful to precompile static frames which are often used. +func CompileFrame(f Frame) (bts []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 16)) + err = WriteFrame(buf, f) + bts = buf.Bytes() + return +} + +// MustCompileFrame is like CompileFrame but panics if frame cannot be encoded. +func MustCompileFrame(f Frame) []byte { + bts, err := CompileFrame(f) + if err != nil { + panic(err) + } + return bts +} + +// Rsv creates rsv byte representation. +func Rsv(r1, r2, r3 bool) (rsv byte) { + if r1 { + rsv |= bit5 + } + if r2 { + rsv |= bit6 + } + if r3 { + rsv |= bit7 + } + return rsv +} diff --git a/frame_test.go b/frame_test.go new file mode 100644 index 0000000..981257a --- /dev/null +++ b/frame_test.go @@ -0,0 +1,26 @@ +package ws + +import ( + "fmt" + "testing" +) + +func TestOpCodeIsControl(t *testing.T) { + for _, test := range []struct { + code OpCode + exp bool + }{ + {OpClose, true}, + {OpPing, true}, + {OpPong, true}, + {OpBinary, false}, + {OpText, false}, + {OpContinuation, false}, + } { + t.Run(fmt.Sprintf("0x%02x", test.code), func(t *testing.T) { + if act := test.code.IsControl(); act != test.exp { + t.Errorf("IsControl = %v; want %v", act, test.exp) + } + }) + } +} diff --git a/read.go b/read.go new file mode 100644 index 0000000..6771819 --- /dev/null +++ b/read.go @@ -0,0 +1,149 @@ +package ws + +import ( + "encoding/binary" + "fmt" + "io" +) + +const ( + PlatformSizeLimit = int64(^(uint(0)) >> 1) // Max int value for current platform. +) + +// Errors used by frame reader. +var ( + ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0") + ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits") +) + +// ReadHeader reads a frame header from r. +func ReadHeader(r io.Reader) (h Header, err error) { + // Make slice with 2 bytes len for header, but with 8 byte capacity. + // The most useful case of reading header is to read header from + // client, that is with mask (4 byte) and some length most cases <= uint16 (2 bytes). + // If such case happened, we will reuse bytes without extra allocation. + bts := make([]byte, 2, 8) + + //var hv uint64 + //hp := uintptr(unsafe.Pointer(&hv)) + //bts := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: hp, Len: 2, Cap: 8})) + + // Prepare to hold first 2 bytes to choose size of next read. + _, err = io.ReadFull(r, bts) + if err != nil { + return + } + + h.Fin = bts[0]&bit0 != 0 + h.Rsv = (bts[0] & 0x70) >> 4 + h.OpCode = OpCode(bts[0] & 0x0f) + + var extra int + + mask := bts[1]&bit0 != 0 + if mask { + extra += 4 + } + + length := bts[1] & 0x7f + switch { + case length < 126: + h.Length = int64(length) + + case length == 126: + extra += 2 + + case length == 127: + extra += 8 + + default: + err = ErrHeaderLengthUnexpected + return + } + + if extra == 0 { + return + } + + if extra <= 8 { + bts = bts[:extra] + } else { + bts = make([]byte, extra) + } + + _, err = io.ReadFull(r, bts) + if err != nil { + return + } + + switch { + case length == 126: + h.Length = int64(binary.BigEndian.Uint16(bts[:2])) + bts = bts[2:] + + case length == 127: + if bts[0]&0x80 != 0 { + err = ErrHeaderLengthMSB + return + } + h.Length = int64(binary.BigEndian.Uint64(bts[:8])) + bts = bts[8:] + } + + if mask { + // TODO(gobwas): move to type Mask uint32 + h.Mask = bts[:4] + } + + return +} + +// ReadFrame reads a frame from r. +// It is not designed for high optimized use case cause it makes allocation +// for frame.Header.Length size inside to read frame payload into. +// +// Note that ReadFrame does not unmask payload. +func ReadFrame(r io.Reader) (f Frame, err error) { + f.Header, err = ReadHeader(r) + if err != nil { + return + } + + if f.Header.Length > 0 { + // int(f.Header.Length) is safe here cause we have + // checked it for overflow above in ReadHeader. + f.Payload = make([]byte, int(f.Header.Length)) + _, err = io.ReadFull(r, f.Payload) + } + + return +} + +// ParseCloseFrameData parses close frame status code and closure reason if any provided. +// If there is no status code in the payload +// the empty status code is returned (code.Empty()) with empty string as a reason. +func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) { + if len(payload) < 2 { + // We returning empty StatusCode here, preventing the situation + // when endpoint really sent code 1005 and we should return ProtocolError on that. + // + // In other words, we ignoring this rule [RFC6455:7.1.5]: + // If this Close control frame contains no status code, _The WebSocket + // Connection Close Code_ is considered to be 1005. + return + } + code = StatusCode(binary.BigEndian.Uint16(payload)) + reason = string(payload[2:]) + return +} + +// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing +// that it does not copies payload bytes into reason, but prepares unsafe cast. +func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) { + if len(payload) < 2 { + return + } + code = StatusCode(binary.BigEndian.Uint16(payload)) + reason = btsToString(payload[2:]) + return +} diff --git a/read_test.go b/read_test.go new file mode 100644 index 0000000..e540c2e --- /dev/null +++ b/read_test.go @@ -0,0 +1,67 @@ +package ws + +import ( + "bytes" + "fmt" + "io" + "reflect" + "testing" +) + +func TestReadHeader(t *testing.T) { + for i, test := range append([]RWCase{ + { + Data: bits("0000 0000 0 1111111 10000000 00000000 00000000 00000000 00000000 00000000 00000000 00000000"), + // _______________________________________________________________________ + // | + // Length value + Err: true, + }, + }, RWCases...) { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + r := bytes.NewReader(test.Data) + h, err := ReadHeader(r) + if test.Err && err == nil { + t.Errorf("expected error, got nil") + } + if !test.Err && err != nil { + t.Errorf("unexpected error: %s", err) + } + if test.Err { + return + } + if !reflect.DeepEqual(h, test.Header) { + t.Errorf("ReadHeader()\nread:\n\t%#v\nwant:\n\t%#v", h, test.Header) + } + }) + } +} + +func BenchmarkReadHeader(b *testing.B) { + for i, bench := range []struct { + label string + header Header + }{ + {"t", Header{OpCode: OpText, Fin: true}}, + {"t-m", Header{OpCode: OpText, Fin: true, Mask: NewMask()}}, + {"t-m-u16", Header{OpCode: OpText, Fin: true, Length: len16, Mask: NewMask()}}, + {"t-m-u64", Header{OpCode: OpText, Fin: true, Length: len64, Mask: NewMask()}}, + } { + b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) { + bts := MustCompileFrame(Frame{Header: bench.header}) + rds := make([]io.Reader, b.N) + for i := 0; i < b.N; i++ { + rds[i] = bytes.NewReader(bts) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := ReadHeader(rds[i]) + if err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..5233524 --- /dev/null +++ b/request.go @@ -0,0 +1,190 @@ +package ws + +import ( + "crypto/sha1" + "encoding/base64" + "fmt" + "hash" + "math/rand" + "net/http" + "net/textproto" + "net/url" + "strings" + "sync" +) + +const ( + nonceSize = 24 + acceptSize = 28 + shaSumSize = 20 +) + +var ( + headerUpgrade = textproto.CanonicalMIMEHeaderKey("Upgrade") + headerConnection = textproto.CanonicalMIMEHeaderKey("Connection") + headerHost = textproto.CanonicalMIMEHeaderKey("Host") + headerOrigin = textproto.CanonicalMIMEHeaderKey("Origin") + headerSecVersion = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Version") + headerSecProtocol = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Protocol") + headerSecExtensions = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Extensions") + headerSecKey = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Key") + headerSecAccept = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Accept") +) + +var ErrBadNonce = fmt.Errorf("nonce size is not %d", nonceSize) + +var WebSocketMagic = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +type request struct { + http.Request + Nonce [nonceSize]byte + Protocols []string + Extensions []string +} + +func (req *request) Reset(urlstr string, headers http.Header, protocols, extensions []string) error { + u, err := url.ParseRequestURI(urlstr) + if err != nil { + return err + } + + req.URL = u + + newNonce(req.Nonce[:]) + req.Header.Set(headerSecKey, string(req.Nonce[:])) + + req.Protocols = protocols + if protocols != nil { + req.Header.Set(headerSecProtocol, strings.Join(protocols, ", ")) + } + + req.Extensions = extensions + if extensions != nil { + req.Header.Set(headerSecExtensions, strings.Join(extensions, ", ")) + } + + req.Header.Set("User-Agent", "") // Disable default user-agent header. + + if headers != nil { + for k, v := range headers { + req.Header[k] = v + } + } + + return nil +} + +var requestPool sync.Pool + +func getRequest() *request { + if req := requestPool.Get(); req != nil { + return req.(*request) + } + return newCommonRequest() +} + +func putRequest(req *request) { + req.URL = nil + req.Protocols = nil + req.Extensions = nil + + for k := range req.Header { + switch k { + case headerUpgrade, headerConnection, headerSecVersion: + // leave common headers + default: + delete(req.Header, k) + } + } + + requestPool.Put(req) +} + +func newCommonRequest() *request { + req := &request{ + Request: http.Request{ + Header: make(http.Header), + }, + } + + req.Header.Set(headerUpgrade, "websocket") + req.Header.Set(headerConnection, "Upgrade") + req.Header.Set(headerSecVersion, "13") + + return req +} + +var sha1Pool sync.Pool + +func acquireSha1() hash.Hash { + if h := sha1Pool.Get(); h != nil { + return h.(hash.Hash) + } + return sha1.New() +} + +func releaseSha1(h hash.Hash) { + h.Reset() + sha1Pool.Put(h) +} + +// todo bench put expect to req as array +func checkNonce(accept string, nonce [nonceSize]byte) bool { + if len(accept) != 28 { + return false + } + + expect := makeAccept(nonce[:]) + + return string(expect) == accept +} + +//const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +// +//func newNonce(dest []byte) { +// for i := 0; i < 22; i++ { +// dest[i] = alphabet[rand.Intn(len(alphabet))] +// } +// dest[22] = '=' +// dest[23] = '=' +//} + +func randBytes(n int) []byte { + bts := make([]byte, n) + if _, err := rand.Read(bts); err != nil { + panic(fmt.Sprintf("rand read error: %s", err)) + } + return bts +} + +func newNonce(dest []byte) { + base64.StdEncoding.Encode(dest, randBytes(16)) +} + +func makeAccept(nonce []byte) []byte { + bts := make([]byte, 0, acceptSize+shaSumSize) + n := putAccept(nonce, bts) + return bts[:n] +} + +func putAccept(nonce, buf []byte) int { + if cap(buf) < acceptSize+shaSumSize { + panic(fmt.Sprintf("buffer cap is %d; want at least %d", len(buf), acceptSize+shaSumSize)) + } + if len(nonce) != nonceSize { + panic(fmt.Sprintf("nonce size is %d; want %d", len(nonce), nonceSize)) + } + + sha := acquireSha1() + defer releaseSha1(sha) + + sha.Write(nonce) + sha.Write(WebSocketMagic) + + buf = buf[:acceptSize] + sum := sha.Sum(buf[acceptSize:]) + + base64.StdEncoding.Encode(buf, sum) + + return acceptSize +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..a8f3962 --- /dev/null +++ b/request_test.go @@ -0,0 +1,143 @@ +package ws + +import ( + "bufio" + "bytes" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "reflect" + "strings" + "testing" +) + +func TestRequestReset(t *testing.T) { + for i, test := range []struct { + url string + expHost string + expPath string + protocols []string + extensions []string + headers http.Header + err bool + }{ + { + url: "wss://websocket.com/chat", + expHost: "websocket.com", + expPath: "/chat", + protocols: []string{"subproto", "hello"}, + extensions: []string{"foo; bar=1", "baz"}, + headers: http.Header{ + "Origin": []string{"https://websocket.com"}, + }, + }, + { + url: "websocket.com/chat", + err: true, + }, + } { + commonHeaders := map[string]string{ + headerConnection: "Upgrade", + headerUpgrade: "websocket", + headerSecVersion: "13", + } + + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + req := getRequest() + err := req.Reset(test.url, test.headers, test.protocols, test.extensions) + if test.err && err == nil { + t.Errorf("expected error; got nil") + } + if !test.err && err != nil { + t.Errorf("unexpected error: %s", err) + } + if test.err { + return + } + + buf := &bytes.Buffer{} + if err = req.Write(buf); err != nil { + t.Errorf("dumping request error: %s", err) + return + } + + r, err := http.ReadRequest(bufio.NewReader(buf)) + if err != nil { + t.Errorf("read request error: %s", err) + return + } + + if r.Method != "GET" { + t.Errorf("http method is %s; want GET", r.Method) + } + if r.ProtoMinor != 1 { + t.Errorf("http proto minor is %d; want 1", r.ProtoMinor) + } + if r.ProtoMajor != 1 { + t.Errorf("http proto major is %d; want 1", r.ProtoMajor) + } + if r.URL.Path != test.expPath { + t.Errorf("http path is %s; want %s", r.URL.Path, test.expPath) + } + + key := r.Header.Get(headerSecKey) + bts, err := base64.StdEncoding.DecodeString(key) + if err != nil { + t.Errorf("bad %q header: %s", headerSecKey, err) + } + if n := len(bts); n != 16 { + t.Errorf("nonce len is %d; want 16", n) + } + r.Header.Del(headerSecKey) + + sub := r.Header.Get(headerSecProtocol) + protocols := strings.Split(sub, ",") + for i, p := range protocols { + protocols[i] = strings.TrimSpace(p) + } + if !reflect.DeepEqual(protocols, test.protocols) { + t.Errorf("%q headers is %v; want %s", headerSecProtocol, protocols, test.protocols) + } + r.Header.Del(headerSecProtocol) + + ext := r.Header.Get(headerSecExtensions) + extensions := strings.Split(ext, ",") + for i, e := range extensions { + extensions[i] = strings.TrimSpace(e) + } + if !reflect.DeepEqual(extensions, test.extensions) { + t.Errorf("%q headers is %v; want %s", headerSecExtensions, extensions, test.extensions) + } + r.Header.Del(headerSecExtensions) + + for key, exp := range commonHeaders { + if act := r.Header.Get(key); act != exp { + t.Errorf("http %q header is %q; want %q", key, act, exp) + } + r.Header.Del(key) + } + for key, exp := range test.headers { + if act := r.Header.Get(key); act != exp[0] { + t.Errorf("http %q custom header is %q; want %q", key, act, exp[0]) + } + r.Header.Del(key) + } + if len(r.Header) != 0 { + t.Errorf("http request has extra headers:\n\t%v", r.Header) + } + }) + } +} + +func BenchmarkMakeAccept(b *testing.B) { + nonce := make([]byte, nonceSize) + _, err := rand.Read(nonce) + if err != nil { + b.Fatal(err) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = makeAccept(nonce) + } +} diff --git a/rw_test.go b/rw_test.go new file mode 100644 index 0000000..4c3342b --- /dev/null +++ b/rw_test.go @@ -0,0 +1,86 @@ +package ws + +import ( + "fmt" + "strings" +) + +type RWCase struct { + Data []byte + Header Header + Err bool +} + +var RWCases = []RWCase{ + { + Data: bits("1 001 0001 0 1100100"), + // _ ___ ____ _ _______ + // | | | | | + // Fin | | Mask Length + // Rsv | + // TextFrame + Header: Header{ + Fin: true, + Rsv: Rsv(false, false, true), + OpCode: OpText, + Length: 100, + Mask: nil, + }, + }, + { + Data: bits("1 001 0001 1 1100100 00000001 10001000 00000000 11111111"), + // _ ___ ____ _ _______ ___________________________________ + // | | | | | | + // Fin | | Mask Length Mask value + // Rsv | + // TextFrame + Header: Header{ + Fin: true, + Rsv: Rsv(false, false, true), + OpCode: OpText, + Length: 100, + Mask: []byte{0x01, 0x88, 0x00, 0xff}, + }, + }, + { + Data: bits("0 110 0010 0 1111110 00001111 11111111"), + // _ ___ ____ _ _______ _________________ + // | | | | | | + // Fin | | Mask Length Length value + // Rsv | + // BinaryFrame + Header: Header{ + Fin: false, + Rsv: Rsv(true, true, false), + OpCode: OpBinary, + Length: 0x0fff, + Mask: nil, + }, + }, + { + Data: bits("1 000 1010 0 1111111 01111111 00000000 00000000 00000000 00000000 00000000 00000000 00000000"), + // _ ___ ____ _ _______ _______________________________________________________________________ + // | | | | | | + // Fin | | Mask Length Length value + // Rsv | + // PongFrame + Header: Header{ + Fin: true, + Rsv: Rsv(false, false, false), + OpCode: OpPong, + Length: 0x7f00000000000000, + Mask: nil, + }, + }, +} + +func bits(s string) []byte { + s = strings.Replace(s, " ", "", -1) + bts := make([]byte, len(s)/8) + + for i, j := 0, 0; i < len(s); i, j = i+8, j+1 { + fmt.Sscanf(s[i:], "%08b", &bts[j]) + } + + return bts +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..49d145f --- /dev/null +++ b/server.go @@ -0,0 +1,220 @@ +package ws + +import ( + "bufio" + "fmt" + "net" + "net/http" + "strings" + _ "unsafe" // for go:linkname +) + +const ( + textUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n" + crlf = "\r\n" + colonAndSpace = ": " +) + +// Errors used by upgraders. +var ( + ErrBadHost = fmt.Errorf("bad %q header", headerHost) + ErrBadUpgrade = fmt.Errorf("bad %q header", headerUpgrade) + ErrBadConnection = fmt.Errorf("bad %q header", headerConnection) + ErrBadSecAccept = fmt.Errorf("bad %q header", headerSecAccept) + ErrBadSecKey = fmt.Errorf("bad %q header", headerSecKey) + ErrBadSecVersion = fmt.Errorf("bad %q header", headerSecVersion) + ErrBadHijacker = fmt.Errorf("given http.ResponseWriter is not a http.Hijacker") +) + +// SelectFromSlice creates accept function that could be used as Protocol/Extension +// select function in the UpgradeConfig. +func SelectFromSlice(accept []string) func(string) bool { + if len(accept) > 16 { + mp := make(map[string]struct{}, len(accept)) + for _, p := range accept { + mp[p] = struct{}{} + } + return func(p string) bool { + _, ok := mp[p] + return ok + } + } + return func(p string) bool { + for _, ok := range accept { + if p == ok { + return true + } + } + return false + } +} + +// UpgradeConfig contains options for upgrading http connection to websocket. +type UpgradeConfig struct { + // Header is the set of custom headers that will be sent with the response. + Header http.Header + + // Protocol is the select function that is used to select subprotocol + // from client passed list. + Protocol func(string) bool + + // Extension is the select function that is used to select extensions + // from client passed list. + Extension func(string) bool +} + +// Upgrade upgrades http connection to websocket. +// It hijacks net.Conn from response writer. +// +// If succeed it returns upgraded connection and Handshake struct describing +// handshake info. +func Upgrade(r *http.Request, w http.ResponseWriter, c *UpgradeConfig) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) { + if r.Host == "" { + err = ErrBadHost + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if u := getHeader(r.Header, headerUpgrade); u != "websocket" && strings.ToLower(u) != "websocket" { + err = ErrBadUpgrade + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if c := getHeader(r.Header, headerConnection); c != "Upgrade" && !hasToken(c, "upgrade") { + err = ErrBadConnection + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if v := getHeader(r.Header, headerSecVersion); v != "13" { + err = ErrBadSecVersion + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + nonce := getHeader(r.Header, headerSecKey) + if len(nonce) != nonceSize { + err = ErrBadSecKey + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + err = ErrBadHijacker + w.WriteHeader(http.StatusInternalServerError) + return + } + conn, rw, err = hj.Hijack() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + rw.WriteString(textUpgrade) + + accept := makeAccept(strToBytes(nonce)) + writeHeaderKey(rw.Writer, headerSecAccept) + writeHeaderValueBytes(rw.Writer, accept) + + if c != nil { + if p, check := r.Header[headerSecProtocol], c.Protocol; len(p) > 0 && check != nil { + for _, v := range p { + if check(v) { + hs.Protocol = v + writeHeader(rw.Writer, headerSecProtocol, hs.Protocol) + break + } + } + } + // TODO(gobwas) parse extensions. + //if e, check := r.Header[headerSecExtensions], c.Extension; len(e) > 0 && check != nil { + // hs.Extensions = selectExtensions(e, c.Extension) + // if len(hs.Extensions) > 0 { + // writeHeader(rw.Writer, headerSecExtensions, strings.Join(hs.Extensions, ", ")) + // } + //} + for key, values := range c.Header { + for _, val := range values { + writeHeader(rw.Writer, key, val) + } + } + } + + rw.WriteString(crlf) + + err = rw.Flush() + + return +} + +// getHeader is the same as textproto.MIMEHeader.Get, except the thing, +// that key is already canonical. This helps to increase performance. +func getHeader(h http.Header, key string) string { + if h == nil { + return "" + } + v := h[key] + if len(v) == 0 { + return "" + } + return v[0] +} + +func writeHeader(bw *bufio.Writer, key, value string) { + writeHeaderKey(bw, key) + writeHeaderValue(bw, value) +} + +func writeHeaderKey(bw *bufio.Writer, key string) { + bw.WriteString(key) + bw.WriteString(colonAndSpace) +} + +func writeHeaderValue(bw *bufio.Writer, value string) { + bw.WriteString(value) + bw.WriteString(crlf) +} + +func writeHeaderValueBytes(bw *bufio.Writer, value []byte) { + bw.Write(value) + bw.WriteString(crlf) +} + +func hasToken(header, token string) bool { + var pos int + for i := 0; i <= len(header); i++ { + if i == len(header) || header[i] == ',' { + v := strings.TrimSpace(header[pos:i]) + if len(v) == len(token) && strings.ToLower(v) == token { + return true + } + pos = i + 1 + } + } + return false +} + +func selectProtocol(h string, ok func(string) bool) string { + var start int + for i := 0; i < len(h); i++ { + c := h[i] + // The elements that comprise this value MUST be non-empty strings with characters in the range + // U+0021 to U+007E not including separator characters as defined in [RFC2616] + // and MUST all be unique strings. + if c != ',' && '!' <= c && c <= '~' { + continue + } + if str := h[start:i]; len(str) > 0 && ok(str) { + return str + } + start = i + 1 + } + if str := h[start:]; len(str) > 0 && ok(str) { + return str + } + return "" +} + +func selectExtensions(h []string, ok func(string) bool) []string { + // TODO(gobwas): parse extensions with params + return nil +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..de1e549 --- /dev/null +++ b/server_test.go @@ -0,0 +1,426 @@ +package ws + +import ( + "bufio" + "bytes" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "net/http/httputil" + "reflect" + "sort" + "strings" + "testing" + _ "unsafe" // for go:linkname +) + +func TestUpgrade(t *testing.T) { + for i, test := range []struct { + nonce []byte + req *http.Request + res *http.Response + hs Handshake + cfg *UpgradeConfig + err error + }{ + { + nonce: mustMakeNonce(), + req: mustMakeRequest("GET", "ws://example.org", http.Header{ + headerUpgrade: []string{"websocket"}, + headerConnection: []string{"Upgrade"}, + headerSecVersion: []string{"13"}, + }), + res: mustMakeResponse(101, http.Header{ + headerUpgrade: []string{"websocket"}, + headerConnection: []string{"Upgrade"}, + }), + cfg: &UpgradeConfig{ + Protocol: func(sub string) bool { + return true + }, + }, + }, + { + nonce: mustMakeNonce(), + req: mustMakeRequest("GET", "ws://example.org", http.Header{ + headerUpgrade: []string{"WEBSOCKET"}, + headerConnection: []string{"UPGRADE"}, + headerSecVersion: []string{"13"}, + }), + res: mustMakeResponse(101, http.Header{ + headerUpgrade: []string{"websocket"}, + headerConnection: []string{"Upgrade"}, + }), + cfg: &UpgradeConfig{ + Protocol: func(sub string) bool { + return true + }, + }, + }, + { + nonce: mustMakeNonce(), + req: mustMakeRequest("GET", "ws://example.org", http.Header{ + headerUpgrade: []string{"websocket"}, + headerConnection: []string{"Upgrade"}, + headerSecVersion: []string{"13"}, + headerSecProtocol: []string{"a", "b", "c", "d"}, + }), + res: mustMakeResponse(101, http.Header{ + headerUpgrade: []string{"websocket"}, + headerConnection: []string{"Upgrade"}, + headerSecProtocol: []string{"b"}, + }), + hs: Handshake{Protocol: "b"}, + cfg: &UpgradeConfig{ + Protocol: SelectFromSlice([]string{"b", "d"}), + }, + }, + // TODO(gobwas) uncomment after selectExtension is ready. + //{ + // nonce: mustMakeNonce(), + // req: mustMakeRequest("GET", "ws://example.org", http.Header{ + // headerUpgrade: []string{"websocket"}, + // headerConnection: []string{"Upgrade"}, + // headerSecVersion: []string{"13"}, + // headerSecExtensions: []string{"a", "b", "c", "d"}, + // }), + // res: mustMakeResponse(101, http.Header{ + // headerUpgrade: []string{"websocket"}, + // headerConnection: []string{"Upgrade"}, + // headerSecExtensions: []string{"b", "d"}, + // }), + // hs: Handshake{Extensions: ["b", "d"]}, + // cfg: &UpgradeConfig{ + // Extension: SelectFromSlice([]string{"b", "d"}), + // }, + //}, + } { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + if test.nonce != nil { + test.req.Header.Set(headerSecKey, string(test.nonce)) + test.res.Header.Set(headerSecAccept, string(makeAccept(test.nonce))) + } + + res := newRecorder() + _, _, hs, err := Upgrade(test.req, res, test.cfg) + if test.err != err { + t.Errorf("expected error to be '%v', got '%v'", test.err, err) + return + } + + actRespBts := sortHeaders(res.Bytes()) + expRespBts := sortHeaders(dumpResponse(test.res)) + if !bytes.Equal(actRespBts, expRespBts) { + t.Errorf( + "unexpected http response:\n---- act:\n%s\n---- want:\n%s\n====", + actRespBts, expRespBts, + ) + return + } + + if !reflect.DeepEqual(hs, test.hs) { + t.Errorf("unexpected handshake: %#v; want %#v", hs, test.hs) + } + }) + } +} + +func BenchmarkUpgrade(b *testing.B) { + bts101 := []byte("HTTP/1.1 101") + for _, bench := range []struct { + label string + req *http.Request + cfg *UpgradeConfig + }{ + { + label: "base", + req: mustMakeRequest("GET", "ws://example.org", http.Header{ + headerUpgrade: []string{"websocket"}, + headerConnection: []string{"Upgrade"}, + headerSecVersion: []string{"13"}, + headerSecKey: []string{string(mustMakeNonce())}, + }), + cfg: &UpgradeConfig{ + Protocol: func(sub string) bool { + return true + }, + }, + }, + { + label: "uppercase", + req: mustMakeRequest("GET", "ws://example.org", http.Header{ + headerUpgrade: []string{"WEBSOCKET"}, + headerConnection: []string{"UPGRADE"}, + headerSecVersion: []string{"13"}, + headerSecKey: []string{string(mustMakeNonce())}, + }), + cfg: &UpgradeConfig{ + Protocol: func(sub string) bool { + return true + }, + }, + }, + } { + b.Run(bench.label, func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + res := newRecorder() + _, _, _, err := Upgrade(bench.req, res, bench.cfg) + if err != nil { + b.Fatal(err) + } + if !bytes.HasPrefix(res.Body.Bytes(), bts101) { + b.Fatalf("unexpected http status code: %v\n%s", res.Code, res.Body.String()) + } + } + }) + }) + } +} + +func TestSelectProtocol(t *testing.T) { + for i, test := range []struct { + header string + }{ + {"jsonrpc, soap, grpc"}, + } { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + exp := strings.Split(test.header, ",") + for i, p := range exp { + exp[i] = strings.TrimSpace(p) + } + + var calls []string + selectProtocol(test.header, func(s string) bool { + calls = append(calls, s) + return false + }) + + if !reflect.DeepEqual(calls, exp) { + t.Errorf("selectProtocol(%q, fn); called fn with %v; want %v", test.header, calls, exp) + } + }) + } +} + +func TestHasToken(t *testing.T) { + for i, test := range []struct { + header string + token string + exp bool + }{ + {"Keep-Alive, Close, Upgrade", "upgrade", true}, + {"Keep-Alive, Close, upgrade, hello", "upgrade", true}, + {"Keep-Alive, Close, hello", "upgrade", false}, + } { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + if has := hasToken(test.header, test.token); has != test.exp { + t.Errorf("hasToken(%q, %q) = %v; want %v", test.header, test.token, has, test.exp) + } + }) + } +} + +func BenchmarkHasToken(b *testing.B) { + for i, bench := range []struct { + header string + token string + }{ + {"Keep-Alive, Close, Upgrade", "upgrade"}, + {"Keep-Alive, Close, upgrade, hello", "upgrade"}, + {"Keep-Alive, Close, hello", "upgrade"}, + } { + b.Run(fmt.Sprintf("#%d", i), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = hasToken(bench.header, bench.token) + } + }) + } +} + +func TestSelectExtensions(t *testing.T) { + +} + +func BenchmarkSelectProtocol(b *testing.B) { + for _, bench := range []struct { + label string + header string + accept func(string) bool + }{ + { + label: "never accept", + header: "jsonrpc, soap, grpc", + accept: func(s string) bool { + return len(s)%2 == 2 // never ok + }, + }, + { + label: "from slice", + header: "a, b, c, d, e, f, g", + accept: SelectFromSlice([]string{"g", "f", "e", "d"}), + }, + { + label: "uniq 1024 from slise", + header: strings.Join(randProtocols(1024, 16), ", "), + accept: SelectFromSlice(randProtocols(1024, 17)), + }, + } { + b.Run(fmt.Sprintf("#%s_optimized", bench.label), func(b *testing.B) { + for i := 0; i < b.N; i++ { + selectProtocol(bench.header, bench.accept) + } + }) + } +} + +func randProtocols(n, m int) []string { + ret := make([]string, n) + bts := make([]byte, m) + uniq := map[string]bool{} + for i := 0; i < n; i++ { + for { + for j := 0; j < m; j++ { + bts[j] = byte(rand.Intn('x'-'a') + 'a') + } + str := string(bts) + if _, has := uniq[str]; !has { + ret[i] = str + break + } + } + } + return ret +} +func dumpRequest(req *http.Request) []byte { + bts, err := httputil.DumpRequest(req, true) + if err != nil { + panic(err) + } + return bts +} + +func dumpResponse(res *http.Response) []byte { + cleanClose := !res.Close + if cleanClose { + for _, v := range res.Header[headerConnection] { + if v == "close" { + cleanClose = false + break + } + } + } + + bts, err := httputil.DumpResponse(res, true) + if err != nil { + panic(err) + } + + if cleanClose { + bts = bytes.Replace(bts, []byte("Connection: close\r\n"), nil, -1) + } + + return bts +} + +type headersBytes [][]byte + +func (h headersBytes) Len() int { return len(h) } +func (h headersBytes) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h headersBytes) Less(i, j int) bool { return string(h[i]) < string(h[j]) } + +func sortHeaders(bts []byte) []byte { + lines := bytes.Split(bts, []byte("\r\n")) + if len(lines) <= 1 { + return bts + } + sort.Sort(headersBytes(lines[1 : len(lines)-2])) + return bytes.Join(lines, []byte("\r\n")) +} + +type recorder struct { + *httptest.ResponseRecorder + hijacked bool +} + +func newRecorder() *recorder { + return &recorder{ + ResponseRecorder: httptest.NewRecorder(), + } +} + +func (r *recorder) Bytes() []byte { + if r.hijacked { + return r.ResponseRecorder.Body.Bytes() + } + return dumpResponse(r.Result()) +} + +//go:linkname httpPutBufioReader net/http.putBufioReader +func httpPutBufioReader(*bufio.Reader) + +//go:linkname httpPutBufioWriter net/http.putBufioWriter +func httpPutBufioWriter(*bufio.Writer) + +//go:linkname httpNewBufioReader net/http.newBufioReader +func httpNewBufioReader(io.Reader) *bufio.Reader + +//go:linkname httpNewBufioWriterSize net/http.newBufioWriterSize +func httpNewBufioWriterSize(io.Writer, int) *bufio.Writer + +func (r *recorder) Hijack() (conn net.Conn, brw *bufio.ReadWriter, err error) { + if r.hijacked { + err = fmt.Errorf("already hijacked") + return + } + + r.hijacked = true + + buf := r.ResponseRecorder.Body + + conn = stubConn{ + read: buf.Read, + write: buf.Write, + close: func() error { return nil }, + } + + // Use httpNewBufio* linked functions here to make + // benchmark more closer to real life usage. + br := httpNewBufioReader(buf) + bw := httpNewBufioWriterSize(buf, 4<<10) + + brw = bufio.NewReadWriter(br, bw) + + return +} + +func mustMakeRequest(method, url string, headers http.Header) *http.Request { + req, err := http.NewRequest(method, url, nil) + if err != nil { + panic(err) + } + req.Header = headers + return req +} + +func mustMakeResponse(code int, headers http.Header) *http.Response { + res := &http.Response{ + StatusCode: code, + Status: http.StatusText(code), + Header: headers, + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + } + return res +} + +func mustMakeNonce() []byte { + b := make([]byte, nonceSize) + newNonce(b) + return b +} diff --git a/server_test.s b/server_test.s new file mode 100644 index 0000000..e69de29 diff --git a/util.go b/util.go new file mode 100644 index 0000000..e16abc5 --- /dev/null +++ b/util.go @@ -0,0 +1,25 @@ +package ws + +import ( + "reflect" + "unsafe" +) + +func strToBytes(str string) []byte { + s := *(*reflect.StringHeader)(unsafe.Pointer(&str)) + b := &reflect.SliceHeader{Data: s.Data, Len: s.Len, Cap: s.Len} + return *(*[]byte)(unsafe.Pointer(b)) +} + +func btsToString(bts []byte) string { + b := *(*reflect.SliceHeader)(unsafe.Pointer(&bts)) + s := &reflect.StringHeader{Data: b.Data, Len: b.Len} + return *(*string)(unsafe.Pointer(s)) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/write.go b/write.go new file mode 100644 index 0000000..a3056be --- /dev/null +++ b/write.go @@ -0,0 +1,82 @@ +package ws + +import ( + "encoding/binary" + "io" +) + +const ( + bit0 = 0x80 + bit1 = 0x40 + bit2 = 0x20 + bit3 = 0x10 + bit4 = 0x08 + bit5 = 0x04 + bit6 = 0x02 + bit7 = 0x01 + + len16 = int64(^(uint16(0))) + len64 = int64(^(uint64(0)) >> 1) +) + +func WriteHeader(w io.Writer, h Header) error { + size := 2 + + var lenByte byte + switch { + case h.Length < 126: + lenByte = byte(h.Length) + size += 0 + + case h.Length <= len16: + lenByte = 126 + size += 2 + + case h.Length <= len64: + lenByte = 127 + size += 8 + + default: + return ErrHeaderLengthUnexpected + } + + if h.Mask != nil { + lenByte |= bit0 + size += 4 + } + + bts := make([]byte, size) + + if h.Fin { + bts[0] |= bit0 + } + bts[0] |= h.Rsv << 4 + bts[0] |= byte(h.OpCode) + bts[1] = lenByte + + maskPos := 2 // after fin, rsv and op code byte and length byte. + switch { + case lenByte == 126: + binary.BigEndian.PutUint16(bts[2:], uint16(h.Length)) + maskPos += 2 + case lenByte == 127: + binary.BigEndian.PutUint64(bts[2:], uint64(h.Length)) + maskPos += 8 + } + + if h.Mask != nil { + copy(bts[maskPos:], h.Mask) + } + + _, err := w.Write(bts) + return err +} + +func WriteFrame(w io.Writer, f Frame) error { + err := WriteHeader(w, f.Header) + if err != nil { + return err + } + _, err = w.Write(f.Payload) + return err +} diff --git a/write_test.go b/write_test.go new file mode 100644 index 0000000..041bf7c --- /dev/null +++ b/write_test.go @@ -0,0 +1,73 @@ +package ws + +import ( + "bytes" + "fmt" + "io/ioutil" + "testing" +) + +func TestWriteHeader(t *testing.T) { + for i, test := range RWCases { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + buf := &bytes.Buffer{} + err := WriteHeader(buf, test.Header) + if test.Err && err == nil { + t.Errorf("expected error, got nil") + } + if !test.Err && err != nil { + t.Errorf("unexpected error: %s", err) + } + if test.Err { + return + } + if bts := buf.Bytes(); !bytes.Equal(bts, test.Data) { + t.Errorf("WriteHeader()\nwrote:\n\t%08b\nwant:\n\t%08b", bts, test.Data) + } + }) + } +} + +func BenchmarkWriteHeader(b *testing.B) { + for _, bench := range []struct { + label string + header Header + }{ + { + "ping", Header{ + OpCode: OpPing, + Fin: true, + }, + }, + { + "text16", Header{ + OpCode: OpText, + Fin: true, + Length: int64(^(uint16(0))), + }, + }, + { + "text64", Header{ + OpCode: OpText, + Fin: true, + Length: int64(^(uint64(0)) >> 1), + }, + }, + { + "text64mask", Header{ + OpCode: OpText, + Fin: true, + Length: int64(^(uint64(0)) >> 1), + Mask: []byte("mask"), + }, + }, + } { + b.Run(bench.label, func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err := WriteHeader(ioutil.Discard, bench.header); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/wsutil/cipher.go b/wsutil/cipher.go new file mode 100644 index 0000000..f0c9d3a --- /dev/null +++ b/wsutil/cipher.go @@ -0,0 +1,30 @@ +package wsutil + +import ( + "io" + + "github.com/gobwas/ws" +) + +type CipherReader struct { + r io.Reader + mask []byte + pos int +} + +func NewCipherReader(r io.Reader, mask []byte) *CipherReader { + return &CipherReader{r, mask, 0} +} + +func (c *CipherReader) Reset(r io.Reader, mask []byte) { + c.r = r + c.mask = mask + c.pos = 0 +} + +func (c *CipherReader) Read(p []byte) (n int, err error) { + n, err = c.r.Read(p) + ws.Cipher(p[:n], c.mask, c.pos) + c.pos += n + return +} diff --git a/wsutil/cipher_test.go b/wsutil/cipher_test.go new file mode 100644 index 0000000..7fcd03c --- /dev/null +++ b/wsutil/cipher_test.go @@ -0,0 +1,50 @@ +package wsutil + +import ( + "bytes" + "fmt" + "io/ioutil" + "reflect" + "testing" + + . "github.com/gobwas/ws" +) + +func TestCipherReader(t *testing.T) { + for i, test := range []struct { + label string + data []byte + chop int + }{ + { + label: "simple", + data: []byte("hello, websockets!"), + chop: 512, + }, + { + label: "chopped", + data: []byte("hello, websockets!"), + chop: 3, + }, + } { + t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { + mask := NewMask() + masked := make([]byte, len(test.data)) + copy(masked, test.data) + Cipher(masked, mask, 0) + + src := &chopReader{bytes.NewReader(masked), test.chop} + rd := NewCipherReader(src, mask) + + bts, err := ioutil.ReadAll(rd) + if err != nil { + t.Errorf("unexpected error: %s", err) + return + } + if !reflect.DeepEqual(bts, test.data) { + t.Errorf("read data is not equal:\n\tact:\t%#v\n\texp:\t%#x\n", bts, test.data) + return + } + }) + } +} diff --git a/wsutil/reader.go b/wsutil/reader.go new file mode 100644 index 0000000..f611af5 --- /dev/null +++ b/wsutil/reader.go @@ -0,0 +1,155 @@ +package wsutil + +import ( + "io" + "io/ioutil" + "strconv" + + "github.com/gobwas/pool/pbytes" + "github.com/gobwas/ws" +) + +type FrameHandler func(h ws.Header, r io.Reader) error + +func ControlHandler(w io.Writer, state ws.State) FrameHandler { + return func(h ws.Header, rd io.Reader) (err error) { + // int(h.Length) is safe cause control frame could be < 125 bytes length. + p := pbytes.GetBufLen(int(h.Length)) + defer pbytes.PutBuf(p) + + _, err = io.ReadFull(rd, p) + if err != nil { + return + } + + var f ws.Frame + + switch h.OpCode { + default: + return + case ws.OpPing: + f = ws.NewPongFrame(p) + case ws.OpClose: + code, reason := ws.ParseCloseFrameDataUnsafe(p) + if code.Empty() { + code = ws.StatusNoStatusRcvd + f = ws.CloseFrame + } else if err = ws.CheckCloseFrameData(code, reason); err != nil { + code = ws.StatusProtocolError + reason = err.Error() + f = ws.NewCloseFrame(code, reason) + } else { + // [RFC6455:5.5.1]: + // If an endpoint receives a Close frame and did not previously + // send a Close frame, the endpoint MUST send a Close frame in + // response. (When sending a Close frame in response, the endpoint + // typically echos the status code it received.) + f = ws.NewCloseFrame(code, "") + } + err = ErrClosed{code, reason} + } + + if state.Is(ws.StateClientSide) { + f = ws.MaskFrame(f) + } + if ew := ws.WriteFrame(w, f); ew != nil { + err = ew + } + + return + } +} + +func NextReader(r io.Reader, s ws.State) (h ws.Header, rd *Reader, err error) { + rd = NewReader(r, s) + h, err = rd.Next() + return +} + +type handler struct { + continuation FrameHandler + intermediate FrameHandler +} + +type Reader struct { + src io.Reader + state ws.State + payload io.Reader + handler handler +} + +type ErrClosed struct { + code ws.StatusCode + reason string +} + +func (err ErrClosed) Error() string { + return "ws closed: " + strconv.FormatUint(uint64(err.code), 10) + " " + err.reason +} + +func NewReader(r io.Reader, s ws.State) *Reader { + return &Reader{ + src: r, + state: s, + } +} + +func (r *Reader) HandleContinuation(h FrameHandler) { r.handler.continuation = h } +func (r *Reader) HandleIntermediate(h FrameHandler) { r.handler.intermediate = h } + +func (r *Reader) Read(p []byte) (n int, err error) { + if r.payload == nil { + _, err = r.Next() + if err != nil || r.payload == nil { + return + } + } + + n, err = r.payload.Read(p) + + if err == io.EOF { + r.payload = nil + if r.state.Is(ws.StateFragmented) { + err = nil + } + } + + return +} + +func (r *Reader) Next() (h ws.Header, err error) { + h, err = ws.ReadHeader(r.src) + if err != nil { + return + } + if err = ws.CheckHeader(h, r.state); err != nil { + return + } + + src := io.LimitReader(r.src, h.Length) + rd := src + if mask := h.Mask; mask != nil { + rd = NewCipherReader(rd, mask) + } + + if r.state.Is(ws.StateFragmented) && h.OpCode.IsControl() { + if hi := r.handler.intermediate; hi != nil { + err = hi(h, rd) + } + if err == nil { + _, err = io.Copy(ioutil.Discard, src) + } + return + } + + if h.OpCode == ws.OpContinuation { + if hc := r.handler.continuation; hc != nil { + err = hc(h, rd) + } + } + + r.state = r.state.SetOrClearIf(!h.Fin, ws.StateFragmented) + r.payload = rd + + return +} diff --git a/wsutil/reader_test.go b/wsutil/reader_test.go new file mode 100644 index 0000000..4c8f199 --- /dev/null +++ b/wsutil/reader_test.go @@ -0,0 +1,140 @@ +package wsutil + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "testing" + + . "github.com/gobwas/ws" +) + +func TestReader(t *testing.T) { + for i, test := range []struct { + label string + seq []Frame + chop int + exp []byte + err error + }{ + { + label: "empty", + seq: []Frame{}, + err: io.EOF, + }, + { + label: "single", + seq: []Frame{ + NewTextFrame("Привет, Мир!"), + }, + exp: []byte("Привет, Мир!"), + }, + { + label: "single_masked", + seq: []Frame{ + MaskFrame(NewTextFrame("Привет, Мир!")), + }, + exp: []byte("Привет, Мир!"), + }, + { + label: "fragmented", + seq: []Frame{ + NewFrame(OpText, false, []byte("Привет,")), + NewFrame(OpContinuation, false, []byte(" о дивный,")), + NewFrame(OpContinuation, false, []byte(" новый ")), + NewFrame(OpContinuation, true, []byte("Мир!")), + + NewTextFrame("Hello, Brave New World!"), + }, + exp: []byte("Привет, о дивный, новый Мир!"), + }, + { + label: "fragmented_masked", + seq: []Frame{ + MaskFrame(NewFrame(OpText, false, []byte("Привет,"))), + MaskFrame(NewFrame(OpContinuation, false, []byte(" о дивный,"))), + MaskFrame(NewFrame(OpContinuation, false, []byte(" новый "))), + MaskFrame(NewFrame(OpContinuation, true, []byte("Мир!"))), + + MaskFrame(NewTextFrame("Hello, Brave New World!")), + }, + exp: []byte("Привет, о дивный, новый Мир!"), + }, + { + label: "fragmented_and_control", + seq: []Frame{ + NewFrame(OpText, false, []byte("Привет,")), + NewFrame(OpPing, true, nil), + NewFrame(OpContinuation, false, []byte(" о дивный,")), + NewFrame(OpPing, true, nil), + NewFrame(OpContinuation, false, []byte(" новый ")), + NewFrame(OpPing, true, nil), + NewFrame(OpPing, true, []byte("ping info")), + NewFrame(OpContinuation, true, []byte("Мир!")), + }, + exp: []byte("Привет, о дивный, новый Мир!"), + }, + { + label: "fragmented_and_control_mask", + seq: []Frame{ + MaskFrame(NewFrame(OpText, false, []byte("Привет,"))), + MaskFrame(NewFrame(OpPing, true, nil)), + MaskFrame(NewFrame(OpContinuation, false, []byte(" о дивный,"))), + MaskFrame(NewFrame(OpPing, true, nil)), + MaskFrame(NewFrame(OpContinuation, false, []byte(" новый "))), + MaskFrame(NewFrame(OpPing, true, nil)), + MaskFrame(NewFrame(OpPing, true, []byte("ping info"))), + MaskFrame(NewFrame(OpContinuation, true, []byte("Мир!"))), + }, + exp: []byte("Привет, о дивный, новый Мир!"), + }, + } { + t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { + // Prepare input. + buf := &bytes.Buffer{} + for _, f := range test.seq { + if err := WriteFrame(buf, f); err != nil { + t.Fatal(err) + } + } + + conn := &chopReader{ + src: bytes.NewReader(buf.Bytes()), + sz: test.chop, + } + + var bts []byte + _, reader, err := NextReader(conn, 0) + if err == nil { + bts, err = ioutil.ReadAll(reader) + } + if err != test.err { + t.Errorf("unexpected error; got %v; want %v", err, test.err) + return + } + if test.err == nil && !bytes.Equal(bts, test.exp) { + t.Errorf( + "ReadAll from reader:\nact:\t%#x\nexp:\t%#x\nact:\t%s\nexp:\t%s\n", + bts, test.exp, string(bts), string(test.exp), + ) + } + }) + } +} + +type chopReader struct { + src io.Reader + sz int +} + +func (c chopReader) Read(p []byte) (n int, err error) { + sz := c.sz + if sz == 0 { + sz = 1 + } + if sz > len(p) { + sz = len(p) + } + return c.src.Read(p[:sz]) +} diff --git a/wsutil/utf8.go b/wsutil/utf8.go new file mode 100644 index 0000000..e304b81 --- /dev/null +++ b/wsutil/utf8.go @@ -0,0 +1,116 @@ +package wsutil + +import ( + "fmt" + "io" +) + +var ErrInvalidUtf8 = fmt.Errorf("invalid utf8") + +type UTF8Reader struct { + r io.Reader + state uint32 + codep uint32 + buf []byte +} + +func NewUTF8Reader(r io.Reader) *UTF8Reader { + return &UTF8Reader{r: r} +} + +func (u *UTF8Reader) Reset(r io.Reader) { + u.state = 0 + u.codep = 0 + u.SetSource(r) +} + +func (u *UTF8Reader) SetSource(r io.Reader) { + u.r = r +} + +func (u *UTF8Reader) Read(p []byte) (n int, err error) { + n, err = u.r.Read(p) + + s, c := u.state, u.codep + for i := 0; i < n; i++ { + c, s = decode(s, c, p[i]) + if s == utf8Reject { + u.state = s + return i, ErrInvalidUtf8 + } + } + u.state, u.codep = s, c + + if err == io.EOF && u.state != utf8Accept { + err = ErrInvalidUtf8 + } + + return +} + +func (u *UTF8Reader) Close() error { + if u.state != utf8Accept { + return ErrInvalidUtf8 + } + return nil +} + +// Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ +// +// Copyright (c) 2008-2009 Bjoern Hoehrmann +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. + +const ( + utf8Accept = 0 + utf8Reject = 12 +) + +var utf8d = [...]byte{ + // The first part of the table maps bytes to character classes that + // to reduce the size of the transition table and create bitmasks. + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + + // The second part is a transition table that maps a combination + // of a state of the automaton and a character class to a state. + 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, + 12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, +} + +func decode(state, codep uint32, b byte) (uint32, uint32) { + t := uint32(utf8d[b]) + + if state != utf8Accept { + codep = (uint32(b) & 0x3f) | (codep << 6) + } else { + codep = (0xff >> t) & uint32(b) + } + + return codep, uint32(utf8d[256+state+t]) +} diff --git a/wsutil/utf8_test.go b/wsutil/utf8_test.go new file mode 100644 index 0000000..9a4d8fa --- /dev/null +++ b/wsutil/utf8_test.go @@ -0,0 +1,208 @@ +package wsutil + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "io/ioutil" + "testing" +) + +func TestUTF8ReaderReadFull(t *testing.T) { + for i, test := range []struct { + hex string + errRead bool + errClose bool + n int + chop int + }{ + { + hex: "cebae1bdb9cf83cebcceb5eda080656469746564", + errClose: true, + errRead: true, + n: 12, + }, + { + hex: "cebae1bdb9cf83cebcceb5eda080656469746564", + errRead: true, + errClose: true, + n: 12, + chop: 1, + }, + { + hex: "7f7f7fdf", + errRead: false, + errClose: true, + n: 4, + }, + { + hex: "dfbf", + n: 2, + }, + } { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + bts, err := hex.DecodeString(test.hex) + if err != nil { + t.Fatal(err) + } + + chop := test.chop + if chop <= 0 { + chop = len(bts) + } + + src := bytes.NewReader(bts) + r := NewUTF8Reader(chopReader{src, chop}) + + p := make([]byte, src.Len()) + n, err := io.ReadFull(r, p) + + if test.errRead && err == nil { + t.Errorf("expected read error; got nil") + } + if !test.errRead && err != nil { + t.Errorf("unexpected read error: %s", err) + } + if n != test.n { + t.Errorf("ReadFull() read %d; want %d", n, test.n) + } + + err = r.Close() + if test.errClose && err == nil { + t.Errorf("expected close error; got nil") + } + if !test.errClose && err != nil { + t.Errorf("unexpected close error: %s", err) + } + }) + } +} + +func TestUTF8Reader(t *testing.T) { + for i, test := range []struct { + label string + + data []byte + // or + hex string + + chop int + + err bool + at int + }{ + { + data: []byte("hello, world!"), + chop: 2, + }, + { + data: []byte{0x7f, 0xf0}, + err: true, + at: 2, + chop: 1, + }, + { + data: []byte{0x7f, 0xf0}, + err: true, + at: 2, + chop: 1, + }, + { + hex: "48656c6c6f2dc2b540c39fc3b6c3a4c3bcc3a0c3a12d5554462d382121", + chop: 1, + }, + } { + t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { + data := test.data + if h := test.hex; h != "" { + var err error + if data, err = hex.DecodeString(h); err != nil { + t.Fatal(err) + } + } + + cr := &chopReader{ + src: bytes.NewReader(data), + sz: test.chop, + } + + r := NewUTF8Reader(cr) + + bts := make([]byte, 2*len(data)) + + var ( + i, n int + err error + ) + for { + n, err = r.Read(bts[i:]) + i += n + if err != nil { + if err == io.EOF { + err = nil + } + bts = bts[:i] + break + } + } + if err == nil { + err = r.Close() + } + if test.err && err == nil { + t.Errorf("want error; got nil") + return + } + if !test.err && err != nil { + t.Errorf("unexpected error: %s", err) + return + } + if test.err && err == ErrInvalidUtf8 && i != test.at { + t.Errorf("received error at %d; want at %d", i, test.at) + return + } + if !test.err && !bytes.Equal(bts, data) { + t.Errorf("bytes are not equal") + } + }) + } +} + +func BenchmarkUTF8Reader(b *testing.B) { + for i, bench := range []struct { + label string + data []byte + chop int + err bool + }{ + { + data: bytes.Repeat([]byte("x"), 1024), + chop: 128, + }, + { + data: append( + bytes.Repeat([]byte("x"), 1024), + append( + []byte{0x7f, 0xf0}, + bytes.Repeat([]byte("x"), 128)..., + )..., + ), + err: true, + chop: 7, + }, + } { + b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) { + for i := 0; i < b.N; i++ { + cr := &chopReader{ + src: bytes.NewReader(bench.data), + sz: bench.chop, + } + r := NewUTF8Reader(cr) + _, err := ioutil.ReadAll(r) + if !bench.err && err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/wsutil/writer.go b/wsutil/writer.go new file mode 100644 index 0000000..39fee88 --- /dev/null +++ b/wsutil/writer.go @@ -0,0 +1,162 @@ +package wsutil + +import ( + "io" + + "github.com/gobwas/pool/pbytes" + "github.com/gobwas/ws" +) + +const defaultWriteBuffer = 4096 + +type WriterConfig struct { + Op ws.OpCode + Mask bool +} + +type Writer struct { + wr io.Writer + buf []byte + n int + + dirty bool + frames int + + op ws.OpCode + mask bool +} + +func NextWriter(dst io.Writer, op ws.OpCode, mask bool) *Writer { + return NewWriterSize(dst, 0, WriterConfig{Op: op, Mask: mask}) +} + +func NewWriter(dst io.Writer, c WriterConfig) *Writer { + return NewWriterSize(dst, defaultWriteBuffer, c) +} + +func NewWriterSize(dst io.Writer, n int, c WriterConfig) *Writer { + if n <= 0 { + n = defaultWriteBuffer + } + return NewWriterBuffer(dst, make([]byte, n), c) +} + +func NewWriterBuffer(wr io.Writer, buf []byte, c WriterConfig) *Writer { + return &Writer{ + wr: wr, + buf: buf, + op: c.Op, + mask: c.Mask, + } +} + +func (w *Writer) Write(p []byte) (n int, err error) { + // Even if len(p) == 0 we mark w as dirty, + // cause even empty p (and empty frame) may have a value. + w.dirty = true + + if len(p) > len(w.buf) && w.n == 0 { + // Large write. + return w.write(p) + } + for { + nn := copy(w.buf[w.n:], p) + p = p[nn:] + w.n += nn + n += nn + + if len(p) == 0 { + break + } + + _, err = w.write(w.buf) + if err != nil { + break + } + w.n = 0 + } + return +} + +func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) { + var nn int + for { + if w.n == len(w.buf) { // buffer is full. + if _, err = w.write(w.buf); err != nil { + return + } + w.n = 0 + } + + nn, err = src.Read(w.buf[w.n:]) + w.n += nn + n += int64(nn) + w.dirty = true + + if err != nil { + break + } + } + if err == io.EOF { + err = nil + } + return +} + +func (w *Writer) Flush() error { + _, err := w.flush() + return err +} + +func (w *Writer) opCode() ws.OpCode { + if w.frames > 0 { + return ws.OpContinuation + } else { + return w.op + } +} + +func (w *Writer) flush() (n int, err error) { + if w.n == 0 && !w.dirty { + return 0, nil + } + + n, err = w.writeFrame(w.opCode(), w.buf[:w.n], true) + w.dirty = false + w.n = 0 + w.frames = 0 + + return +} + +func (w *Writer) write(p []byte) (n int, err error) { + return w.writeFrame(w.opCode(), p, false) +} + +func (w *Writer) writeFrame(op ws.OpCode, p []byte, fin bool) (n int, err error) { + header := ws.Header{ + OpCode: op, + Length: int64(len(p)), + Fin: fin, + } + + payload := p + if w.mask { + header.Mask = ws.NewMask() + + payload = pbytes.GetBufLen(len(p)) + defer pbytes.PutBuf(payload) + + copy(payload, p) + ws.Cipher(payload, header.Mask, 0) + } + + err = ws.WriteHeader(w.wr, header) + if err == nil { + n, err = w.wr.Write(payload) + } + + w.frames++ + + return +} diff --git a/wsutil/writer_test.go b/wsutil/writer_test.go new file mode 100644 index 0000000..1247be3 --- /dev/null +++ b/wsutil/writer_test.go @@ -0,0 +1,252 @@ +package wsutil + +import ( + "bytes" + "fmt" + "io" + "reflect" + "testing" + + . "github.com/gobwas/ws" +) + +func TestWriter(t *testing.T) { + for i, test := range []struct { + label string + config WriterConfig + size int + data [][]byte + expFrm []Frame + expBts []byte + }{ + { + config: WriterConfig{Op: OpText}, + }, + { + config: WriterConfig{Op: OpText}, + data: [][]byte{ + []byte{}, + }, + expBts: MustCompileFrame(NewTextFrame("")), + }, + { + config: WriterConfig{Op: OpText}, + data: [][]byte{ + []byte("hello, world!"), + }, + expBts: MustCompileFrame(NewTextFrame("hello, world!")), + }, + { + config: WriterConfig{Op: OpText, Mask: true}, + data: [][]byte{ + []byte("hello, world!"), + }, + expFrm: []Frame{NewTextFrame("hello, world!")}, + }, + { + config: WriterConfig{Op: OpText}, + size: 5, + data: [][]byte{ + []byte("hello"), + []byte(", wor"), + []byte("ld!"), + }, + expBts: bytes.Join( + bts( + MustCompileFrame(Frame{ + Header: Header{ + Fin: false, + OpCode: OpText, + Length: 5, + }, + Payload: []byte("hello"), + }), + MustCompileFrame(Frame{ + Header: Header{ + Fin: false, + OpCode: OpContinuation, + Length: 5, + }, + Payload: []byte(", wor"), + }), + MustCompileFrame(Frame{ + Header: Header{ + Fin: true, + OpCode: OpContinuation, + Length: 3, + }, + Payload: []byte("ld!"), + }), + ), + nil, + ), + }, + { // Large write case. + config: WriterConfig{Op: OpText}, + size: 5, + data: [][]byte{ + []byte("hello, world!"), + }, + expBts: bytes.Join( + bts( + MustCompileFrame(Frame{ + Header: Header{ + Fin: false, + OpCode: OpText, + Length: 13, + }, + Payload: []byte("hello, world!"), + }), + MustCompileFrame(Frame{ + Header: Header{ + Fin: true, + OpCode: OpContinuation, + Length: 0, + }, + }), + ), + nil, + ), + }, + } { + t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { + buf := &bytes.Buffer{} + w := NewWriterSize(buf, test.size, test.config) + + for _, p := range test.data { + _, err := w.Write(p) + if err != nil { + t.Fatalf("unexpected Write() error: %s", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("unexpected Flush() error: %s", err) + } + if test.expBts != nil { + if bts := buf.Bytes(); !bytes.Equal(test.expBts, bts) { + t.Errorf( + "wrote bytes:\nact:\t%#x\nexp:\t%#x\nacth:\t%s\nexph:\t%s\n", bts, test.expBts, + pretty(frames(bts)), pretty(frames(test.expBts)), + ) + } + } + if test.expFrm != nil { + act := omitMask(frames(buf.Bytes())) + exp := omitMask(test.expFrm) + + if !reflect.DeepEqual(act, exp) { + t.Errorf( + "wrote frames (mask omitted):\nact:\t%s\nexp:\t%s\n", + pretty(act), pretty(exp), + ) + } + } + }) + } +} + +func TestWriterReadFrom(t *testing.T) { + for i, test := range []struct { + label string + chop int + size int + data []byte + exp []Frame + n int64 + }{ + { + chop: 1, + size: 1, + data: []byte("golang"), + exp: []Frame{ + Frame{Header: Header{Fin: false, Length: 1, OpCode: OpText}, Payload: []byte{'g'}}, + Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'o'}}, + Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'l'}}, + Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'a'}}, + Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'n'}}, + Frame{Header: Header{Fin: false, Length: 1, OpCode: OpContinuation}, Payload: []byte{'g'}}, + Frame{Header: Header{Fin: true, Length: 0, OpCode: OpContinuation}}, + }, + n: 6, + }, + { + chop: 1, + size: 4, + data: []byte("golang"), + exp: []Frame{ + Frame{Header: Header{Fin: false, Length: 4, OpCode: OpText}, Payload: []byte("gola")}, + Frame{Header: Header{Fin: true, Length: 2, OpCode: OpContinuation}, Payload: []byte("ng")}, + }, + n: 6, + }, + { + size: 64, + data: []byte{}, + exp: []Frame{ + Frame{Header: Header{Fin: true, Length: 0, OpCode: OpText}}, + }, + n: 0, + }, + } { + t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { + dst := &bytes.Buffer{} + wr := NewWriterSize(dst, test.size, WriterConfig{Op: OpText}) + + chop := test.chop + if chop == 0 { + chop = 128 + } + src := &chopReader{bytes.NewReader(test.data), chop} + + n, err := wr.ReadFrom(src) + if err == nil { + err = wr.Flush() + } + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if n != test.n { + t.Errorf("ReadFrom() read out %d; want %d", n, test.n) + } + if frames := frames(dst.Bytes()); !reflect.DeepEqual(frames, test.exp) { + t.Errorf("ReadFrom() read frames:\n\tact:\t%s\n\texp:\t%s\n", pretty(frames), pretty(test.exp)) + } + }) + } +} + +func frames(p []byte) (ret []Frame) { + r := bytes.NewReader(p) + for stop := false; !stop; { + f, err := ReadFrame(r) + if err != nil { + if err == io.EOF { + break + } + panic(err) + + } + if mask := f.Header.Mask; mask != nil { + Cipher(f.Payload, mask, 0) + } + ret = append(ret, f) + } + return +} + +func pretty(f []Frame) string { + str := "\n" + for _, f := range f { + str += fmt.Sprintf("\t%#v\n\t%#x (%s)\n\t----\n", f.Header, f.Payload, f.Payload) + } + return str +} + +func omitMask(f []Frame) []Frame { + for i := 0; i < len(f); i++ { + f[i].Header.Mask = []byte{0, 0, 0, 0} + } + return f +} + +func bts(b ...[]byte) [][]byte { return b }