Permalink
Browse files

websocket, feat: support context takeover.

  • Loading branch information...
xicilion committed Oct 26, 2017
1 parent 8b31858 commit e96f23e9fa1ddf4f5ae4e9dd1201bbda2980b197
View
@@ -8,6 +8,7 @@
#include "ifs/WebSocket.h"
#include "ifs/ws.h"
#include "ifs/Stream.h"
#include "zlib.h"
#ifndef WEBSOCKET_H_
#define WEBSOCKET_H_
@@ -28,12 +29,12 @@ class WebSocket : public WebSocket_base {
{
}
WebSocket(Stream_base* stream, bool compress, exlib::string protocol, AsyncEvent* ac)
WebSocket(Stream_base* stream, exlib::string protocol, AsyncEvent* ac)
: m_stream(stream)
, m_ac(ac)
, m_protocol(protocol)
, m_masked(false)
, m_compress(compress)
, m_compress(false)
, m_maxSize(67108864)
, m_readyState(ws_base::_OPEN)
{
@@ -69,11 +70,16 @@ class WebSocket : public WebSocket_base {
void startRecv();
void endConnect(int32_t code, exlib::string reason);
void endConnect(SeekableStream_base* body);
void enableCompress();
public:
obj_ptr<Stream_base> m_stream;
AsyncEvent* m_ac;
obj_ptr<ZlibStream> m_deflate;
obj_ptr<ZlibStream> m_inflate;
obj_ptr<Buffer_base> m_flushTail;
obj_ptr<SeekableStream_base> m_buffer;
exlib::Locker m_lockEncode;
exlib::Locker m_lockBuffer;
@@ -7,6 +7,7 @@
#include "Message.h"
#include "ifs/WebSocketMessage.h"
#include "WebSocket.h"
#ifndef WEBSOCKETMESSAGE_H_
#define WEBSOCKETMESSAGE_H_
@@ -64,6 +65,8 @@ class WebSocketMessage : public WebSocketMessage_base {
public:
static result_t copy(Stream_base* from, Stream_base* to, int64_t bytes, uint32_t mask, AsyncEvent* ac);
result_t sendTo(Stream_base* stm, WebSocket* wss, AsyncEvent* ac);
result_t readFrom(Stream_base* stm, WebSocket* wss, AsyncEvent* ac);
public:
obj_ptr<Stream_base> m_stm;
View
@@ -93,6 +93,11 @@ class ZlibStream : public Stream_base {
strm.next_out = m_outBuffer;
}
void attach(Stream_base* stm)
{
m_stm = stm;
}
public:
// Stream_base
result_t read(int32_t bytes, obj_ptr<Buffer_base>& retVal, AsyncEvent* ac)
@@ -11,6 +11,7 @@
#include "EventInfo.h"
#include "ifs/io.h"
#include "ifs/http.h"
#include "ifs/zlib.h"
#include "Map.h"
#include <mbedtls/mbedtls/sha1.h>
#include "encoding.h"
@@ -101,7 +102,7 @@ class asyncSend : public AsyncState {
pThis->m_this->m_buffer = new MemoryStream();
pThis->set(encode_ok);
return pThis->m_msg->sendTo(pThis->m_this->m_buffer, pThis);
return pThis->m_msg->sendTo(pThis->m_this->m_buffer, pThis->m_this, pThis);
}
static int32_t encode_ok(AsyncState* pState, int32_t n)
@@ -290,7 +291,7 @@ result_t WebSocket_base::_new(exlib::string url, exlib::string protocol, exlib::
return hr;
if (hr != CALL_RETURN_NULL && !qstricmp(v.string().c_str(), "permessage-deflate", 18))
pThis->m_this->m_compress = true;
pThis->m_this->enableCompress();
pThis->m_httprep->get_stream(pThis->m_this->m_stream);
@@ -344,7 +345,7 @@ void WebSocket::startRecv()
pThis->m_msg = new WebSocketMessage(ws_base::_TEXT, false, false, pThis->m_this->m_maxSize);
pThis->set(event);
return pThis->m_msg->readFrom(pThis->m_this->m_stream, pThis);
return pThis->m_msg->readFrom(pThis->m_this->m_stream, pThis->m_this, pThis);
}
static int32_t event(AsyncState* pState, int32_t n)
@@ -464,6 +465,14 @@ void WebSocket::endConnect(SeekableStream_base* body)
endConnect(code, reason);
}
void WebSocket::enableCompress()
{
m_compress = true;
m_deflate = new defraw(NULL);
m_inflate = new infraw(NULL);
m_flushTail = new Buffer("\x0\x0\xff\xff", 4);
}
result_t WebSocket::get_url(exlib::string& retVal)
{
retVal = m_url;
@@ -129,7 +129,9 @@ result_t WebSocketHandler::invoke(object_base* v, obj_ptr<Handler_base>& retVal,
pThis->done(CALL_RETURN_NULL);
obj_ptr<WebSocketHandler> pHandler = pThis->m_pThis;
obj_ptr<WebSocket> sock = new WebSocket(pThis->m_stm, pThis->m_compress, "", pThis);
obj_ptr<WebSocket> sock = new WebSocket(pThis->m_stm, "", pThis);
if (pThis->m_compress)
sock->enableCompress();
Variant v = sock;
pHandler->_emit("accept", &v, 1);
@@ -191,24 +191,33 @@ result_t WebSocketMessage::copy(Stream_base* from, Stream_base* to, int64_t byte
return (new asyncCopy(from, to, bytes, mask, ac))->post(0);
}
result_t WebSocketMessage::sendTo(Stream_base* stm, AsyncEvent* ac)
result_t WebSocketMessage::sendTo(Stream_base* stm, WebSocket* wss, AsyncEvent* ac)
{
class asyncSendTo : public AsyncState {
public:
asyncSendTo(WebSocketMessage* pThis, Stream_base* stm,
AsyncEvent* ac)
asyncSendTo(WebSocketMessage* pThis, Stream_base* stm, WebSocket* wss, AsyncEvent* ac)
: AsyncState(ac)
, m_pThis(pThis)
, m_stm(stm)
, m_wss(wss)
, m_mask(0)
, m_take_over(false)
{
m_pThis->get_body(m_body);
if (m_pThis->m_compress) {
m_zip = new MemoryStream();
m_data = new MemoryStream();
if (m_wss && m_wss->m_compress) {
m_take_over = true;
m_wss->m_deflate->attach(m_data);
m_zip = m_wss->m_deflate;
} else
zlib_base::createDeflateRaw(m_data, m_zip);
set(deflate);
} else {
m_zip = m_body;
m_data = m_body;
set(head);
}
}
@@ -217,15 +226,32 @@ result_t WebSocketMessage::sendTo(Stream_base* stm, AsyncEvent* ac)
{
asyncSendTo* pThis = (asyncSendTo*)pState;
pThis->set(head);
pThis->m_body->rewind();
return zlib_base::deflateRawTo(pThis->m_body, pThis->m_zip, zlib_base::_DEFAULT_COMPRESSION, pThis);
pThis->set(flush);
return pThis->m_body->copyTo(pThis->m_zip, -1, pThis->m_size, pThis);
}
static int32_t flush(AsyncState* pState, int32_t n)
{
asyncSendTo* pThis = (asyncSendTo*)pState;
pThis->set(head);
return pThis->m_zip->flush(pThis);
}
static int32_t head(AsyncState* pState, int32_t n)
{
asyncSendTo* pThis = (asyncSendTo*)pState;
if (pThis->m_take_over)
pThis->m_wss->m_deflate->attach(NULL);
int64_t size;
pThis->m_data->size(size);
if (pThis->m_pThis->m_compress)
size -= 4;
pThis->m_size = size;
uint8_t buf[16];
int32_t pos = 0;
@@ -236,10 +262,6 @@ result_t WebSocketMessage::sendTo(Stream_base* stm, AsyncEvent* ac)
else
buf[0] = 0x80 | (type & 0x0f);
int64_t size;
pThis->m_zip->size(size);
pThis->m_size = size;
if (size < 126) {
buf[1] = (uint8_t)size;
pos = 2;
@@ -287,41 +309,57 @@ result_t WebSocketMessage::sendTo(Stream_base* stm, AsyncEvent* ac)
asyncSendTo* pThis = (asyncSendTo*)pState;
pThis->done();
pThis->m_zip->rewind();
return copy(pThis->m_zip, pThis->m_stm, pThis->m_size, pThis->m_mask, pThis);
pThis->m_data->rewind();
return copy(pThis->m_data, pThis->m_stm, pThis->m_size, pThis->m_mask, pThis);
}
virtual int32_t error(int32_t v)
{
if (m_take_over)
m_wss->m_deflate->attach(NULL);
return v;
}
public:
obj_ptr<SeekableStream_base> m_zip;
WebSocketMessage* m_pThis;
obj_ptr<Stream_base> m_zip;
obj_ptr<SeekableStream_base> m_data;
obj_ptr<WebSocketMessage> m_pThis;
obj_ptr<Stream_base> m_stm;
obj_ptr<WebSocket> m_wss;
obj_ptr<SeekableStream_base> m_body;
int64_t m_size;
uint32_t m_mask;
obj_ptr<Buffer_base> m_buffer;
bool m_take_over;
};
if (ac->isSync())
return CHECK_ERROR(CALL_E_NOSYNC);
return (new asyncSendTo(this, stm, ac))->post(0);
return (new asyncSendTo(this, stm, wss, ac))->post(0);
}
result_t WebSocketMessage::readFrom(Stream_base* stm, AsyncEvent* ac)
result_t WebSocketMessage::sendTo(Stream_base* stm, AsyncEvent* ac)
{
return sendTo(stm, NULL, ac);
}
result_t WebSocketMessage::readFrom(Stream_base* stm, WebSocket* wss, AsyncEvent* ac)
{
class asyncReadFrom : public AsyncState {
public:
asyncReadFrom(WebSocketMessage* pThis, Stream_base* stm,
AsyncEvent* ac)
asyncReadFrom(WebSocketMessage* pThis, Stream_base* stm, WebSocket* wss, AsyncEvent* ac)
: AsyncState(ac)
, m_pThis(pThis)
, m_stm(stm)
, m_wss(wss)
, m_fin(false)
, m_masked(false)
, m_fragmented(false)
, m_size(0)
, m_fullsize(0)
, m_mask(0)
, m_take_over(false)
{
m_pThis->get_body(m_body);
m_zip = m_body;
@@ -354,7 +392,13 @@ result_t WebSocketMessage::readFrom(Stream_base* stm, AsyncEvent* ac)
ch = strBuffer[0];
if (ch & 0x40) {
pThis->m_zip = new MemoryStream();
if (pThis->m_wss && pThis->m_wss->m_compress) {
pThis->m_take_over = true;
pThis->m_wss->m_inflate->attach(pThis->m_body);
pThis->m_zip = pThis->m_wss->m_inflate;
} else
zlib_base::createInflateRaw(pThis->m_body, pThis->m_zip);
pThis->m_pThis->m_compress = true;
} else if (ch & 0x70) {
pThis->m_pThis->m_error = 1007;
@@ -454,34 +498,46 @@ result_t WebSocketMessage::readFrom(Stream_base* stm, AsyncEvent* ac)
return 0;
}
if (pThis->m_pThis->m_compress)
pThis->set(inflate);
else
pThis->set(body_end);
return 0;
if (pThis->m_take_over) {
pThis->set(tail_end);
return pThis->m_zip->write(pThis->m_wss->m_flushTail, pThis);
}
pThis->set(body_end);
return pThis->m_zip->flush(pThis);
}
static int32_t inflate(AsyncState* pState, int32_t n)
static int32_t tail_end(AsyncState* pState, int32_t n)
{
asyncReadFrom* pThis = (asyncReadFrom*)pState;
pThis->set(body_end);
pThis->m_zip->rewind();
return zlib_base::inflateRawTo(pThis->m_zip, pThis->m_body, pThis);
return pThis->m_zip->flush(pThis);
}
static int32_t body_end(AsyncState* pState, int32_t n)
{
asyncReadFrom* pThis = (asyncReadFrom*)pState;
if (pThis->m_take_over)
pThis->m_wss->m_inflate->attach(NULL);
pThis->m_body->rewind();
return pThis->done();
}
virtual int32_t error(int32_t v)
{
if (m_take_over)
m_wss->m_inflate->attach(NULL);
return v;
}
public:
WebSocketMessage* m_pThis;
obj_ptr<WebSocketMessage> m_pThis;
obj_ptr<Stream_base> m_stm;
obj_ptr<SeekableStream_base> m_zip;
obj_ptr<WebSocket> m_wss;
obj_ptr<Stream_base> m_zip;
obj_ptr<SeekableStream_base> m_body;
obj_ptr<Buffer_base> m_buffer;
bool m_fin;
@@ -490,14 +546,20 @@ result_t WebSocketMessage::readFrom(Stream_base* stm, AsyncEvent* ac)
int64_t m_size;
int64_t m_fullsize;
uint32_t m_mask;
bool m_take_over;
};
if (ac->isSync())
return CHECK_ERROR(CALL_E_NOSYNC);
m_stm = stm;
return (new asyncReadFrom(this, stm, ac))->post(0);
return (new asyncReadFrom(this, stm, wss, ac))->post(0);
}
result_t WebSocketMessage::readFrom(Stream_base* stm, AsyncEvent* ac)
{
return readFrom(stm, NULL, ac);
}
result_t WebSocketMessage::get_stream(obj_ptr<Stream_base>& retVal)
View
@@ -366,7 +366,7 @@ describe('ws', () => {
s.stream.close();
});
it("echo compress", () => {
xit("echo compress", () => {
var s = connect(true);
test_msg(s, 10, true);

0 comments on commit e96f23e

Please sign in to comment.