From 8a8329be6851394829fe7b50f97f580e2b86f0d0 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 18 Jan 2018 01:41:59 +0900 Subject: [PATCH 01/52] modify: header option --- server.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/server.go b/server.go index 9a9dfebb..6f949d7f 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ package websocket import ( "bufio" "errors" + "fmt" "net" "net/http" "net/url" @@ -142,14 +143,25 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade subprotocol := u.selectSubprotocol(r, responseHeader) // Negotiate PMCE - var compress bool + var ( + compress bool + contextTakeover bool + ) if u.EnableCompression { for _, ext := range parseExtensions(r.Header) { - if ext[""] != "permessage-deflate" { + // map[string]string{"":"permessage-deflate", "client_max_window_bits":""} + // context-takeoverをclient_max_window_bitsから判定する + fmt.Printf("%#v\n", ext) + if ext[""] == "permessage-deflate" { + compress = true + continue + } + + if _, ok := ext["client_max_window_bits"]; ok { + // Todo: validation. window size level only allow 15. + contextTakeover = true continue } - compress = true - break } } @@ -177,6 +189,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.subprotocol = subprotocol if compress { + if contextTakeover { + + } c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover } From 053a62020704aad3013ac768f0e160ffccab485a Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 19 Jan 2018 01:43:04 +0900 Subject: [PATCH 02/52] add: validation --- server.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/server.go b/server.go index 6f949d7f..3953c9e5 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "time" ) @@ -150,15 +151,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if u.EnableCompression { for _, ext := range parseExtensions(r.Header) { // map[string]string{"":"permessage-deflate", "client_max_window_bits":""} - // context-takeoverをclient_max_window_bitsから判定する + // detect context-takeover from client_max_window_bits fmt.Printf("%#v\n", ext) if ext[""] == "permessage-deflate" { compress = true continue } - if _, ok := ext["client_max_window_bits"]; ok { - // Todo: validation. window size level only allow 15. + if level, ok := ext["client_max_window_bits"]; ok { + l, err := strconv.Atoi(level) + if err != nil { + return nil, errors.New(err.Error()) + } + if l != 15 { + return u.returnError(w, r, http.StatusBadRequest, "client_max_window_bits level only allow 15.") + } contextTakeover = true continue } From 785fa70cb3b11738f45b29b18acacefa17f15b5f Mon Sep 17 00:00:00 2001 From: misu Date: Sun, 21 Jan 2018 15:10:50 +0900 Subject: [PATCH 03/52] add server header option --- server.go | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 3953c9e5..d34bd2a8 100644 --- a/server.go +++ b/server.go @@ -11,7 +11,6 @@ import ( "net" "net/http" "net/url" - "strconv" "strings" "time" ) @@ -152,22 +151,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade for _, ext := range parseExtensions(r.Header) { // map[string]string{"":"permessage-deflate", "client_max_window_bits":""} // detect context-takeover from client_max_window_bits - fmt.Printf("%#v\n", ext) if ext[""] == "permessage-deflate" { compress = true - continue } - if level, ok := ext["client_max_window_bits"]; ok { - l, err := strconv.Atoi(level) - if err != nil { - return nil, errors.New(err.Error()) - } - if l != 15 { - return u.returnError(w, r, http.StatusBadRequest, "client_max_window_bits level only allow 15.") - } + if _, ok := ext["client_max_window_bits"]; ok { contextTakeover = true - continue } } } @@ -196,11 +185,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.subprotocol = subprotocol if compress { - if contextTakeover { - + switch { + case contextTakeover: + fmt.Println("contextTakeover strategy is set...") + c.newCompressionWriter = compressContextTakeover + c.newDecompressionReader = decompressContextTakeover + default: + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover } - c.newCompressionWriter = compressNoContextTakeover - c.newDecompressionReader = decompressNoContextTakeover } p := c.writeBuf[:0] @@ -213,7 +206,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade p = append(p, "\r\n"...) } if compress { - p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + switch { + case contextTakeover: + p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...) + default: + p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { From 9177fe8645a7a14acfbbc220e8753d0effec9d9d Mon Sep 17 00:00:00 2001 From: misu Date: Sun, 21 Jan 2018 15:11:52 +0900 Subject: [PATCH 04/52] mod: comment out --- server.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/server.go b/server.go index d34bd2a8..8c20621d 100644 --- a/server.go +++ b/server.go @@ -7,7 +7,6 @@ package websocket import ( "bufio" "errors" - "fmt" "net" "net/http" "net/url" @@ -187,7 +186,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if compress { switch { case contextTakeover: - fmt.Println("contextTakeover strategy is set...") c.newCompressionWriter = compressContextTakeover c.newDecompressionReader = decompressContextTakeover default: From e0da4e377f2ba19af8898209fdab7e3cd4184c6d Mon Sep 17 00:00:00 2001 From: misu Date: Sun, 21 Jan 2018 17:56:13 +0900 Subject: [PATCH 05/52] mod: .gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ac710204..b13d2d78 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ _testmain.go *.exe .idea/ -*.iml \ No newline at end of file +*.iml +.vscode/ \ No newline at end of file From 36c43970ee5272c5f8a6f1dcf5229348062fcb04 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 24 Jan 2018 16:52:47 +0900 Subject: [PATCH 06/52] impl: decompressContextTakeover --- compression.go | 30 ++++++++++++++++++++++++++++-- conn.go | 42 ++++++++++++++++++++++++++++++++++++++---- server.go | 1 + 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/compression.go b/compression.go index 813ffb1e..64cbb447 100644 --- a/compression.go +++ b/compression.go @@ -25,7 +25,7 @@ var ( }} ) -func decompressNoContextTakeover(r io.Reader) io.ReadCloser { +func decompressNoContextTakeover(r io.Reader, b []byte) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + @@ -37,6 +37,18 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { return &flateReadWrapper{fr} } +func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict) + return &flateReadWrapper{fr} +} + func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } @@ -53,6 +65,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { return &flateWriteWrapper{fw: fw, tw: tw, p: p} } +func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. type truncWriter struct { @@ -120,14 +144,16 @@ func (w *flateWriteWrapper) Close() error { } type flateReadWrapper struct { - fr io.ReadCloser + fr io.ReadCloser // flate.NewReader } func (r *flateReadWrapper) Read(p []byte) (int, error) { if r.fr == nil { return 0, io.ErrClosedPipe } + n, err := r.fr.Read(p) + if err == io.EOF { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after diff --git a/conn.go b/conn.go index cd3569d5..cf4a9af2 100644 --- a/conn.go +++ b/conn.go @@ -38,6 +38,8 @@ const ( continuationFrame = 0 noFrame = -1 + + maxWindowBits = 1 << 15 ) // Close codes defined in RFC 6455, section 11.7. @@ -259,8 +261,12 @@ type Conn struct { readErrCount int messageReader *messageReader // the current low-level reader - readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader) io.ReadCloser + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct + + contextTakeover bool + dict []byte + mutex sync.RWMutex } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -945,9 +951,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader - if c.readDecompress { - c.reader = c.newDecompressionReader(c.reader) + + switch { + case c.readDecompress && c.contextTakeover: + c.reader = c.newDecompressionReader(c.reader, c.dict) + case c.readDecompress: + c.reader = c.newDecompressionReader(c.reader, nil) } + return frameType, c.reader, nil } } @@ -974,9 +985,11 @@ func (r *messageReader) Read(b []byte) (int, error) { for c.readErr == nil { if c.readRemaining > 0 { + // Determine the size of the data to be read. if int64(len(b)) > c.readRemaining { b = b[:c.readRemaining] } + n, err := c.br.Read(b) c.readErr = hideTempErr(err) if c.isServer { @@ -986,6 +999,7 @@ func (r *messageReader) Read(b []byte) (int, error) { if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } + return n, c.readErr } @@ -1023,6 +1037,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { return messageType, nil, err } p, err = ioutil.ReadAll(r) + + // if context-takeover add payload to dictionary + if c.contextTakeover { + c.AddDict(p) + } + return messageType, p, err } @@ -1139,6 +1159,20 @@ func (c *Conn) SetCompressionLevel(level int) error { return nil } +func (c *Conn) AddDict(b []byte) { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Todo I do not know whether to leave the dictionary with 32768 bytes or more + // If it is recognized as a duplicate character string, + // deleting a part of the character may make it impossible to decrypt it. + c.dict = append(b, c.dict...) + + if len(c.dict) > maxWindowBits { + c.dict = c.dict[:maxWindowBits] + } +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { diff --git a/server.go b/server.go index 8c20621d..67c82b7d 100644 --- a/server.go +++ b/server.go @@ -186,6 +186,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if compress { switch { case contextTakeover: + c.contextTakeover = contextTakeover c.newCompressionWriter = compressContextTakeover c.newDecompressionReader = decompressContextTakeover default: From a366cdf6168ee6d9c12d2babca11f8e7b2a4cf0e Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 24 Jan 2018 17:53:38 +0900 Subject: [PATCH 07/52] impl: compressContextTakeover --- compression.go | 15 +++++++-------- conn.go | 14 +++++++++++--- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/compression.go b/compression.go index 64cbb447..4c5f0aad 100644 --- a/compression.go +++ b/compression.go @@ -53,7 +53,7 @@ func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } -func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { +func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} fw, _ := p.Get().(*flate.Writer) @@ -65,15 +65,14 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { return &flateWriteWrapper{fw: fw, tw: tw, p: p} } -func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { +func compressContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} - fw, _ := p.Get().(*flate.Writer) - if fw == nil { - fw, _ = flate.NewWriter(tw, level) - } else { - fw.Reset(tw) - } + + // WriterDict's Reset just restores the dictionary. + // Initialization is done with New. (If possible get struct from sync.Pool) + fw, _ := flate.NewWriterDict(tw, level, dict) + return &flateWriteWrapper{fw: fw, tw: tw, p: p} } diff --git a/conn.go b/conn.go index cf4a9af2..c740e013 100644 --- a/conn.go +++ b/conn.go @@ -243,7 +243,7 @@ type Conn struct { enableWriteCompression bool compressionLevel int - newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + newCompressionWriter func(io.WriteCloser, int, []byte) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -505,9 +505,14 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { - w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true - c.writer = w + switch { + case c.contextTakeover: + c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, c.dict) + // no-context-takeover + default: + c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, nil) + } } return c.writer, nil } @@ -758,6 +763,9 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if _, err = w.Write(data); err != nil { return err } + if c.contextTakeover { + c.AddDict(data) + } return w.Close() } From f7abc95255075dcca87e133541db7ca42cf40104 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 24 Jan 2018 18:01:48 +0900 Subject: [PATCH 08/52] mod: method args name --- compression.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compression.go b/compression.go index 4c5f0aad..d2084ec7 100644 --- a/compression.go +++ b/compression.go @@ -25,7 +25,7 @@ var ( }} ) -func decompressNoContextTakeover(r io.Reader, b []byte) io.ReadCloser { +func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + From eec63bf9bc327b88b0fa34461b897bfd5913cfae Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 24 Jan 2018 18:08:43 +0900 Subject: [PATCH 09/52] mod: split dict to rx and tx --- conn.go | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index c740e013..46af7726 100644 --- a/conn.go +++ b/conn.go @@ -265,7 +265,8 @@ type Conn struct { newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct contextTakeover bool - dict []byte + txDict []byte + rxDict []byte mutex sync.RWMutex } @@ -508,7 +509,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { mw.compress = true switch { case c.contextTakeover: - c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, c.dict) + c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, c.txDict) // no-context-takeover default: c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, nil) @@ -764,7 +765,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { return err } if c.contextTakeover { - c.AddDict(data) + c.AddTxDict(data) } return w.Close() } @@ -962,7 +963,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { switch { case c.readDecompress && c.contextTakeover: - c.reader = c.newDecompressionReader(c.reader, c.dict) + c.reader = c.newDecompressionReader(c.reader, c.rxDict) case c.readDecompress: c.reader = c.newDecompressionReader(c.reader, nil) } @@ -1048,7 +1049,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { // if context-takeover add payload to dictionary if c.contextTakeover { - c.AddDict(p) + c.AddRxDict(p) } return messageType, p, err @@ -1167,17 +1168,31 @@ func (c *Conn) SetCompressionLevel(level int) error { return nil } -func (c *Conn) AddDict(b []byte) { +func (c *Conn) AddTxDict(b []byte) { c.mutex.Lock() defer c.mutex.Unlock() // Todo I do not know whether to leave the dictionary with 32768 bytes or more // If it is recognized as a duplicate character string, // deleting a part of the character may make it impossible to decrypt it. - c.dict = append(b, c.dict...) + c.txDict = append(b, c.txDict...) - if len(c.dict) > maxWindowBits { - c.dict = c.dict[:maxWindowBits] + if len(c.txDict) > maxWindowBits { + c.txDict = c.txDict[:maxWindowBits] + } +} + +func (c *Conn) AddRxDict(b []byte) { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Todo I do not know whether to leave the dictionary with 32768 bytes or more + // If it is recognized as a duplicate character string, + // deleting a part of the character may make it impossible to decrypt it. + c.rxDict = append(b, c.rxDict...) + + if len(c.rxDict) > maxWindowBits { + c.rxDict = c.rxDict[:maxWindowBits] } } From dfe35b789be4bdddad6d92f52c7d9095bb788bd6 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 24 Jan 2018 20:00:09 +0900 Subject: [PATCH 10/52] add: context-takeover option to client --- client.go | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 934e28e9..45363fdb 100644 --- a/client.go +++ b/client.go @@ -78,6 +78,11 @@ type Dialer struct { // takeover" modes are supported. EnableCompression bool + // EnableContextTakeover specifies specifies if the client should attempt to negotiate + // per message compression with context-takeover (RFC 7692). + // but window bits is allowed only 15, because go's flate library support 15 bits only. + EnableContextTakeover bool + // Jar specifies the cookie jar. // If Jar is nil, cookies are not sent in requests and ignored // in responses. @@ -196,7 +201,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - if d.EnableCompression { + switch { + case d.EnableCompression && d.EnableContextTakeover: + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_max_window_bits=15; client_max_window_bits=15") + case d.EnableCompression: req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") } @@ -307,13 +315,20 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re if ext[""] != "permessage-deflate" { continue } - _, snct := ext["server_no_context_takeover"] - _, cnct := ext["client_no_context_takeover"] - if !snct || !cnct { - return nil, resp, errInvalidCompression + + _, cmwb := ext["client_max_window_bits"] + _, smwb := ext["server_max_window_bits"] + + switch { + case cmwb && smwb: + conn.contextTakeover = true + conn.newCompressionWriter = compressContextTakeover + conn.newDecompressionReader = decompressContextTakeover + default: + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover } - conn.newCompressionWriter = compressNoContextTakeover - conn.newDecompressionReader = decompressNoContextTakeover + break } From cecdf7ced452efe19c46e5d5f19f341292b8d368 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 24 Jan 2018 23:09:03 +0900 Subject: [PATCH 11/52] mod: remove mutex. --- conn.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 46af7726..46c53840 100644 --- a/conn.go +++ b/conn.go @@ -267,7 +267,6 @@ type Conn struct { contextTakeover bool txDict []byte rxDict []byte - mutex sync.RWMutex } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -1168,13 +1167,8 @@ func (c *Conn) SetCompressionLevel(level int) error { return nil } +// AddTxDict adds payload to txDict. func (c *Conn) AddTxDict(b []byte) { - c.mutex.Lock() - defer c.mutex.Unlock() - - // Todo I do not know whether to leave the dictionary with 32768 bytes or more - // If it is recognized as a duplicate character string, - // deleting a part of the character may make it impossible to decrypt it. c.txDict = append(b, c.txDict...) if len(c.txDict) > maxWindowBits { @@ -1182,13 +1176,8 @@ func (c *Conn) AddTxDict(b []byte) { } } +// AddTxDict adds payload to rxDict. func (c *Conn) AddRxDict(b []byte) { - c.mutex.Lock() - defer c.mutex.Unlock() - - // Todo I do not know whether to leave the dictionary with 32768 bytes or more - // If it is recognized as a duplicate character string, - // deleting a part of the character may make it impossible to decrypt it. c.rxDict = append(b, c.rxDict...) if len(c.rxDict) > maxWindowBits { From 80a809717daf405db0272dfa2f427c491eba9de9 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 25 Jan 2018 13:55:52 +0900 Subject: [PATCH 12/52] mod: add dict process --- conn.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 46c53840..6ac7e755 100644 --- a/conn.go +++ b/conn.go @@ -1169,19 +1169,21 @@ func (c *Conn) SetCompressionLevel(level int) error { // AddTxDict adds payload to txDict. func (c *Conn) AddTxDict(b []byte) { - c.txDict = append(b, c.txDict...) + c.txDict = append(c.txDict, b...) if len(c.txDict) > maxWindowBits { - c.txDict = c.txDict[:maxWindowBits] + offset := len(c.txDict) - maxWindowBits + c.txDict = c.txDict[offset:] } } // AddTxDict adds payload to rxDict. func (c *Conn) AddRxDict(b []byte) { - c.rxDict = append(b, c.rxDict...) + c.rxDict = append(c.rxDict, b...) if len(c.rxDict) > maxWindowBits { - c.rxDict = c.rxDict[:maxWindowBits] + offset := len(c.rxDict) - maxWindowBits + c.rxDict = c.rxDict[offset:] } } From 70abf7287b83609f268fc9b9f75da91e0783e052 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 25 Jan 2018 19:01:58 +0900 Subject: [PATCH 13/52] mod: .gitignore to ignore .test file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b13d2d78..f0a5301e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ _testmain.go .idea/ *.iml -.vscode/ \ No newline at end of file +.vscode/ +*.test \ No newline at end of file From e873eb1ee058f08bfae0070a71c2984276694c6d Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 26 Jan 2018 17:52:23 +0900 Subject: [PATCH 14/52] mod: test and write compression --- client_server_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++ compression.go | 17 ++++++++----- compression_test.go | 14 +++++++++++ conn_test.go | 11 +++++++++ 4 files changed, 92 insertions(+), 6 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 50063b7e..97631eb3 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -142,6 +142,45 @@ func sendRecv(t *testing.T, ws *Conn) { } } +func multipleSendRecv(t *testing.T, ws *Conn) { + message := "Hello World!" + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + if string(p) != message { + t.Fatalf("message=%s, want %s", p, message) + } + + message_2 := "Can you read message?" + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message_2)); err != nil { + t.Fatalf("_WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("_SetReadDeadline: %v", err) + } + + _, p, err = ws.ReadMessage() + if err != nil { + t.Fatalf("_ReadMessage: %v", err) // _ReadMessage: websocket: close 1006 (abnormal closure): unexpected EOF + } + if string(p) != message { + t.Fatalf("_message=%s, want %s", p, message_2) + } +} + func TestProxyDial(t *testing.T) { s := newServer(t) @@ -522,6 +561,23 @@ func TestDialCompression(t *testing.T) { sendRecv(t, ws) } +func TestDialCompressionOfContextTakeover(t *testing.T) { + s := newServer(t) + defer s.Close() + + dialer := cstDialer + dialer.EnableCompression = true + dialer.EnableContextTakeover = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + // Todo multiple send and receive. + multipleSendRecv(t, ws) +} + func TestSocksProxyDial(t *testing.T) { s := newServer(t) defer s.Close() diff --git a/compression.go b/compression.go index d2084ec7..c34be891 100644 --- a/compression.go +++ b/compression.go @@ -19,8 +19,9 @@ const ( ) var ( - flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool - flateReaderPool = sync.Pool{New: func() interface{} { + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateWriterDictPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { return flate.NewReader(nil) }} ) @@ -66,12 +67,16 @@ func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.Writ } func compressContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { - p := &flateWriterPools[level-minCompressionLevel] + p := &flateWriterDictPools[level-minCompressionLevel] tw := &truncWriter{w: w} - // WriterDict's Reset just restores the dictionary. - // Initialization is done with New. (If possible get struct from sync.Pool) - fw, _ := flate.NewWriterDict(tw, level, dict) + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + // use WriterDict + fw, _ = flate.NewWriterDict(tw, level, dict) + } else { + fw.Reset(tw) + } return &flateWriteWrapper{fw: fw, tw: tw, p: p} } diff --git a/compression_test.go b/compression_test.go index 659cf421..aecb23f5 100644 --- a/compression_test.go +++ b/compression_test.go @@ -65,6 +65,20 @@ func BenchmarkWriteWithCompression(b *testing.B) { b.ReportAllocs() } +func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { + w := ioutil.Discard + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + messages := textMessages(100) + c.enableWriteCompression = true + c.contextTakeover = true + c.newCompressionWriter = compressContextTakeover + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + func TestValidCompressionLevel(t *testing.T) { c := newConn(fakeNetConn{}, false, 1024, 1024) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { diff --git a/conn_test.go b/conn_test.go index 5fda7b5c..5e101636 100644 --- a/conn_test.go +++ b/conn_test.go @@ -494,3 +494,14 @@ func TestBufioReuse(t *testing.T) { } } + +func BenchmarkAddDict(b *testing.B) { + w := ioutil.Discard + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + messages := textMessages(100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.AddRxDict(messages[i%len(messages)]) + } + b.ReportAllocs() +} From e2dd00db3d95ec014b57c6947a2e4161a2bf1997 Mon Sep 17 00:00:00 2001 From: misu Date: Mon, 29 Jan 2018 15:10:19 +0900 Subject: [PATCH 15/52] mod: dict strategy --- compression.go | 83 ++++++++++++++++++++++++++++++++++++++++---------- conn.go | 41 +++++-------------------- conn_test.go | 11 ------- 3 files changed, 75 insertions(+), 60 deletions(-) diff --git a/compression.go b/compression.go index c34be891..2640a916 100644 --- a/compression.go +++ b/compression.go @@ -26,7 +26,7 @@ var ( }} ) -func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser { +func decompressNoContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + @@ -35,10 +35,10 @@ func decompressNoContextTakeover(r io.Reader, dict []byte) io.ReadCloser { fr, _ := flateReaderPool.Get().(io.ReadCloser) fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) - return &flateReadWrapper{fr} + return &flateReadWrapper{fr: fr} } -func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser { +func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + @@ -46,15 +46,21 @@ func decompressContextTakeover(r io.Reader, dict []byte) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), dict) - return &flateReadWrapper{fr} + + if dict != nil { + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict) + } else { + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + } + + return &flateReadWrapper{fr: fr, hasDict: true, dict: dict} } func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } -func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { +func compressNoContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser { p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} fw, _ := p.Get().(*flate.Writer) @@ -66,19 +72,18 @@ func compressNoContextTakeover(w io.WriteCloser, level int, dict []byte) io.Writ return &flateWriteWrapper{fw: fw, tw: tw, p: p} } -func compressContextTakeover(w io.WriteCloser, level int, dict []byte) io.WriteCloser { - p := &flateWriterDictPools[level-minCompressionLevel] +func compressContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser { tw := &truncWriter{w: w} - fw, _ := p.Get().(*flate.Writer) - if fw == nil { - // use WriterDict - fw, _ = flate.NewWriterDict(tw, level, dict) + var fw *flate.Writer + + if dict != nil { + fw, _ = flate.NewWriterDict(tw, level, *dict) } else { - fw.Reset(tw) + fw, _ = flate.NewWriterDict(tw, level, nil) } - return &flateWriteWrapper{fw: fw, tw: tw, p: p} + return &flateWriteWrapper{fw: fw, tw: tw, hasDict: true, dict: dict} } // truncWriter is an io.Writer that writes all but the last four bytes of the @@ -121,12 +126,20 @@ type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter p *sync.Pool + + hasDict bool + dict *[]byte } func (w *flateWriteWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } + + if w.hasDict { + w.addDict(p) + } + return w.fw.Write(p) } @@ -135,7 +148,11 @@ func (w *flateWriteWrapper) Close() error { return errWriteClosed } err1 := w.fw.Flush() - w.p.Put(w.fw) + + if !w.hasDict { + w.p.Put(w.fw) + } + w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") @@ -147,8 +164,21 @@ func (w *flateWriteWrapper) Close() error { return err2 } +// addDict adds payload to dict. +func (w *flateWriteWrapper) addDict(b []byte) { + *w.dict = append(*w.dict, b...) + + if len(*w.dict) > maxWindowBits { + offset := len(*w.dict) - maxWindowBits + *w.dict = (*w.dict)[offset:] + } +} + type flateReadWrapper struct { fr io.ReadCloser // flate.NewReader + + hasDict bool + dict *[]byte } func (r *flateReadWrapper) Read(p []byte) (int, error) { @@ -164,6 +194,13 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) { // this final read. r.Close() } + + if r.hasDict { + if n > 0 { + r.addDict(p[:n]) + } + } + return n, err } @@ -172,7 +209,21 @@ func (r *flateReadWrapper) Close() error { return io.ErrClosedPipe } err := r.fr.Close() - flateReaderPool.Put(r.fr) + + if !r.hasDict { + flateReaderPool.Put(r.fr) + } + r.fr = nil return err } + +// addDict adds payload to dict. +func (r *flateReadWrapper) addDict(b []byte) { + *r.dict = append(*r.dict, b...) + + if len(*r.dict) > maxWindowBits { + offset := len(*r.dict) - maxWindowBits + *r.dict = (*r.dict)[offset:] + } +} diff --git a/conn.go b/conn.go index 6ac7e755..174f94b6 100644 --- a/conn.go +++ b/conn.go @@ -243,7 +243,7 @@ type Conn struct { enableWriteCompression bool compressionLevel int - newCompressionWriter func(io.WriteCloser, int, []byte) io.WriteCloser + newCompressionWriter func(io.WriteCloser, int, *[]byte) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -261,12 +261,12 @@ type Conn struct { readErrCount int messageReader *messageReader // the current low-level reader - readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader, []byte) io.ReadCloser // arges may flateReadWrapper struct + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader, *[]byte) io.ReadCloser // arges may flateReadWrapper struct contextTakeover bool - txDict []byte - rxDict []byte + txDict *[]byte + rxDict *[]byte } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -336,6 +336,9 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in writeBuf: writeBuf, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, + + txDict: &[]byte{}, + rxDict: &[]byte{}, } c.SetCloseHandler(nil) c.SetPingHandler(nil) @@ -763,9 +766,6 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if _, err = w.Write(data); err != nil { return err } - if c.contextTakeover { - c.AddTxDict(data) - } return w.Close() } @@ -1046,11 +1046,6 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { } p, err = ioutil.ReadAll(r) - // if context-takeover add payload to dictionary - if c.contextTakeover { - c.AddRxDict(p) - } - return messageType, p, err } @@ -1167,26 +1162,6 @@ func (c *Conn) SetCompressionLevel(level int) error { return nil } -// AddTxDict adds payload to txDict. -func (c *Conn) AddTxDict(b []byte) { - c.txDict = append(c.txDict, b...) - - if len(c.txDict) > maxWindowBits { - offset := len(c.txDict) - maxWindowBits - c.txDict = c.txDict[offset:] - } -} - -// AddTxDict adds payload to rxDict. -func (c *Conn) AddRxDict(b []byte) { - c.rxDict = append(c.rxDict, b...) - - if len(c.rxDict) > maxWindowBits { - offset := len(c.rxDict) - maxWindowBits - c.rxDict = c.rxDict[offset:] - } -} - // FormatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { diff --git a/conn_test.go b/conn_test.go index 5e101636..5fda7b5c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -494,14 +494,3 @@ func TestBufioReuse(t *testing.T) { } } - -func BenchmarkAddDict(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) - messages := textMessages(100) - b.ResetTimer() - for i := 0; i < b.N; i++ { - c.AddRxDict(messages[i%len(messages)]) - } - b.ReportAllocs() -} From a9475f2ccfb8ac305fd32422587aca0b808a31b9 Mon Sep 17 00:00:00 2001 From: misu Date: Mon, 29 Jan 2018 18:12:43 +0900 Subject: [PATCH 16/52] mod: remove judge dict nil --- compression.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/compression.go b/compression.go index 2640a916..9b9158d6 100644 --- a/compression.go +++ b/compression.go @@ -46,12 +46,7 @@ func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - - if dict != nil { - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict) - } else { - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) - } + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict) return &flateReadWrapper{fr: fr, hasDict: true, dict: dict} } @@ -75,13 +70,7 @@ func compressNoContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.Wri func compressContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser { tw := &truncWriter{w: w} - var fw *flate.Writer - - if dict != nil { - fw, _ = flate.NewWriterDict(tw, level, *dict) - } else { - fw, _ = flate.NewWriterDict(tw, level, nil) - } + fw, _ := flate.NewWriterDict(tw, level, *dict) return &flateWriteWrapper{fw: fw, tw: tw, hasDict: true, dict: dict} } From f2a68e2d216a4c6a9deefc147993d13d7464b73c Mon Sep 17 00:00:00 2001 From: misu Date: Tue, 30 Jan 2018 18:37:52 +0900 Subject: [PATCH 17/52] add: call compress method test --- compression_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/compression_test.go b/compression_test.go index aecb23f5..afa7de7c 100644 --- a/compression_test.go +++ b/compression_test.go @@ -79,6 +79,76 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { b.ReportAllocs() } +func BenchmarkCallWriteWithCompressionOfContextTakeover(b *testing.B) { + w := ioutil.Discard + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + // messages := textMessages(100) + c.enableWriteCompression = true + c.contextTakeover = true + c.newCompressionWriter = compressContextTakeover + mw := &messageWriter{ + c: c, + frameType: 2, + pos: maxFrameHeaderSize, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + // c.txDict = &messages[i%len(messages)] + c.newCompressionWriter(mw, 2, &[]byte{}) + } + b.ReportAllocs() +} + +func BenchmarkCallWriteWithCompression(b *testing.B) { + w := ioutil.Discard + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + // messages := textMessages(100) + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + mw := &messageWriter{ + c: c, + frameType: 2, + pos: maxFrameHeaderSize, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + // c.txDict = &messages[i%len(messages)] + c.newCompressionWriter(mw, 2, nil) + } + b.ReportAllocs() +} + +func BenchmarkReadWithCompression(b *testing.B) { + w := ioutil.Discard + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c.enableWriteCompression = true + c.newDecompressionReader = decompressNoContextTakeover + messages := textMessages(100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r := bytes.NewReader(messages[i%len(messages)]) + reader := c.newDecompressionReader(r, nil) + ioutil.ReadAll(reader) + } + b.ReportAllocs() +} + +func BenchmarkReadWithCompressionOfContextTakeover(b *testing.B) { + w := ioutil.Discard + c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c.enableWriteCompression = true + c.contextTakeover = true + c.newDecompressionReader = decompressContextTakeover + messages := textMessages(100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r := bytes.NewReader(messages[i%len(messages)]) + reader := c.newDecompressionReader(r, c.rxDict) + ioutil.ReadAll(reader) + } + b.ReportAllocs() +} + func TestValidCompressionLevel(t *testing.T) { c := newConn(fakeNetConn{}, false, 1024, 1024) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { From 2472e6d8009c4ae57088e56d3eaeb5c4fd732a4e Mon Sep 17 00:00:00 2001 From: misu Date: Tue, 30 Jan 2018 19:02:50 +0900 Subject: [PATCH 18/52] upgrade: use pool for writerdict --- compression.go | 40 +++++++++++++--------------------------- compression_test.go | 39 --------------------------------------- conn.go | 12 ++---------- 3 files changed, 15 insertions(+), 76 deletions(-) diff --git a/compression.go b/compression.go index 9b9158d6..eda70ab3 100644 --- a/compression.go +++ b/compression.go @@ -5,11 +5,12 @@ package websocket import ( - "compress/flate" "errors" "io" "strings" "sync" + + "compress/flate" ) const ( @@ -55,7 +56,7 @@ func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } -func compressNoContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser { +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} fw, _ := p.Get().(*flate.Writer) @@ -67,12 +68,16 @@ func compressNoContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.Wri return &flateWriteWrapper{fw: fw, tw: tw, p: p} } -func compressContextTakeover(w io.WriteCloser, level int, dict *[]byte) io.WriteCloser { +func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterDictPools[level-minCompressionLevel] tw := &truncWriter{w: w} - - fw, _ := flate.NewWriterDict(tw, level, *dict) - - return &flateWriteWrapper{fw: fw, tw: tw, hasDict: true, dict: dict} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriterDict(tw, level, []byte{}) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} } // truncWriter is an io.Writer that writes all but the last four bytes of the @@ -115,9 +120,6 @@ type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter p *sync.Pool - - hasDict bool - dict *[]byte } func (w *flateWriteWrapper) Write(p []byte) (int, error) { @@ -125,10 +127,6 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) { return 0, errWriteClosed } - if w.hasDict { - w.addDict(p) - } - return w.fw.Write(p) } @@ -138,9 +136,7 @@ func (w *flateWriteWrapper) Close() error { } err1 := w.fw.Flush() - if !w.hasDict { - w.p.Put(w.fw) - } + w.p.Put(w.fw) w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { @@ -153,16 +149,6 @@ func (w *flateWriteWrapper) Close() error { return err2 } -// addDict adds payload to dict. -func (w *flateWriteWrapper) addDict(b []byte) { - *w.dict = append(*w.dict, b...) - - if len(*w.dict) > maxWindowBits { - offset := len(*w.dict) - maxWindowBits - *w.dict = (*w.dict)[offset:] - } -} - type flateReadWrapper struct { fr io.ReadCloser // flate.NewReader diff --git a/compression_test.go b/compression_test.go index afa7de7c..5a479e3c 100644 --- a/compression_test.go +++ b/compression_test.go @@ -79,45 +79,6 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { b.ReportAllocs() } -func BenchmarkCallWriteWithCompressionOfContextTakeover(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) - // messages := textMessages(100) - c.enableWriteCompression = true - c.contextTakeover = true - c.newCompressionWriter = compressContextTakeover - mw := &messageWriter{ - c: c, - frameType: 2, - pos: maxFrameHeaderSize, - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - // c.txDict = &messages[i%len(messages)] - c.newCompressionWriter(mw, 2, &[]byte{}) - } - b.ReportAllocs() -} - -func BenchmarkCallWriteWithCompression(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) - // messages := textMessages(100) - c.enableWriteCompression = true - c.newCompressionWriter = compressNoContextTakeover - mw := &messageWriter{ - c: c, - frameType: 2, - pos: maxFrameHeaderSize, - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - // c.txDict = &messages[i%len(messages)] - c.newCompressionWriter(mw, 2, nil) - } - b.ReportAllocs() -} - func BenchmarkReadWithCompression(b *testing.B) { w := ioutil.Discard c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) diff --git a/conn.go b/conn.go index 174f94b6..078d25de 100644 --- a/conn.go +++ b/conn.go @@ -243,7 +243,7 @@ type Conn struct { enableWriteCompression bool compressionLevel int - newCompressionWriter func(io.WriteCloser, int, *[]byte) io.WriteCloser + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -265,7 +265,6 @@ type Conn struct { newDecompressionReader func(io.Reader, *[]byte) io.ReadCloser // arges may flateReadWrapper struct contextTakeover bool - txDict *[]byte rxDict *[]byte } @@ -337,7 +336,6 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in enableWriteCompression: true, compressionLevel: defaultCompressionLevel, - txDict: &[]byte{}, rxDict: &[]byte{}, } c.SetCloseHandler(nil) @@ -509,13 +507,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { mw.compress = true - switch { - case c.contextTakeover: - c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, c.txDict) - // no-context-takeover - default: - c.writer = c.newCompressionWriter(c.writer, c.compressionLevel, nil) - } + c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) } return c.writer, nil } From fededdd187d6df5b40e08cd1bc6cb01d4a2fc8a1 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 31 Jan 2018 19:29:01 +0900 Subject: [PATCH 19/52] mod: detach compressContextTakeover method and add debug comment. --- client.go | 2 +- compression.go | 46 +++++++++++++++++++++++++++++++++------------- conn.go | 35 ++++++++++++++++++++++++++++++++++- server.go | 2 +- 4 files changed, 69 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 45363fdb..6cc897d6 100644 --- a/client.go +++ b/client.go @@ -322,7 +322,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re switch { case cmwb && smwb: conn.contextTakeover = true - conn.newCompressionWriter = compressContextTakeover + // conn.newCompressionWriter = compressContextTakeover conn.newDecompressionReader = decompressContextTakeover default: conn.newCompressionWriter = compressNoContextTakeover diff --git a/compression.go b/compression.go index eda70ab3..b2a0c0d2 100644 --- a/compression.go +++ b/compression.go @@ -6,6 +6,7 @@ package websocket import ( "errors" + "fmt" "io" "strings" "sync" @@ -68,17 +69,17 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { return &flateWriteWrapper{fw: fw, tw: tw, p: p} } -func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { - p := &flateWriterDictPools[level-minCompressionLevel] - tw := &truncWriter{w: w} - fw, _ := p.Get().(*flate.Writer) - if fw == nil { - fw, _ = flate.NewWriterDict(tw, level, []byte{}) - } else { - fw.Reset(tw) - } - return &flateWriteWrapper{fw: fw, tw: tw, p: p} -} +// func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { +// p := &flateWriterDictPools[level-minCompressionLevel] +// tw := &truncWriter{w: w} +// fw, _ := p.Get().(*flate.Writer) +// if fw == nil { +// fw, _ = flate.NewWriterDict(tw, level, nil) +// } else { +// fw.Reset(tw) +// } +// return &flateWriteWrapper{fw: fw, tw: tw, p: p} +// } // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. @@ -90,6 +91,8 @@ type truncWriter struct { func (w *truncWriter) Write(p []byte) (int, error) { n := 0 + fmt.Printf("\x1b[32m Start truncWriter.Write %#v \x1b[0m\n", p) + fmt.Printf("\x1b[32m truncWriter w.n -> len %#v \x1b[0m\n", w.n) // fill buffer first for simplicity. if w.n < len(w.p) { @@ -106,13 +109,17 @@ func (w *truncWriter) Write(p []byte) (int, error) { m = len(w.p) } + fmt.Printf("\x1b[32m Write will truncWriter.Write %#v \x1b[0m\n", w.p[:m]) + if nn, err := w.w.Write(w.p[:m]); err != nil { + fmt.Printf("\x1b[32m w.w.Write Error truncWriter.Write %#v \x1b[0m\n", err) return n + nn, err } copy(w.p[:], w.p[m:]) copy(w.p[len(w.p)-m:], p[len(p)-m:]) nn, err := w.w.Write(p[:len(p)-m]) + fmt.Printf("\x1b[32m End truncWriter.Write %#v \x1b[0m\n", p) return n + nn, err } @@ -120,6 +127,8 @@ type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter p *sync.Pool + + isDictWriter bool } func (w *flateWriteWrapper) Write(p []byte) (int, error) { @@ -127,6 +136,8 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) { return 0, errWriteClosed } + fmt.Printf("flateWriteWrapper will Write %#v \n", p) + return w.fw.Write(p) } @@ -136,16 +147,25 @@ func (w *flateWriteWrapper) Close() error { } err1 := w.fw.Flush() - w.p.Put(w.fw) + fmt.Printf("w.tw.n -> -> %#v \n", w.tw.n) + fmt.Printf("w.tw.p -> -> %#v \n", w.tw.p) + + if !w.isDictWriter { + w.p.Put(w.fw) + w.fw = nil + } - w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } + err2 := w.tw.w.Close() if err1 != nil { return err1 } + + fmt.Printf("err2 %#v \n", err2) + return err2 } diff --git a/conn.go b/conn.go index 078d25de..50339d92 100644 --- a/conn.go +++ b/conn.go @@ -6,8 +6,10 @@ package websocket import ( "bufio" + "compress/flate" "encoding/binary" "errors" + "fmt" "io" "io/ioutil" "math/rand" @@ -244,6 +246,7 @@ type Conn struct { enableWriteCompression bool compressionLevel int newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + compressionWriters map[int]*flateWriteWrapper // Read fields reader io.ReadCloser // the current reader returned to the application @@ -326,6 +329,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) } + cw := make(map[int]*flateWriteWrapper, 2) + c := &Conn{ isServer: isServer, br: br, @@ -336,6 +341,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in enableWriteCompression: true, compressionLevel: defaultCompressionLevel, + compressionWriters: cw, + rxDict: &[]byte{}, } c.SetCloseHandler(nil) @@ -507,6 +514,27 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { mw.compress = true + // For context-takeover, frate.Writer must be set to conn. + if c.contextTakeover { + if fww, ok := c.compressionWriters[messageType]; ok { + // tw := &truncWriter{w: c.writer} + + // Todo: if write + + // fw, _ := flate.NewWriterDict(tw, c.compressionLevel, []byte("Hello")) + // fww.fw.Reset(tw) + // fww.fw = fw + fww.tw.w = c.writer + return fww, nil + } else { + tw := &truncWriter{w: c.writer} + fw, _ := flate.NewWriterDict(tw, c.compressionLevel, nil) + fww := &flateWriteWrapper{fw: fw, tw: tw, isDictWriter: true} + c.compressionWriters[messageType] = fww + return fww, nil + } + } + c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) } return c.writer, nil @@ -753,11 +781,16 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { w, err := c.NextWriter(messageType) if err != nil { + fmt.Printf("\x1b[31m c.NextWriter %v \x1b[0m\n", err) return err } - if _, err = w.Write(data); err != nil { + if _, err := w.Write(data); err != nil { + fmt.Printf("\x1b[31m w.Write %v \x1b[0m\n", err) return err } + + fmt.Printf("WriteMessage data %v \x1b[0m\n", data) + fmt.Printf("\x1b[31m w.Write err %v \x1b[0m\n", err) return w.Close() } diff --git a/server.go b/server.go index 67c82b7d..5841381c 100644 --- a/server.go +++ b/server.go @@ -187,7 +187,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade switch { case contextTakeover: c.contextTakeover = contextTakeover - c.newCompressionWriter = compressContextTakeover + // c.newCompressionWriter = compressContextTakeover c.newDecompressionReader = decompressContextTakeover default: c.newCompressionWriter = compressNoContextTakeover From 62df16a55d4214b56ba0e7a08d72bbbc075dafb8 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 31 Jan 2018 19:40:12 +0900 Subject: [PATCH 20/52] mod: compressContextTakeover method. --- client.go | 2 +- compression.go | 22 +++++++++++----------- server.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 6cc897d6..45363fdb 100644 --- a/client.go +++ b/client.go @@ -322,7 +322,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re switch { case cmwb && smwb: conn.contextTakeover = true - // conn.newCompressionWriter = compressContextTakeover + conn.newCompressionWriter = compressContextTakeover conn.newDecompressionReader = decompressContextTakeover default: conn.newCompressionWriter = compressNoContextTakeover diff --git a/compression.go b/compression.go index b2a0c0d2..97b8610a 100644 --- a/compression.go +++ b/compression.go @@ -69,17 +69,17 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { return &flateWriteWrapper{fw: fw, tw: tw, p: p} } -// func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { -// p := &flateWriterDictPools[level-minCompressionLevel] -// tw := &truncWriter{w: w} -// fw, _ := p.Get().(*flate.Writer) -// if fw == nil { -// fw, _ = flate.NewWriterDict(tw, level, nil) -// } else { -// fw.Reset(tw) -// } -// return &flateWriteWrapper{fw: fw, tw: tw, p: p} -// } +func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + // p := &flateWriterDictPools[level-minCompressionLevel] + // tw := &truncWriter{w: w} + // fw, _ := p.Get().(*flate.Writer) + // if fw == nil { + // fw, _ = flate.NewWriterDict(tw, level, nil) + // } else { + // fw.Reset(tw) + // } + return &flateWriteWrapper{} +} // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. diff --git a/server.go b/server.go index 5841381c..67c82b7d 100644 --- a/server.go +++ b/server.go @@ -187,7 +187,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade switch { case contextTakeover: c.contextTakeover = contextTakeover - // c.newCompressionWriter = compressContextTakeover + c.newCompressionWriter = compressContextTakeover c.newDecompressionReader = decompressContextTakeover default: c.newCompressionWriter = compressNoContextTakeover From ee46f8548a106a02264f711a1838887fd3cf58cf Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 31 Jan 2018 20:52:04 +0900 Subject: [PATCH 21/52] add: comment --- compression.go | 7 +++++++ conn.go | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/compression.go b/compression.go index 97b8610a..edf8cb63 100644 --- a/compression.go +++ b/compression.go @@ -93,6 +93,7 @@ func (w *truncWriter) Write(p []byte) (int, error) { n := 0 fmt.Printf("\x1b[32m Start truncWriter.Write %#v \x1b[0m\n", p) fmt.Printf("\x1b[32m truncWriter w.n -> len %#v \x1b[0m\n", w.n) + fmt.Printf("\x1b[32m truncWriter w.p %#v \x1b[0m\n", w.p) // fill buffer first for simplicity. if w.n < len(w.p) { @@ -120,6 +121,7 @@ func (w *truncWriter) Write(p []byte) (int, error) { copy(w.p[len(w.p)-m:], p[len(p)-m:]) nn, err := w.w.Write(p[:len(p)-m]) fmt.Printf("\x1b[32m End truncWriter.Write %#v \x1b[0m\n", p) + fmt.Printf("\x1b[32m End truncWriter w.p %#v \x1b[0m\n", w.p) return n + nn, err } @@ -159,6 +161,11 @@ func (w *flateWriteWrapper) Close() error { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } + if !w.isDictWriter { + w.tw.p = [4]byte{} + w.tw.n = 0 + } + err2 := w.tw.w.Close() if err1 != nil { return err1 diff --git a/conn.go b/conn.go index 50339d92..3e402e2c 100644 --- a/conn.go +++ b/conn.go @@ -514,12 +514,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { mw.compress = true - // For context-takeover, frate.Writer must be set to conn. + // For context-takeover if c.contextTakeover { if fww, ok := c.compressionWriters[messageType]; ok { // tw := &truncWriter{w: c.writer} - // Todo: if write + //Todo reset trunkwriter inside flate.Writer. // fw, _ := flate.NewWriterDict(tw, c.compressionLevel, []byte("Hello")) // fww.fw.Reset(tw) From fb7d67a34a0160ed35d57827e67023099218081a Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 11:56:17 +0900 Subject: [PATCH 22/52] mod: fmt variables of conn.go --- conn.go | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/conn.go b/conn.go index 3e402e2c..72423d6d 100644 --- a/conn.go +++ b/conn.go @@ -9,7 +9,6 @@ import ( "compress/flate" "encoding/binary" "errors" - "fmt" "io" "io/ioutil" "math/rand" @@ -85,20 +84,14 @@ const ( PongMessage = 10 ) -// ErrCloseSent is returned when the application writes a message to the -// connection after sending a close message. -var ErrCloseSent = errors.New("websocket: close sent") - -// ErrReadLimit is returned when reading a message that is larger than the -// read limit set for the connection. -var ErrReadLimit = errors.New("websocket: read limit exceeded") - -// netError satisfies the net Error interface. -type netError struct { - msg string - temporary bool - timeout bool -} +type ( + // netError satisfies the net Error interface. + netError struct { + msg string + temporary bool + timeout bool + } +) func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } @@ -182,6 +175,14 @@ var ( errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") + + // ErrCloseSent is returned when the application writes a message to the + // connection after sending a close message. + ErrCloseSent = errors.New("websocket: close sent") + + // ErrReadLimit is returned when reading a message that is larger than the + // read limit set for the connection. + ErrReadLimit = errors.New("websocket: read limit exceeded") ) func newMaskKey() [4]byte { @@ -781,16 +782,12 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { w, err := c.NextWriter(messageType) if err != nil { - fmt.Printf("\x1b[31m c.NextWriter %v \x1b[0m\n", err) return err } if _, err := w.Write(data); err != nil { - fmt.Printf("\x1b[31m w.Write %v \x1b[0m\n", err) return err } - fmt.Printf("WriteMessage data %v \x1b[0m\n", data) - fmt.Printf("\x1b[31m w.Write err %v \x1b[0m\n", err) return w.Close() } From f8b4a0f71ddd24f926ea583713371e178be41674 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 12:59:30 +0900 Subject: [PATCH 23/52] mod: flate.writer for context-takeover --- client.go | 7 ++++- compression.go | 69 +++++++++++++++++++++++++++------------------ compression_test.go | 5 +++- conn.go | 21 ++------------ server.go | 7 ++++- 5 files changed, 60 insertions(+), 49 deletions(-) diff --git a/client.go b/client.go index 45363fdb..3890298d 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ package websocket import ( "bytes" + "compress/flate" "crypto/tls" "errors" "io" @@ -322,7 +323,11 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re switch { case cmwb && smwb: conn.contextTakeover = true - conn.newCompressionWriter = compressContextTakeover + + var f contextTakeoverWriterFactory + f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader + conn.newCompressionWriter = f.newCompressionWriter + conn.newDecompressionReader = decompressContextTakeover default: conn.newCompressionWriter = compressNoContextTakeover diff --git a/compression.go b/compression.go index edf8cb63..c9067c4d 100644 --- a/compression.go +++ b/compression.go @@ -6,7 +6,6 @@ package websocket import ( "errors" - "fmt" "io" "strings" "sync" @@ -69,18 +68,6 @@ func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { return &flateWriteWrapper{fw: fw, tw: tw, p: p} } -func compressContextTakeover(w io.WriteCloser, level int) io.WriteCloser { - // p := &flateWriterDictPools[level-minCompressionLevel] - // tw := &truncWriter{w: w} - // fw, _ := p.Get().(*flate.Writer) - // if fw == nil { - // fw, _ = flate.NewWriterDict(tw, level, nil) - // } else { - // fw.Reset(tw) - // } - return &flateWriteWrapper{} -} - // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. type truncWriter struct { @@ -91,9 +78,6 @@ type truncWriter struct { func (w *truncWriter) Write(p []byte) (int, error) { n := 0 - fmt.Printf("\x1b[32m Start truncWriter.Write %#v \x1b[0m\n", p) - fmt.Printf("\x1b[32m truncWriter w.n -> len %#v \x1b[0m\n", w.n) - fmt.Printf("\x1b[32m truncWriter w.p %#v \x1b[0m\n", w.p) // fill buffer first for simplicity. if w.n < len(w.p) { @@ -110,18 +94,13 @@ func (w *truncWriter) Write(p []byte) (int, error) { m = len(w.p) } - fmt.Printf("\x1b[32m Write will truncWriter.Write %#v \x1b[0m\n", w.p[:m]) - if nn, err := w.w.Write(w.p[:m]); err != nil { - fmt.Printf("\x1b[32m w.w.Write Error truncWriter.Write %#v \x1b[0m\n", err) return n + nn, err } copy(w.p[:], w.p[m:]) copy(w.p[len(w.p)-m:], p[len(p)-m:]) nn, err := w.w.Write(p[:len(p)-m]) - fmt.Printf("\x1b[32m End truncWriter.Write %#v \x1b[0m\n", p) - fmt.Printf("\x1b[32m End truncWriter w.p %#v \x1b[0m\n", w.p) return n + nn, err } @@ -138,8 +117,6 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) { return 0, errWriteClosed } - fmt.Printf("flateWriteWrapper will Write %#v \n", p) - return w.fw.Write(p) } @@ -149,9 +126,6 @@ func (w *flateWriteWrapper) Close() error { } err1 := w.fw.Flush() - fmt.Printf("w.tw.n -> -> %#v \n", w.tw.n) - fmt.Printf("w.tw.p -> -> %#v \n", w.tw.p) - if !w.isDictWriter { w.p.Put(w.fw) w.fw = nil @@ -171,8 +145,6 @@ func (w *flateWriteWrapper) Close() error { return err1 } - fmt.Printf("err2 %#v \n", err2) - return err2 } @@ -229,3 +201,44 @@ func (r *flateReadWrapper) addDict(b []byte) { *r.dict = (*r.dict)[offset:] } } + +type ( + contextTakeoverWriterFactory struct { + fw *flate.Writer + tw truncWriter + } + + flateTakeoverWriteWrapper struct { + f *contextTakeoverWriterFactory + } +) + +func (f *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser { + f.tw.w = w + f.tw.n = 0 + return &flateTakeoverWriteWrapper{f} +} + +func (w *flateTakeoverWriteWrapper) Write(p []byte) (int, error) { + if w.f == nil { + return 0, errWriteClosed + } + return w.f.fw.Write(p) +} + +func (w *flateTakeoverWriteWrapper) Close() error { + if w.f == nil { + return errWriteClosed + } + f := w.f + w.f = nil + err1 := f.fw.Flush() + if f.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := f.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} diff --git a/compression_test.go b/compression_test.go index 5a479e3c..1b9218b4 100644 --- a/compression_test.go +++ b/compression_test.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "compress/flate" "fmt" "io" "io/ioutil" @@ -71,7 +72,9 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { messages := textMessages(100) c.enableWriteCompression = true c.contextTakeover = true - c.newCompressionWriter = compressContextTakeover + var f contextTakeoverWriterFactory + f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader + c.newCompressionWriter = f.newCompressionWriter b.ResetTimer() for i := 0; i < b.N; i++ { c.WriteMessage(TextMessage, messages[i%len(messages)]) diff --git a/conn.go b/conn.go index 72423d6d..2f112941 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" - "compress/flate" "encoding/binary" "errors" "io" @@ -346,6 +345,7 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in rxDict: &[]byte{}, } + c.SetCloseHandler(nil) c.SetPingHandler(nil) c.SetPongHandler(nil) @@ -517,23 +517,8 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { mw.compress = true // For context-takeover if c.contextTakeover { - if fww, ok := c.compressionWriters[messageType]; ok { - // tw := &truncWriter{w: c.writer} - - //Todo reset trunkwriter inside flate.Writer. - - // fw, _ := flate.NewWriterDict(tw, c.compressionLevel, []byte("Hello")) - // fww.fw.Reset(tw) - // fww.fw = fw - fww.tw.w = c.writer - return fww, nil - } else { - tw := &truncWriter{w: c.writer} - fw, _ := flate.NewWriterDict(tw, c.compressionLevel, nil) - fww := &flateWriteWrapper{fw: fw, tw: tw, isDictWriter: true} - c.compressionWriters[messageType] = fww - return fww, nil - } + c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) + return c.writer, nil } c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) diff --git a/server.go b/server.go index 67c82b7d..9296e602 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ package websocket import ( "bufio" + "compress/flate" "errors" "net" "net/http" @@ -187,7 +188,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade switch { case contextTakeover: c.contextTakeover = contextTakeover - c.newCompressionWriter = compressContextTakeover + + var f contextTakeoverWriterFactory + f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader + c.newCompressionWriter = f.newCompressionWriter + c.newDecompressionReader = decompressContextTakeover default: c.newCompressionWriter = compressNoContextTakeover From 3452ab8a615d9d185c13154e046cc04ba94a6620 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 14:57:00 +0900 Subject: [PATCH 24/52] mod: client & server setting for context-takeover --- client.go | 16 ++++++++++------ conn.go | 11 ----------- server.go | 14 +++++++++++--- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 3890298d..f927a324 100644 --- a/client.go +++ b/client.go @@ -42,6 +42,7 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS NetDial: func(net, addr string) (net.Conn, error) { return netConn, nil }, + CompressionLevel: defaultCompressionLevel, } return d.Dial(u.String(), requestHeader) } @@ -79,15 +80,18 @@ type Dialer struct { // takeover" modes are supported. EnableCompression bool - // EnableContextTakeover specifies specifies if the client should attempt to negotiate - // per message compression with context-takeover (RFC 7692). - // but window bits is allowed only 15, because go's flate library support 15 bits only. - EnableContextTakeover bool - // Jar specifies the cookie jar. // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar http.CookieJar + + // CompressionLeval is set for contextTakeoer. + CompressionLevel int + + // EnableContextTakeover specifies specifies if the client should attempt to negotiate + // per message compression with context-takeover (RFC 7692). + // but window bits is allowed only 15, because go's flate library support 15 bits only. + EnableContextTakeover bool } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -325,7 +329,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re conn.contextTakeover = true var f contextTakeoverWriterFactory - f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader + f.fw, _ = flate.NewWriter(&f.tw, d.CompressionLevel) // level is specified in Dialer, Upgrader conn.newCompressionWriter = f.newCompressionWriter conn.newDecompressionReader = decompressContextTakeover diff --git a/conn.go b/conn.go index 2f112941..38451ec1 100644 --- a/conn.go +++ b/conn.go @@ -246,7 +246,6 @@ type Conn struct { enableWriteCompression bool compressionLevel int newCompressionWriter func(io.WriteCloser, int) io.WriteCloser - compressionWriters map[int]*flateWriteWrapper // Read fields reader io.ReadCloser // the current reader returned to the application @@ -329,8 +328,6 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) } - cw := make(map[int]*flateWriteWrapper, 2) - c := &Conn{ isServer: isServer, br: br, @@ -341,8 +338,6 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in enableWriteCompression: true, compressionLevel: defaultCompressionLevel, - compressionWriters: cw, - rxDict: &[]byte{}, } @@ -515,12 +510,6 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { mw.compress = true - // For context-takeover - if c.contextTakeover { - c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) - return c.writer, nil - } - c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) } return c.writer, nil diff --git a/server.go b/server.go index 9296e602..f7e73591 100644 --- a/server.go +++ b/server.go @@ -54,6 +54,14 @@ type Upgrader struct { // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool + + // CompressionLeval is set for contextTakeoer. + CompressionLevel int + + // EnableContextTakeover specifies specifies if the client should attempt to negotiate + // per message compression with context-takeover (RFC 7692). + // but window bits is allowed only 15, because go's flate library support 15 bits only. + EnableContextTakeover bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -186,11 +194,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if compress { switch { - case contextTakeover: + case contextTakeover && u.EnableContextTakeover: c.contextTakeover = contextTakeover var f contextTakeoverWriterFactory - f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader + f.fw, _ = flate.NewWriter(&f.tw, u.CompressionLevel) // level is specified in Dialer, Upgrader c.newCompressionWriter = f.newCompressionWriter c.newDecompressionReader = decompressContextTakeover @@ -211,7 +219,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } if compress { switch { - case contextTakeover: + case contextTakeover && u.EnableContextTakeover: p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...) default: p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) From b9e2343c7d1106c3d35125ea97b99f2e757f4f99 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 15:05:47 +0900 Subject: [PATCH 25/52] mod: set compressionLevel to conn, if EnabeleCompression --- client.go | 4 ++++ server.go | 2 ++ 2 files changed, 6 insertions(+) diff --git a/client.go b/client.go index f927a324..e84a6706 100644 --- a/client.go +++ b/client.go @@ -288,6 +288,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + if d.EnableCompression { + conn.compressionLevel = d.CompressionLevel + } + if err := req.Write(netConn); err != nil { return nil, nil, err } diff --git a/server.go b/server.go index f7e73591..1f41b610 100644 --- a/server.go +++ b/server.go @@ -193,6 +193,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.subprotocol = subprotocol if compress { + c.compressionLevel = u.CompressionLevel + switch { case contextTakeover && u.EnableContextTakeover: c.contextTakeover = contextTakeover From e7575e215d53e5779a95067b54f6ec237b4b3872 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 18:26:02 +0900 Subject: [PATCH 26/52] upgrade: compressContextTakeover reader --- client.go | 7 ++- compression.go | 117 +++++++++++++++++++++++--------------------- compression_test.go | 9 ++-- conn.go | 19 +++---- server.go | 13 +++-- 5 files changed, 91 insertions(+), 74 deletions(-) diff --git a/client.go b/client.go index e84a6706..2b8e04ed 100644 --- a/client.go +++ b/client.go @@ -333,10 +333,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re conn.contextTakeover = true var f contextTakeoverWriterFactory - f.fw, _ = flate.NewWriter(&f.tw, d.CompressionLevel) // level is specified in Dialer, Upgrader + f.fw, _ = flate.NewWriter(&f.tw, d.CompressionLevel) conn.newCompressionWriter = f.newCompressionWriter - conn.newDecompressionReader = decompressContextTakeover + var frf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + frf.fr = fr + conn.newDecompressionReader = frf.newDeCompressionReader default: conn.newCompressionWriter = compressNoContextTakeover conn.newDecompressionReader = decompressNoContextTakeover diff --git a/compression.go b/compression.go index c9067c4d..5a143910 100644 --- a/compression.go +++ b/compression.go @@ -17,41 +17,27 @@ const ( minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 maxCompressionLevel = flate.BestCompression defaultCompressionLevel = 1 + + tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" ) var ( - flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool - flateWriterDictPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool - flateReaderPool = sync.Pool{New: func() interface{} { + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { return flate.NewReader(nil) }} ) -func decompressNoContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser { - const tail = - // Add four bytes as specified in RFC - "\x00\x00\xff\xff" + - // Add final block to squelch unexpected EOF error from flate reader. - "\x01\x00\x00\xff\xff" - +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { fr, _ := flateReaderPool.Get().(io.ReadCloser) fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) return &flateReadWrapper{fr: fr} } -func decompressContextTakeover(r io.Reader, dict *[]byte) io.ReadCloser { - const tail = - // Add four bytes as specified in RFC - "\x00\x00\xff\xff" + - // Add final block to squelch unexpected EOF error from flate reader. - "\x01\x00\x00\xff\xff" - - fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), *dict) - - return &flateReadWrapper{fr: fr, hasDict: true, dict: dict} -} - func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } @@ -108,8 +94,6 @@ type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter p *sync.Pool - - isDictWriter bool } func (w *flateWriteWrapper) Write(p []byte) (int, error) { @@ -126,19 +110,15 @@ func (w *flateWriteWrapper) Close() error { } err1 := w.fw.Flush() - if !w.isDictWriter { - w.p.Put(w.fw) - w.fw = nil - } + w.p.Put(w.fw) + w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } - if !w.isDictWriter { - w.tw.p = [4]byte{} - w.tw.n = 0 - } + w.tw.p = [4]byte{} + w.tw.n = 0 err2 := w.tw.w.Close() if err1 != nil { @@ -150,9 +130,6 @@ func (w *flateWriteWrapper) Close() error { type flateReadWrapper struct { fr io.ReadCloser // flate.NewReader - - hasDict bool - dict *[]byte } func (r *flateReadWrapper) Read(p []byte) (int, error) { @@ -169,12 +146,6 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) { r.Close() } - if r.hasDict { - if n > 0 { - r.addDict(p[:n]) - } - } - return n, err } @@ -184,24 +155,12 @@ func (r *flateReadWrapper) Close() error { } err := r.fr.Close() - if !r.hasDict { - flateReaderPool.Put(r.fr) - } + flateReaderPool.Put(r.fr) r.fr = nil return err } -// addDict adds payload to dict. -func (r *flateReadWrapper) addDict(b []byte) { - *r.dict = append(*r.dict, b...) - - if len(*r.dict) > maxWindowBits { - offset := len(*r.dict) - maxWindowBits - *r.dict = (*r.dict)[offset:] - } -} - type ( contextTakeoverWriterFactory struct { fw *flate.Writer @@ -242,3 +201,51 @@ func (w *flateTakeoverWriteWrapper) Close() error { } return err2 } + +type ( + contextTakeoverReaderFactory struct { + fr io.ReadCloser + window []byte + } + + flateTakeoverReadWrapper struct { + f *contextTakeoverReaderFactory + } +) + +func (f *contextTakeoverReaderFactory) newDeCompressionReader(r io.Reader) io.ReadCloser { + f.fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), f.window) + return &flateTakeoverReadWrapper{f} +} + +func (r *flateTakeoverReadWrapper) Read(p []byte) (int, error) { + if r.f.fr == nil { + return 0, io.ErrClosedPipe + } + + n, err := r.f.fr.Read(p) + + // add dictionary + r.f.window = append(r.f.window, p[:n]...) + if len(r.f.window) > maxWindowBits { + offset := len(r.f.window) - maxWindowBits + r.f.window = r.f.window[offset:] + } + + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + + return n, err +} + +func (r *flateTakeoverReadWrapper) Close() error { + if r.f.fr == nil { + return io.ErrClosedPipe + } + err := r.f.fr.Close() + return err +} diff --git a/compression_test.go b/compression_test.go index 1b9218b4..dedc3864 100644 --- a/compression_test.go +++ b/compression_test.go @@ -91,7 +91,7 @@ func BenchmarkReadWithCompression(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { r := bytes.NewReader(messages[i%len(messages)]) - reader := c.newDecompressionReader(r, nil) + reader := c.newDecompressionReader(r) ioutil.ReadAll(reader) } b.ReportAllocs() @@ -102,12 +102,15 @@ func BenchmarkReadWithCompressionOfContextTakeover(b *testing.B) { c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) c.enableWriteCompression = true c.contextTakeover = true - c.newDecompressionReader = decompressContextTakeover + var frf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + frf.fr = fr + c.newDecompressionReader = frf.newDeCompressionReader messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { r := bytes.NewReader(messages[i%len(messages)]) - reader := c.newDecompressionReader(r, c.rxDict) + reader := c.newDecompressionReader(r) ioutil.ReadAll(reader) } b.ReportAllocs() diff --git a/conn.go b/conn.go index 38451ec1..dc6b081a 100644 --- a/conn.go +++ b/conn.go @@ -263,8 +263,8 @@ type Conn struct { readErrCount int messageReader *messageReader // the current low-level reader - readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader, *[]byte) io.ReadCloser // arges may flateReadWrapper struct + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser // arges may flateReadWrapper struct contextTakeover bool rxDict *[]byte @@ -955,13 +955,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader - - switch { - case c.readDecompress && c.contextTakeover: - c.reader = c.newDecompressionReader(c.reader, c.rxDict) - case c.readDecompress: - c.reader = c.newDecompressionReader(c.reader, nil) - } + c.reader = c.newDecompressionReader(c.reader) + + // switch { + // case c.readDecompress && c.contextTakeover: + // c.reader = c.newDecompressionReader(c.reader, c.rxDict) + // case c.readDecompress: + // c.reader = c.newDecompressionReader(c.reader, nil) + // } return frameType, c.reader, nil } diff --git a/server.go b/server.go index 1f41b610..e006083e 100644 --- a/server.go +++ b/server.go @@ -199,11 +199,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade case contextTakeover && u.EnableContextTakeover: c.contextTakeover = contextTakeover - var f contextTakeoverWriterFactory - f.fw, _ = flate.NewWriter(&f.tw, u.CompressionLevel) // level is specified in Dialer, Upgrader - c.newCompressionWriter = f.newCompressionWriter - - c.newDecompressionReader = decompressContextTakeover + var fwf contextTakeoverWriterFactory + fwf.fw, _ = flate.NewWriter(&fwf.tw, u.CompressionLevel) + c.newCompressionWriter = fwf.newCompressionWriter + + var frf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + frf.fr = fr + c.newDecompressionReader = frf.newDeCompressionReader default: c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover From 83a133897cc923c6b1d20f7357f34f55334d309c Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 18:34:03 +0900 Subject: [PATCH 27/52] mod: NextReader if compress --- client_server_test.go | 2 +- conn.go | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 97631eb3..266db508 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -575,7 +575,7 @@ func TestDialCompressionOfContextTakeover(t *testing.T) { defer ws.Close() // Todo multiple send and receive. - multipleSendRecv(t, ws) + sendRecv(t, ws) } func TestSocksProxyDial(t *testing.T) { diff --git a/conn.go b/conn.go index dc6b081a..ed7db85f 100644 --- a/conn.go +++ b/conn.go @@ -955,14 +955,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader - c.reader = c.newDecompressionReader(c.reader) - - // switch { - // case c.readDecompress && c.contextTakeover: - // c.reader = c.newDecompressionReader(c.reader, c.rxDict) - // case c.readDecompress: - // c.reader = c.newDecompressionReader(c.reader, nil) - // } + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } return frameType, c.reader, nil } From 7c6832cc73378653343eee89960d534de93dac1b Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 18:38:15 +0900 Subject: [PATCH 28/52] mod: NextWriter if compress --- conn.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index ed7db85f..8458fbfe 100644 --- a/conn.go +++ b/conn.go @@ -509,8 +509,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true - c.writer = c.newCompressionWriter(c.writer, c.compressionLevel) + c.writer = w } return c.writer, nil } From 8e146c3b73ab5ef2c9b9583fb512c1080925a864 Mon Sep 17 00:00:00 2001 From: misu Date: Thu, 1 Feb 2018 19:04:12 +0900 Subject: [PATCH 29/52] mod: remove unnecessary field --- compression.go | 2 +- conn.go | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/compression.go b/compression.go index 5a143910..98a9b2a3 100644 --- a/compression.go +++ b/compression.go @@ -129,7 +129,7 @@ func (w *flateWriteWrapper) Close() error { } type flateReadWrapper struct { - fr io.ReadCloser // flate.NewReader + fr io.ReadCloser } func (r *flateReadWrapper) Read(p []byte) (int, error) { diff --git a/conn.go b/conn.go index 8458fbfe..f6d64912 100644 --- a/conn.go +++ b/conn.go @@ -267,7 +267,6 @@ type Conn struct { newDecompressionReader func(io.Reader) io.ReadCloser // arges may flateReadWrapper struct contextTakeover bool - rxDict *[]byte } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -337,8 +336,6 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in writeBuf: writeBuf, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, - - rxDict: &[]byte{}, } c.SetCloseHandler(nil) From c83088956fd8252de291888a9cda2a9b032f2282 Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 2 Feb 2018 12:15:47 +0900 Subject: [PATCH 30/52] upgrade: TestFraming --- client.go | 16 +++++++++------- compression.go | 2 +- conn_test.go | 39 ++++++++++++++++++++++++++++++++++++--- server.go | 16 ++++++++++------ 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index 2b8e04ed..1f630965 100644 --- a/client.go +++ b/client.go @@ -42,7 +42,6 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS NetDial: func(net, addr string) (net.Conn, error) { return netConn, nil }, - CompressionLevel: defaultCompressionLevel, } return d.Dial(u.String(), requestHeader) } @@ -289,6 +288,9 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) if d.EnableCompression { + if !isValidCompressionLevel(d.CompressionLevel) { + return nil, nil, errors.New("websocket: invalid compression level") + } conn.compressionLevel = d.CompressionLevel } @@ -332,14 +334,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re case cmwb && smwb: conn.contextTakeover = true - var f contextTakeoverWriterFactory - f.fw, _ = flate.NewWriter(&f.tw, d.CompressionLevel) - conn.newCompressionWriter = f.newCompressionWriter + var wf contextTakeoverWriterFactory + wf.fw, _ = flate.NewWriter(&wf.tw, d.CompressionLevel) + conn.newCompressionWriter = wf.newCompressionWriter - var frf contextTakeoverReaderFactory + var rf contextTakeoverReaderFactory fr := flate.NewReader(nil) - frf.fr = fr - conn.newDecompressionReader = frf.newDeCompressionReader + rf.fr = fr + conn.newDecompressionReader = rf.newDeCompressionReader default: conn.newCompressionWriter = compressNoContextTakeover conn.newDecompressionReader = decompressNoContextTakeover diff --git a/compression.go b/compression.go index 98a9b2a3..56ee05a4 100644 --- a/compression.go +++ b/compression.go @@ -225,7 +225,7 @@ func (r *flateTakeoverReadWrapper) Read(p []byte) (int, error) { n, err := r.f.fr.Read(p) - // add dictionary + // add window r.f.window = append(r.f.window, p[:n]...) if len(r.f.window) > maxWindowBits { offset := len(r.f.window) - maxWindowBits diff --git a/conn_test.go b/conn_test.go index 5fda7b5c..7542a3ae 100644 --- a/conn_test.go +++ b/conn_test.go @@ -7,6 +7,7 @@ package websocket import ( "bufio" "bytes" + "compress/flate" "errors" "fmt" "io" @@ -77,20 +78,52 @@ func TestFraming(t *testing.T) { }}, } - for _, compress := range []bool{false, true} { + compressConditions := []struct { + compress bool + contextTakeover bool + }{ + { + compress: false, + contextTakeover: false, + }, + { + compress: true, + contextTakeover: false, + }, + { + compress: true, + contextTakeover: true, + }, + } + + for _, compressCondition := range compressConditions { for _, isServer := range []bool{true, false} { for _, chunker := range readChunkers { var connBuf bytes.Buffer wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) - if compress { + switch { + case compressCondition.compress && compressCondition.contextTakeover: + + var wf contextTakeoverWriterFactory + wf.fw, _ = flate.NewWriter(&wf.tw, defaultCompressionLevel) + wc.newCompressionWriter = wf.newCompressionWriter + wc.contextTakeover = true + + var rf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + rf.fr = fr + rc.newDecompressionReader = rf.newDeCompressionReader + + rc.contextTakeover = true + case compressCondition.compress: wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover } for _, n := range frameSizes { for _, writer := range writers { - name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + name := fmt.Sprintf("z:%v, c:%v, s:%v, r:%s, n:%d w:%s", compressCondition.compress, compressCondition.contextTakeover, isServer, chunker.name, n, writer.name) w, err := wc.NextWriter(TextMessage) if err != nil { diff --git a/server.go b/server.go index e006083e..92630bec 100644 --- a/server.go +++ b/server.go @@ -193,20 +193,24 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.subprotocol = subprotocol if compress { + if !isValidCompressionLevel(u.CompressionLevel) { + return nil, errors.New("websocket: invalid compression level") + } + c.compressionLevel = u.CompressionLevel switch { case contextTakeover && u.EnableContextTakeover: c.contextTakeover = contextTakeover - var fwf contextTakeoverWriterFactory - fwf.fw, _ = flate.NewWriter(&fwf.tw, u.CompressionLevel) - c.newCompressionWriter = fwf.newCompressionWriter + var wf contextTakeoverWriterFactory + wf.fw, _ = flate.NewWriter(&wf.tw, u.CompressionLevel) + c.newCompressionWriter = wf.newCompressionWriter - var frf contextTakeoverReaderFactory + var rf contextTakeoverReaderFactory fr := flate.NewReader(nil) - frf.fr = fr - c.newDecompressionReader = frf.newDeCompressionReader + rf.fr = fr + c.newDecompressionReader = rf.newDeCompressionReader default: c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover From e79bb70823f54af66a3db6c04b30a2610b3adae9 Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 2 Feb 2018 13:48:44 +0900 Subject: [PATCH 31/52] upgrade: add context-takeover test to client_server_test --- client_server_test.go | 100 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 8 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 266db508..995d737a 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -42,6 +42,8 @@ var cstDialer = Dialer{ type cstHandler struct{ *testing.T } +type cstContextTakeoverHandler struct{ *testing.T } + type cstServer struct { *httptest.Server URL string @@ -61,6 +63,14 @@ func newServer(t *testing.T) *cstServer { return &s } +func newContextTakeoverServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewServer(cstContextTakeoverHandler{t}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + func newTLSServer(t *testing.T) *cstServer { var s cstServer s.Server = httptest.NewTLSServer(cstHandler{t}) @@ -118,6 +128,80 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func (t cstContextTakeoverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != cstPath { + t.Logf("path=%v, want %v", r.URL.Path, cstPath) + http.Error(w, "bad path", 400) + return + } + if r.URL.RawQuery != cstRawQuery { + t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) + http.Error(w, "bad path", 400) + return + } + subprotos := Subprotocols(r) + if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { + t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) + http.Error(w, "bad protocol", 400) + return + } + cu := cstUpgrader + cu.CompressionLevel = defaultCompressionLevel + cu.EnableContextTakeover = true + ws, err := cu.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) + if err != nil { + t.Logf("Upgrade: %v", err) + return + } + defer ws.Close() + + if ws.Subprotocol() != "p1" { + t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) + ws.Close() + return + } + + // first message + op, rd, err := ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } + + // second message + op, rd, err = ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err = ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } +} + func makeWsProto(s string) string { return "ws" + strings.TrimPrefix(s, "http") } @@ -161,11 +245,11 @@ func multipleSendRecv(t *testing.T, ws *Conn) { t.Fatalf("message=%s, want %s", p, message) } - message_2 := "Can you read message?" + nextMessage := "Can you read message?" if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { t.Fatalf("SetWriteDeadline: %v", err) } - if err := ws.WriteMessage(TextMessage, []byte(message_2)); err != nil { + if err := ws.WriteMessage(TextMessage, []byte(nextMessage)); err != nil { t.Fatalf("_WriteMessage: %v", err) } if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { @@ -174,10 +258,10 @@ func multipleSendRecv(t *testing.T, ws *Conn) { _, p, err = ws.ReadMessage() if err != nil { - t.Fatalf("_ReadMessage: %v", err) // _ReadMessage: websocket: close 1006 (abnormal closure): unexpected EOF + t.Fatalf("_ReadMessage: %v", err) } - if string(p) != message { - t.Fatalf("_message=%s, want %s", p, message_2) + if string(p) != nextMessage { + t.Fatalf("_message=%s, want %s", p, nextMessage) } } @@ -562,20 +646,20 @@ func TestDialCompression(t *testing.T) { } func TestDialCompressionOfContextTakeover(t *testing.T) { - s := newServer(t) + s := newContextTakeoverServer(t) defer s.Close() dialer := cstDialer dialer.EnableCompression = true dialer.EnableContextTakeover = true + dialer.CompressionLevel = 2 ws, _, err := dialer.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } defer ws.Close() - // Todo multiple send and receive. - sendRecv(t, ws) + multipleSendRecv(t, ws) } func TestSocksProxyDial(t *testing.T) { From 4131adc64e706bb210b74a6b6f32a5ab0590a693 Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 2 Feb 2018 14:30:07 +0900 Subject: [PATCH 32/52] upgrade: remove readBench, it read blank data --- compression_test.go | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/compression_test.go b/compression_test.go index dedc3864..0af6ed84 100644 --- a/compression_test.go +++ b/compression_test.go @@ -82,40 +82,6 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { b.ReportAllocs() } -func BenchmarkReadWithCompression(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) - c.enableWriteCompression = true - c.newDecompressionReader = decompressNoContextTakeover - messages := textMessages(100) - b.ResetTimer() - for i := 0; i < b.N; i++ { - r := bytes.NewReader(messages[i%len(messages)]) - reader := c.newDecompressionReader(r) - ioutil.ReadAll(reader) - } - b.ReportAllocs() -} - -func BenchmarkReadWithCompressionOfContextTakeover(b *testing.B) { - w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) - c.enableWriteCompression = true - c.contextTakeover = true - var frf contextTakeoverReaderFactory - fr := flate.NewReader(nil) - frf.fr = fr - c.newDecompressionReader = frf.newDeCompressionReader - messages := textMessages(100) - b.ResetTimer() - for i := 0; i < b.N; i++ { - r := bytes.NewReader(messages[i%len(messages)]) - reader := c.newDecompressionReader(r) - ioutil.ReadAll(reader) - } - b.ReportAllocs() -} - func TestValidCompressionLevel(t *testing.T) { c := newConn(fakeNetConn{}, false, 1024, 1024) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { From df2a0ccfb120c9e15e469ceb2aa4594fa754e803 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 7 Feb 2018 11:02:34 +0900 Subject: [PATCH 33/52] mod: comment --- client.go | 5 ++--- conn.go | 2 ++ server.go | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 1f630965..e1b1e80b 100644 --- a/client.go +++ b/client.go @@ -75,8 +75,7 @@ type Dialer struct { // EnableCompression specifies if the client should attempt to negotiate // per message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. Currently only "no context - // takeover" modes are supported. + // guarantee that compression will be supported. EnableCompression bool // Jar specifies the cookie jar. @@ -84,7 +83,7 @@ type Dialer struct { // in responses. Jar http.CookieJar - // CompressionLeval is set for contextTakeoer. + // CompressionLevel is passed to conn when the compression setting is true. CompressionLevel int // EnableContextTakeover specifies specifies if the client should attempt to negotiate diff --git a/conn.go b/conn.go index 025d284f..ed4e96d8 100644 --- a/conn.go +++ b/conn.go @@ -1147,6 +1147,8 @@ func (c *Conn) EnableWriteCompression(enable bool) { // binary messages. This function is a noop if compression was not negotiated // with the peer. See the compress/flate package for a description of // compression levels. +// If you use context-takeover, do not specify a compression level from this method. +// Please set it to Dialer or Upgrader in advance. func (c *Conn) SetCompressionLevel(level int) error { if !isValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") diff --git a/server.go b/server.go index 109220af..12c8e1b8 100644 --- a/server.go +++ b/server.go @@ -55,11 +55,10 @@ type Upgrader struct { // EnableCompression specify if the server should attempt to negotiate per // message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. Currently only "no context - // takeover" modes are supported. + // guarantee that compression will be supported. EnableCompression bool - // CompressionLeval is set for contextTakeoer. + // CompressionLevel is passed to conn when the compression setting is true. CompressionLevel int // EnableContextTakeover specifies specifies if the client should attempt to negotiate From 6282a7bdcf03a544826014bcb7aa9302ffed8348 Mon Sep 17 00:00:00 2001 From: misu Date: Wed, 7 Feb 2018 11:36:55 +0900 Subject: [PATCH 34/52] mod: doc.go --- doc.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/doc.go b/doc.go index dcce1a63..4469dabd 100644 --- a/doc.go +++ b/doc.go @@ -171,10 +171,21 @@ // // conn.EnableWriteCompression(false) // -// Currently this package does not support compression with "context takeover". +// Currently this package support compression with "context takeover". // This means that messages must be compressed and decompressed in isolation, // without retaining sliding window or dictionary state across messages. For // more details refer to RFC 7692. // +// If you want to use it, please do as follows. +// +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// EnableContextTakeover: true, +// CompressionLevel: 2, // default 0 +// } +// +// +// Since compression level is passed to Conn, please do not set it later. +// // Use of compression is experimental and may result in decreased performance. package websocket From e9a52af44b4e7d9da189abc18177757e2114b03d Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 9 Feb 2018 10:13:43 +0900 Subject: [PATCH 35/52] mod: unnecessary format change --- conn.go | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/conn.go b/conn.go index ed4e96d8..1d15cd5b 100644 --- a/conn.go +++ b/conn.go @@ -83,14 +83,20 @@ const ( PongMessage = 10 ) -type ( - // netError satisfies the net Error interface. - netError struct { - msg string - temporary bool - timeout bool - } -) +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLi mit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } @@ -174,14 +180,6 @@ var ( errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") - - // ErrCloseSent is returned when the application writes a message to the - // connection after sending a close message. - ErrCloseSent = errors.New("websocket: close sent") - - // ErrReadLimit is returned when reading a message that is larger than the - // read limit set for the connection. - ErrReadLimit = errors.New("websocket: read limit exceeded") ) func newMaskKey() [4]byte { From f68770a434393d5164101862146eed747e307294 Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 9 Feb 2018 10:16:19 +0900 Subject: [PATCH 36/52] mod: unnecessary format change --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 1d15cd5b..c3f73fe4 100644 --- a/conn.go +++ b/conn.go @@ -87,7 +87,7 @@ const ( // connection after sending a close message. var ErrCloseSent = errors.New("websocket: close sent") -// ErrReadLi mit is returned when reading a message that is larger than the +// ErrReadLimit is returned when reading a message that is larger than the // read limit set for the connection. var ErrReadLimit = errors.New("websocket: read limit exceeded") From 7fa60f84a1ffe62e16cef987ce9d21a0b3d889a7 Mon Sep 17 00:00:00 2001 From: misu Date: Mon, 5 Mar 2018 10:03:17 +0900 Subject: [PATCH 37/52] mod: gitignore --- .gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index f0a5301e..ac710204 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,4 @@ _testmain.go *.exe .idea/ -*.iml -.vscode/ -*.test \ No newline at end of file +*.iml \ No newline at end of file From d0111e646e1b6ef4c013b22c40fa476b43dfe1b1 Mon Sep 17 00:00:00 2001 From: claudia-jones <36607057+claudia-jones@users.noreply.github.com> Date: Sun, 18 Feb 2018 16:00:50 -0800 Subject: [PATCH 38/52] Simplify echo example client (#349) Use existing `done` channel to signal that reader is done instead of closing the connection. --- examples/echo/client.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/echo/client.go b/examples/echo/client.go index 6578094e..bf0e6573 100644 --- a/examples/echo/client.go +++ b/examples/echo/client.go @@ -38,7 +38,6 @@ func main() { done := make(chan struct{}) go func() { - defer c.Close() defer close(done) for { _, message, err := c.ReadMessage() @@ -55,6 +54,8 @@ func main() { for { select { + case <-done: + return case t := <-ticker.C: err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) if err != nil { @@ -63,8 +64,9 @@ func main() { } case <-interrupt: log.Println("interrupt") - // To cleanly close a connection, a client should send a close - // frame and wait for the server to close the connection. + + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { log.Println("write close:", err) @@ -74,7 +76,6 @@ func main() { case <-done: case <-time.After(time.Second): } - c.Close() return } } From 5241e533fa4d272867f5b24eb0f449d1f6513db0 Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Fri, 16 Feb 2018 00:24:17 +0300 Subject: [PATCH 39/52] Use latest patch releases of Go --- .travis.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9f233f98..ce7ec75e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,12 +3,12 @@ sudo: false matrix: include: - - go: 1.4 - - go: 1.5 - - go: 1.6 - - go: 1.7 - - go: 1.8 - - go: 1.9 + - go: 1.4.x + - go: 1.5.x + - go: 1.6.x + - go: 1.7.x + - go: 1.8.x + - go: 1.9.x - go: tip allow_failures: - go: tip From 3ab5e92ee8f290c0594ca6e647fe509d30d7bffe Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Sun, 18 Feb 2018 23:06:29 -0800 Subject: [PATCH 40/52] Travis config: add Go 1.10.x, revert 1.4.x to 1.4 1.4.x is missing go vet --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ce7ec75e..1f730470 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,12 +3,13 @@ sudo: false matrix: include: - - go: 1.4.x + - go: 1.4 - go: 1.5.x - go: 1.6.x - go: 1.7.x - go: 1.8.x - go: 1.9.x + - go: 1.10.x - go: tip allow_failures: - go: tip From c13439fe37fb7357fb71770ec0dc34261f10f1aa Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 26 Feb 2018 12:30:15 +0900 Subject: [PATCH 41/52] Modify http status code to variable --- client_server_test.go | 14 +++++++------- examples/autobahn/server.go | 4 ++-- examples/chat/main.go | 4 ++-- examples/command/main.go | 4 ++-- examples/filewatch/main.go | 4 ++-- server.go | 2 +- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 995d737a..42196321 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -82,18 +82,18 @@ func newTLSServer(t *testing.T) *cstServer { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path != cstPath { t.Logf("path=%v, want %v", r.URL.Path, cstPath) - http.Error(w, "bad path", 400) + http.Error(w, "bad path", http.StatusBadRequest) return } if r.URL.RawQuery != cstRawQuery { t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) - http.Error(w, "bad path", 400) + http.Error(w, "bad path", http.StatusBadRequest) return } subprotos := Subprotocols(r) if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) - http.Error(w, "bad protocol", 400) + http.Error(w, "bad protocol", http.StatusBadRequest) return } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) @@ -283,13 +283,13 @@ func TestProxyDial(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { connect = true - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } if !connect { t.Log("connect not received") - http.Error(w, "connect not received", 405) + http.Error(w, "connect not received", http.StatusMethodNotAllowed) return } origHandler.ServeHTTP(w, r) @@ -323,13 +323,13 @@ func TestProxyAuthorizationDial(t *testing.T) { expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { connect = true - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } if !connect { t.Log("connect with proxy authorization not received") - http.Error(w, "connect with proxy authorization not received", 405) + http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) return } origHandler.ServeHTTP(w, r) diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index 3db880f9..c2d6ee50 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -157,11 +157,11 @@ func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found.", 404) + http.Error(w, "Not found.", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/examples/chat/main.go b/examples/chat/main.go index 74615d59..9d4737a6 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -15,11 +15,11 @@ var addr = flag.String("addr", ":8080", "http service address") func serveHome(w http.ResponseWriter, r *http.Request) { log.Println(r.URL) if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } http.ServeFile(w, r, "home.html") diff --git a/examples/command/main.go b/examples/command/main.go index 239c5c85..304f1a52 100644 --- a/examples/command/main.go +++ b/examples/command/main.go @@ -167,11 +167,11 @@ func serveWs(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } http.ServeFile(w, r, "home.html") diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index f5f9da5c..b834ed39 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -130,11 +130,11 @@ func serveWs(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/server.go b/server.go index 12c8e1b8..db4bd7f7 100644 --- a/server.go +++ b/server.go @@ -283,7 +283,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade // of the same origin policy check is: // // if req.Header.Get("Origin") != "http://"+req.Host { -// http.Error(w, "Origin not allowed", 403) +// http.Error(w, "Origin not allowed", http.StatusForbidden) // return // } // From 416a1d5b7b06f6782c25147f9bcbe7b63c0134ca Mon Sep 17 00:00:00 2001 From: Carter Jones Date: Sun, 4 Mar 2018 11:58:07 -0800 Subject: [PATCH 42/52] add newline and remove extra space --- .gitignore | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index ac710204..cd3fcd1e 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ _testmain.go *.exe .idea/ -*.iml \ No newline at end of file +*.iml diff --git a/README.md b/README.md index 33c3d2be..20e391f8 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn Write message using io.WriteCloserYesNo, see note 3 -Notes: +Notes: 1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). 2. The application can get the type of a received data message by implementing From af46d4fe1f0b6f48165ba0a02a81dd79b4eac205 Mon Sep 17 00:00:00 2001 From: misu Date: Mon, 5 Mar 2018 17:57:15 +0900 Subject: [PATCH 43/52] mod: modification of review --- client.go | 28 ++---- client_server_test.go | 207 ++++++++++++++---------------------------- compression.go | 15 ++- compression_test.go | 3 - conn.go | 4 - conn_test.go | 2 - server.go | 27 ++---- 7 files changed, 93 insertions(+), 193 deletions(-) diff --git a/client.go b/client.go index e1b1e80b..6ad67007 100644 --- a/client.go +++ b/client.go @@ -73,9 +73,8 @@ type Dialer struct { // Subprotocols specifies the client's requested subprotocols. Subprotocols []string - // EnableCompression specifies if the client should attempt to negotiate - // per message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). EnableCompression bool // Jar specifies the cookie jar. @@ -83,13 +82,10 @@ type Dialer struct { // in responses. Jar http.CookieJar - // CompressionLevel is passed to conn when the compression setting is true. - CompressionLevel int - - // EnableContextTakeover specifies specifies if the client should attempt to negotiate - // per message compression with context-takeover (RFC 7692). - // but window bits is allowed only 15, because go's flate library support 15 bits only. - EnableContextTakeover bool + // AllowClientContextTakeover specifies whether the server will negotiate client context + // takeover for per message compression. Context takeover improves compression at the + // the cost of using more memory. + AllowClientContextTakeover bool } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -205,7 +201,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } switch { - case d.EnableCompression && d.EnableContextTakeover: + case d.EnableCompression && d.AllowClientContextTakeover: req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_max_window_bits=15; client_max_window_bits=15") case d.EnableCompression: req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") @@ -286,13 +282,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) - if d.EnableCompression { - if !isValidCompressionLevel(d.CompressionLevel) { - return nil, nil, errors.New("websocket: invalid compression level") - } - conn.compressionLevel = d.CompressionLevel - } - if err := req.Write(netConn); err != nil { return nil, nil, err } @@ -331,10 +320,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re switch { case cmwb && smwb: - conn.contextTakeover = true - var wf contextTakeoverWriterFactory - wf.fw, _ = flate.NewWriter(&wf.tw, d.CompressionLevel) conn.newCompressionWriter = wf.newCompressionWriter var rf contextTakeoverReaderFactory diff --git a/client_server_test.go b/client_server_test.go index 42196321..88fcd3fb 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -40,9 +40,14 @@ var cstDialer = Dialer{ HandshakeTimeout: 30 * time.Second, } -type cstHandler struct{ *testing.T } +type cstHandlerConfig struct { + contextTakeover bool +} -type cstContextTakeoverHandler struct{ *testing.T } +type cstHandler struct { + *testing.T + cstHandlerConfig +} type cstServer struct { *httptest.Server @@ -55,25 +60,17 @@ const ( cstRequestURI = cstPath + "?" + cstRawQuery ) -func newServer(t *testing.T) *cstServer { - var s cstServer - s.Server = httptest.NewServer(cstHandler{t}) - s.Server.URL += cstRequestURI - s.URL = makeWsProto(s.Server.URL) - return &s -} - -func newContextTakeoverServer(t *testing.T) *cstServer { +func newServer(t *testing.T, c cstHandlerConfig) *cstServer { var s cstServer - s.Server = httptest.NewServer(cstContextTakeoverHandler{t}) + s.Server = httptest.NewServer(cstHandler{t, c}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } -func newTLSServer(t *testing.T) *cstServer { +func newTLSServer(t *testing.T, c cstHandlerConfig) *cstServer { var s cstServer - s.Server = httptest.NewTLSServer(cstHandler{t}) + s.Server = httptest.NewTLSServer(cstHandler{t, c}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s @@ -96,59 +93,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad protocol", http.StatusBadRequest) return } - ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) - if err != nil { - t.Logf("Upgrade: %v", err) - return - } - defer ws.Close() - - if ws.Subprotocol() != "p1" { - t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) - ws.Close() - return - } - op, rd, err := ws.NextReader() - if err != nil { - t.Logf("NextReader: %v", err) - return - } - wr, err := ws.NextWriter(op) - if err != nil { - t.Logf("NextWriter: %v", err) - return - } - if _, err = io.Copy(wr, rd); err != nil { - t.Logf("NextWriter: %v", err) - return - } - if err := wr.Close(); err != nil { - t.Logf("Close: %v", err) - return - } -} - -func (t cstContextTakeoverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != cstPath { - t.Logf("path=%v, want %v", r.URL.Path, cstPath) - http.Error(w, "bad path", 400) - return - } - if r.URL.RawQuery != cstRawQuery { - t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) - http.Error(w, "bad path", 400) - return - } - subprotos := Subprotocols(r) - if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { - t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) - http.Error(w, "bad protocol", 400) - return + if t.contextTakeover { + cstUpgrader.AllowServerContextTakeover = true } - cu := cstUpgrader - cu.CompressionLevel = defaultCompressionLevel - cu.EnableContextTakeover = true - ws, err := cu.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) + ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) if err != nil { t.Logf("Upgrade: %v", err) return @@ -160,8 +108,6 @@ func (t cstContextTakeoverHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ ws.Close() return } - - // first message op, rd, err := ws.NextReader() if err != nil { t.Logf("NextReader: %v", err) @@ -181,24 +127,26 @@ func (t cstContextTakeoverHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ return } - // second message - op, rd, err = ws.NextReader() - if err != nil { - t.Logf("NextReader: %v", err) - return - } - wr, err = ws.NextWriter(op) - if err != nil { - t.Logf("NextWriter: %v", err) - return - } - if _, err = io.Copy(wr, rd); err != nil { - t.Logf("NextWriter: %v", err) - return - } - if err := wr.Close(); err != nil { - t.Logf("Close: %v", err) - return + // for multipleSendRecv when context takeover. + if t.contextTakeover { + op, rd, err := ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } } } @@ -227,47 +175,29 @@ func sendRecv(t *testing.T, ws *Conn) { } func multipleSendRecv(t *testing.T, ws *Conn) { - message := "Hello World!" - if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { - t.Fatalf("SetWriteDeadline: %v", err) - } - if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { - t.Fatalf("WriteMessage: %v", err) - } - if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { - t.Fatalf("SetReadDeadline: %v", err) - } - _, p, err := ws.ReadMessage() - if err != nil { - t.Fatalf("ReadMessage: %v", err) - } - if string(p) != message { - t.Fatalf("message=%s, want %s", p, message) - } - - nextMessage := "Can you read message?" - if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { - t.Fatalf("SetWriteDeadline: %v", err) - } - if err := ws.WriteMessage(TextMessage, []byte(nextMessage)); err != nil { - t.Fatalf("_WriteMessage: %v", err) - } - if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { - t.Fatalf("_SetReadDeadline: %v", err) - } - - _, p, err = ws.ReadMessage() - if err != nil { - t.Fatalf("_ReadMessage: %v", err) - } - if string(p) != nextMessage { - t.Fatalf("_message=%s, want %s", p, nextMessage) + for _, message := range []string{"Hello World", "Can you read message?"} { + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + if string(p) != message { + t.Fatalf("message=%s, want %s", p, message) + } } } func TestProxyDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() surl, _ := url.Parse(s.Server.URL) @@ -304,7 +234,7 @@ func TestProxyDial(t *testing.T) { } func TestProxyAuthorizationDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() surl, _ := url.Parse(s.Server.URL) @@ -344,7 +274,7 @@ func TestProxyAuthorizationDial(t *testing.T) { } func TestDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, _, err := cstDialer.Dial(s.URL, nil) @@ -356,7 +286,7 @@ func TestDial(t *testing.T) { } func TestDialCookieJar(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() jar, _ := cookiejar.New(nil) @@ -404,7 +334,7 @@ func TestDialCookieJar(t *testing.T) { } func TestDialTLS(t *testing.T) { - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() certs := x509.NewCertPool() @@ -430,7 +360,7 @@ func TestDialTLS(t *testing.T) { func xTestDialTLSBadCert(t *testing.T) { // This test is deactivated because of noisy logging from the net/http package. - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() ws, _, err := cstDialer.Dial(s.URL, nil) @@ -441,7 +371,7 @@ func xTestDialTLSBadCert(t *testing.T) { } func TestDialTLSNoVerify(t *testing.T) { - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -455,7 +385,7 @@ func TestDialTLSNoVerify(t *testing.T) { } func TestDialTimeout(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -468,7 +398,7 @@ func TestDialTimeout(t *testing.T) { } func TestDialBadScheme(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, _, err := cstDialer.Dial(s.Server.URL, nil) @@ -479,7 +409,7 @@ func TestDialBadScheme(t *testing.T) { } func TestDialBadOrigin(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) @@ -496,7 +426,7 @@ func TestDialBadOrigin(t *testing.T) { } func TestDialBadHeader(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() for _, k := range []string{"Upgrade", @@ -543,7 +473,7 @@ func TestBadMethod(t *testing.T) { } func TestHandshake(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) @@ -605,7 +535,7 @@ func TestRespOnBadHandshake(t *testing.T) { // TestHostHeader confirms that the host header provided in the call to Dial is // sent to the server. func TestHostHeader(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() specifiedHost := make(chan string, 1) @@ -632,7 +562,7 @@ func TestHostHeader(t *testing.T) { } func TestDialCompression(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() dialer := cstDialer @@ -646,13 +576,12 @@ func TestDialCompression(t *testing.T) { } func TestDialCompressionOfContextTakeover(t *testing.T) { - s := newContextTakeoverServer(t) + s := newServer(t, cstHandlerConfig{true}) defer s.Close() dialer := cstDialer dialer.EnableCompression = true - dialer.EnableContextTakeover = true - dialer.CompressionLevel = 2 + dialer.AllowClientContextTakeover = true ws, _, err := dialer.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) @@ -663,7 +592,7 @@ func TestDialCompressionOfContextTakeover(t *testing.T) { } func TestSocksProxyDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() proxyListener, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/compression.go b/compression.go index 56ee05a4..394f89cf 100644 --- a/compression.go +++ b/compression.go @@ -172,10 +172,17 @@ type ( } ) -func (f *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser { - f.tw.w = w - f.tw.n = 0 - return &flateTakeoverWriteWrapper{f} +func (wf *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser { + // Set writer on first write. + // In order to guarantee the consistency of compression with the client, + // do not reassign later. + if wf.fw == nil { + wf.fw, _ = flate.NewWriter(&wf.tw, level) + } + + wf.tw.w = w + wf.tw.n = 0 + return &flateTakeoverWriteWrapper{wf} } func (w *flateTakeoverWriteWrapper) Write(p []byte) (int, error) { diff --git a/compression_test.go b/compression_test.go index 0af6ed84..3aacf94e 100644 --- a/compression_test.go +++ b/compression_test.go @@ -2,7 +2,6 @@ package websocket import ( "bytes" - "compress/flate" "fmt" "io" "io/ioutil" @@ -71,9 +70,7 @@ func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) messages := textMessages(100) c.enableWriteCompression = true - c.contextTakeover = true var f contextTakeoverWriterFactory - f.fw, _ = flate.NewWriter(&f.tw, 2) // level is specified in Dialer, Upgrader c.newCompressionWriter = f.newCompressionWriter b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/conn.go b/conn.go index c3f73fe4..37cb69ec 100644 --- a/conn.go +++ b/conn.go @@ -263,8 +263,6 @@ type Conn struct { readDecompress bool // whether last read frame had RSV1 set newDecompressionReader func(io.Reader) io.ReadCloser // arges may flateReadWrapper struct - - contextTakeover bool } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -1145,8 +1143,6 @@ func (c *Conn) EnableWriteCompression(enable bool) { // binary messages. This function is a noop if compression was not negotiated // with the peer. See the compress/flate package for a description of // compression levels. -// If you use context-takeover, do not specify a compression level from this method. -// Please set it to Dialer or Upgrader in advance. func (c *Conn) SetCompressionLevel(level int) error { if !isValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") diff --git a/conn_test.go b/conn_test.go index 7542a3ae..39ede0eb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -109,14 +109,12 @@ func TestFraming(t *testing.T) { var wf contextTakeoverWriterFactory wf.fw, _ = flate.NewWriter(&wf.tw, defaultCompressionLevel) wc.newCompressionWriter = wf.newCompressionWriter - wc.contextTakeover = true var rf contextTakeoverReaderFactory fr := flate.NewReader(nil) rf.fr = fr rc.newDecompressionReader = rf.newDeCompressionReader - rc.contextTakeover = true case compressCondition.compress: wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover diff --git a/server.go b/server.go index db4bd7f7..fc55fc69 100644 --- a/server.go +++ b/server.go @@ -54,17 +54,13 @@ type Upgrader struct { CheckOrigin func(r *http.Request) bool // EnableCompression specify if the server should attempt to negotiate per - // message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. + // message compression (RFC 7692). EnableCompression bool - // CompressionLevel is passed to conn when the compression setting is true. - CompressionLevel int - - // EnableContextTakeover specifies specifies if the client should attempt to negotiate - // per message compression with context-takeover (RFC 7692). - // but window bits is allowed only 15, because go's flate library support 15 bits only. - EnableContextTakeover bool + // AllowServerContextTakeover specifies whether the server will negotiate server context + // takeover for per message compression. Context takeover improves compression at the + // cost of using more memory. + AllowServerContextTakeover bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -196,18 +192,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.subprotocol = subprotocol if compress { - if !isValidCompressionLevel(u.CompressionLevel) { - return nil, errors.New("websocket: invalid compression level") - } - - c.compressionLevel = u.CompressionLevel - switch { - case contextTakeover && u.EnableContextTakeover: - c.contextTakeover = contextTakeover - + case contextTakeover && u.AllowServerContextTakeover: var wf contextTakeoverWriterFactory - wf.fw, _ = flate.NewWriter(&wf.tw, u.CompressionLevel) c.newCompressionWriter = wf.newCompressionWriter var rf contextTakeoverReaderFactory @@ -231,7 +218,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } if compress { switch { - case contextTakeover && u.EnableContextTakeover: + case contextTakeover && u.AllowServerContextTakeover: p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...) default: p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) From 360a199e488c2f19d53995ee6fbb92c10ca79ee3 Mon Sep 17 00:00:00 2001 From: misu Date: Mon, 5 Mar 2018 18:37:27 +0900 Subject: [PATCH 44/52] mod: doc.go --- doc.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/doc.go b/doc.go index 4469dabd..4f7ada2a 100644 --- a/doc.go +++ b/doc.go @@ -180,12 +180,8 @@ // // var upgrader = websocket.Upgrader{ // EnableCompression: true, -// EnableContextTakeover: true, -// CompressionLevel: 2, // default 0 +// AllowServerContextTakeover: true, // } // -// -// Since compression level is passed to Conn, please do not set it later. -// // Use of compression is experimental and may result in decreased performance. package websocket From 658a2fb894f477e0c12f4b2e8be94a7c1111736b Mon Sep 17 00:00:00 2001 From: misu Date: Sun, 1 Apr 2018 15:12:58 +0900 Subject: [PATCH 45/52] temp: autobahn/server.go --- examples/autobahn/server.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index c2d6ee50..ec4cc7d4 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -14,13 +14,14 @@ import ( "time" "unicode/utf8" - "github.com/gorilla/websocket" + "github.com/smith-30/websocket" ) var upgrader = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - EnableCompression: true, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + EnableCompression: true, + AllowServerContextTakeover: true, CheckOrigin: func(r *http.Request) bool { return true }, From 811803b3ed6846597f1651360f216953f3d93d9f Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 6 Apr 2018 13:30:19 +0900 Subject: [PATCH 46/52] add: comment --- compression.go | 1 + 1 file changed, 1 insertion(+) diff --git a/compression.go b/compression.go index 394f89cf..12fadd05 100644 --- a/compression.go +++ b/compression.go @@ -209,6 +209,7 @@ func (w *flateTakeoverWriteWrapper) Close() error { return err2 } +// modules for compression context takeover type ( contextTakeoverReaderFactory struct { fr io.ReadCloser From 6caf089088f3310c23313f33874cda8bbed0eb9f Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 6 Apr 2018 14:01:48 +0900 Subject: [PATCH 47/52] mod: examples/autobahn/server.go to pass build go 1.4, 1.5 --- examples/autobahn/server.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index ec4cc7d4..c2d6ee50 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -14,14 +14,13 @@ import ( "time" "unicode/utf8" - "github.com/smith-30/websocket" + "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - EnableCompression: true, - AllowServerContextTakeover: true, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + EnableCompression: true, CheckOrigin: func(r *http.Request) bool { return true }, From cd973fdcfcd8248af577b6e85f9f28b94304d21e Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 20 Apr 2018 09:55:03 +0900 Subject: [PATCH 48/52] mod: pointed out part --- conn.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index 37cb69ec..c5635be8 100644 --- a/conn.go +++ b/conn.go @@ -261,8 +261,8 @@ type Conn struct { readErrCount int messageReader *messageReader // the current low-level reader - readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader) io.ReadCloser // arges may flateReadWrapper struct + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -752,7 +752,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if err != nil { return err } - if _, err := w.Write(data); err != nil { + if _, err = w.Write(data); err != nil { return err } From b27407e5fd6314edca704797b4472815bb1f1abf Mon Sep 17 00:00:00 2001 From: misu Date: Fri, 20 Apr 2018 10:07:26 +0900 Subject: [PATCH 49/52] mod: no need to change line feed --- compression.go | 13 +------------ conn.go | 7 ------- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/compression.go b/compression.go index 12fadd05..ecf1ab75 100644 --- a/compression.go +++ b/compression.go @@ -35,7 +35,7 @@ var ( func decompressNoContextTakeover(r io.Reader) io.ReadCloser { fr, _ := flateReaderPool.Get().(io.ReadCloser) fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) - return &flateReadWrapper{fr: fr} + return &flateReadWrapper{fr} } func isValidCompressionLevel(level int) bool { @@ -100,7 +100,6 @@ func (w *flateWriteWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } - return w.fw.Write(p) } @@ -109,22 +108,17 @@ func (w *flateWriteWrapper) Close() error { return errWriteClosed } err1 := w.fw.Flush() - w.p.Put(w.fw) w.fw = nil - if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } - w.tw.p = [4]byte{} w.tw.n = 0 - err2 := w.tw.w.Close() if err1 != nil { return err1 } - return err2 } @@ -136,16 +130,13 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) { if r.fr == nil { return 0, io.ErrClosedPipe } - n, err := r.fr.Read(p) - if err == io.EOF { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after // this final read. r.Close() } - return n, err } @@ -154,9 +145,7 @@ func (r *flateReadWrapper) Close() error { return io.ErrClosedPipe } err := r.fr.Close() - flateReaderPool.Put(r.fr) - r.fr = nil return err } diff --git a/conn.go b/conn.go index c5635be8..1fe05133 100644 --- a/conn.go +++ b/conn.go @@ -333,7 +333,6 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in enableWriteCompression: true, compressionLevel: defaultCompressionLevel, } - c.SetCloseHandler(nil) c.SetPingHandler(nil) c.SetPongHandler(nil) @@ -755,7 +754,6 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if _, err = w.Write(data); err != nil { return err } - return w.Close() } @@ -952,7 +950,6 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { if c.readDecompress { c.reader = c.newDecompressionReader(c.reader) } - return frameType, c.reader, nil } } @@ -979,11 +976,9 @@ func (r *messageReader) Read(b []byte) (int, error) { for c.readErr == nil { if c.readRemaining > 0 { - // Determine the size of the data to be read. if int64(len(b)) > c.readRemaining { b = b[:c.readRemaining] } - n, err := c.br.Read(b) c.readErr = hideTempErr(err) if c.isServer { @@ -993,7 +988,6 @@ func (r *messageReader) Read(b []byte) (int, error) { if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } - return n, c.readErr } @@ -1031,7 +1025,6 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { return messageType, nil, err } p, err = ioutil.ReadAll(r) - return messageType, p, err } From 89ee8d4670b071c5751e30d4db939acc2f713301 Mon Sep 17 00:00:00 2001 From: misu Date: Tue, 11 Sep 2018 18:21:37 +0900 Subject: [PATCH 50/52] mod: test --- client_server_test.go | 8 ++++---- compression_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 11ff7827..bacd0601 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -449,7 +449,7 @@ func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } func TestHandshakeTimeout(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -465,7 +465,7 @@ func TestHandshakeTimeout(t *testing.T) { } func TestHandshakeTimeoutInContext(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialerWithoutHandshakeTimeout @@ -782,7 +782,7 @@ func TestTracingDialWithContext(t *testing.T) { } ctx := httptrace.WithClientTrace(context.Background(), trace) - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() certs := x509.NewCertPool() @@ -832,7 +832,7 @@ func TestEmptyTracingDialWithContext(t *testing.T) { trace := &httptrace.ClientTrace{} ctx := httptrace.WithClientTrace(context.Background(), trace) - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() certs := x509.NewCertPool() diff --git a/compression_test.go b/compression_test.go index c70f490e..9a6f1343 100644 --- a/compression_test.go +++ b/compression_test.go @@ -67,7 +67,7 @@ func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) messages := textMessages(100) c.enableWriteCompression = true var f contextTakeoverWriterFactory From 62bb07b760693a4aca5cb1d4940d5b66a347faf9 Mon Sep 17 00:00:00 2001 From: misu Date: Sun, 3 Feb 2019 11:30:02 +0900 Subject: [PATCH 51/52] add: comment --- compression.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/compression.go b/compression.go index ecf1ab75..882b46b1 100644 --- a/compression.go +++ b/compression.go @@ -201,7 +201,12 @@ func (w *flateTakeoverWriteWrapper) Close() error { // modules for compression context takeover type ( contextTakeoverReaderFactory struct { - fr io.ReadCloser + fr io.ReadCloser + + // this window is used in compress/flate.decompressor. + // since there is no interface for updating the dictionary in the structure, + // window is rewritten with this structure. + // although there is a Reset(), it becomes initialization of a dictionary. window []byte } From 985aed2c33cd562a032491fb0d474b3ddee0660d Mon Sep 17 00:00:00 2001 From: misu Date: Sun, 3 Feb 2019 14:49:04 +0900 Subject: [PATCH 52/52] mod: doc --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index 4f7ada2a..cc9c05b5 100644 --- a/doc.go +++ b/doc.go @@ -171,7 +171,7 @@ // // conn.EnableWriteCompression(false) // -// Currently this package support compression with "context takeover". +// Currently this package supports compression with "context takeover". // This means that messages must be compressed and decompressed in isolation, // without retaining sliding window or dictionary state across messages. For // more details refer to RFC 7692.