diff --git a/modules/websocket/wsl_client.cpp b/modules/websocket/wsl_client.cpp index d695395635e84..b2d865a58f836 100644 --- a/modules/websocket/wsl_client.cpp +++ b/modules/websocket/wsl_client.cpp @@ -47,11 +47,16 @@ void WSLClient::_do_handshake() { _requested += sent; } else { - uint8_t byte = 0; int read = 0; - while (true) { - Error err = _connection->get_partial_data(&byte, 1, read); + if (_resp_pos >= WSL_MAX_HEADER_SIZE) { + // Header is too big + disconnect_from_host(); + _on_error(); + ERR_EXPLAIN("Response headers too big"); + ERR_FAIL(); + } + Error err = _connection->get_partial_data(&_resp_buf[_resp_pos], 1, read); if (err == ERR_FILE_EOF) { // We got a disconnect. disconnect_from_host(); @@ -66,16 +71,11 @@ void WSLClient::_do_handshake() { // Busy, wait next poll. break; } - // TODO lots of allocs. Use a buffer. - _response += byte; - if (_response.size() > WSL_MAX_HEADER_SIZE) { - // Header is too big - disconnect_from_host(); - _on_error(); - ERR_EXPLAIN("Response headers too big"); - ERR_FAIL(); - } - if (_response.ends_with("\r\n\r\n")) { + // Check "\r\n\r\n" header terminator + char *r = (char *)_resp_buf; + int l = _resp_pos; + if (l > 3 && r[l] == '\n' && r[l - 1] == '\r' && r[l - 2] == '\n' && r[l - 3] == '\r') { + r[l - 3] = '\0'; String protocol; // Response is over, verify headers and create peer. if (!_verify_headers(protocol)) { @@ -93,12 +93,14 @@ void WSLClient::_do_handshake() { _peer->make_context(data, _in_buf_size, _in_pkt_size, _out_buf_size, _out_pkt_size); _on_connect(protocol); } + _resp_pos += 1; } } } bool WSLClient::_verify_headers(String &r_protocol) { - Vector psa = _response.trim_suffix("\r\n\r\n").split("\r\n"); + String s = (char *)_resp_buf; + Vector psa = s.split("\r\n"); int len = psa.size(); if (len < 4) { ERR_EXPLAIN("Not enough response headers."); @@ -305,12 +307,17 @@ void WSLClient::disconnect_from_host(int p_code, String p_reason) { _peer->close(p_code, p_reason); _connection = Ref(NULL); _tcp = Ref(memnew(StreamPeerTCP)); - _request = ""; - _response = ""; + _key = ""; _host = ""; + _protocols.resize(0); _use_ssl = false; + + _request = ""; _requested = 0; + + memset(_resp_buf, 0, sizeof(_resp_buf)); + _resp_pos = 0; } IP_Address WSLClient::get_connected_host() const { @@ -325,7 +332,7 @@ uint16_t WSLClient::get_connected_port() const { Error WSLClient::set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) { ERR_EXPLAIN("Buffers sizes can only be set before listening or connecting"); - ERR_FAIL_COND_V(_ctx != NULL, FAILED); + ERR_FAIL_COND_V(_connection.is_valid(), FAILED); _in_buf_size = nearest_shift(p_in_buffer - 1) + 10; _in_pkt_size = nearest_shift(p_in_packets - 1); @@ -340,10 +347,9 @@ WSLClient::WSLClient() { _out_buf_size = nearest_shift((int)GLOBAL_GET(WSC_OUT_BUF) - 1) + 10; _out_pkt_size = nearest_shift((int)GLOBAL_GET(WSC_OUT_PKT) - 1); - _ctx = NULL; _peer.instance(); _tcp.instance(); - _requested = 0; + disconnect_from_host(); } WSLClient::~WSLClient() { diff --git a/modules/websocket/wsl_client.h b/modules/websocket/wsl_client.h index 1ead88f60b1d0..57dfd635b7f95 100644 --- a/modules/websocket/wsl_client.h +++ b/modules/websocket/wsl_client.h @@ -49,17 +49,22 @@ class WSLClient : public WebSocketClient { int _in_pkt_size; int _out_buf_size; int _out_pkt_size; - wslay_event_context_ptr _ctx; + Ref _peer; - // XXX we could use HTTPClient with some hacking instead... Ref _tcp; + Ref _connection; + CharString _request; + int _requested; + + uint8_t _resp_buf[WSL_MAX_HEADER_SIZE]; + int _resp_pos; + String _response; + String _key; String _host; PoolVector _protocols; - Ref _connection; - int _requested; bool _use_ssl; void _do_handshake(); diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp index 66f6a7e86ea4a..1e140a716f8e2 100644 --- a/modules/websocket/wsl_server.cpp +++ b/modules/websocket/wsl_server.cpp @@ -34,8 +34,16 @@ #include "core/os/os.h" #include "core/project_settings.h" +WSLServer::PendingPeer::PendingPeer() { + time = 0; + has_request = false; + response_sent = 0; + req_pos = 0; + memset(req_buf, 0, sizeof(req_buf)); +} + bool WSLServer::PendingPeer::_parse_request(String &r_key) { - Vector psa = request.trim_suffix("\r\n\r\n").split("\r\n"); + Vector psa = String((char *)req_buf).split("\r\n"); int len = psa.size(); if (len < 4) { ERR_EXPLAIN("Not enough response headers."); @@ -87,34 +95,35 @@ Error WSLServer::PendingPeer::do_handshake() { if (OS::get_singleton()->get_ticks_msec() - time > WSL_SERVER_TIMEOUT) return ERR_TIMEOUT; if (!has_request) { - uint8_t byte = 0; int read = 0; while (true) { - Error err = connection->get_partial_data(&byte, 1, read); + if (req_pos >= WSL_MAX_HEADER_SIZE) { + // Header is too big + ERR_EXPLAIN("Response headers too big"); + ERR_FAIL_V(ERR_OUT_OF_MEMORY); + } + Error err = connection->get_partial_data(&req_buf[req_pos], 1, read); if (err != OK) // Got an error return FAILED; else if (read != 1) // Busy, wait next poll return ERR_BUSY; - request += byte; - - if (request.size() > WSL_MAX_HEADER_SIZE) { - ERR_EXPLAIN("Response headers too big"); - ERR_FAIL_V(ERR_OUT_OF_MEMORY); - } - if (request.ends_with("\r\n\r\n")) { + char *r = (char *)req_buf; + int l = req_pos; + if (l > 3 && r[l] == '\n' && r[l - 1] == '\r' && r[l - 2] == '\n' && r[l - 3] == '\r') { + r[l - 3] = '\0'; if (!_parse_request(key)) { return FAILED; } - String r = "HTTP/1.1 101 Switching Protocols\r\n"; - r += "Upgrade: websocket\r\n"; - r += "Connection: Upgrade\r\n"; - r += "Sec-WebSocket-Accept: " + WSLPeer::compute_key_response(key) + "\r\n"; - r += "\r\n"; - response = r.utf8(); + String s = "HTTP/1.1 101 Switching Protocols\r\n"; + s += "Upgrade: websocket\r\n"; + s += "Connection: Upgrade\r\n"; + s += "Sec-WebSocket-Accept: " + WSLPeer::compute_key_response(key) + "\r\n"; + s += "\r\n"; + response = s.utf8(); has_request = true; - WARN_PRINTS("Parsed, " + key); break; } + req_pos += 1; } } if (has_request && response_sent < response.size() - 1) { diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h index ca627c1fc3de2..b0520bd731cbd 100644 --- a/modules/websocket/wsl_server.h +++ b/modules/websocket/wsl_server.h @@ -55,17 +55,14 @@ class WSLServer : public WebSocketServer { Ref connection; int time; - String request; + uint8_t req_buf[WSL_MAX_HEADER_SIZE]; + int req_pos; String key; bool has_request; CharString response; int response_sent; - PendingPeer() { - time = 0; - has_request = false; - response_sent = 0; - } + PendingPeer(); Error do_handshake(); };