diff --git a/src/inspector_socket.cc b/src/inspector_socket.cc index 9e388f5a00ed55..641e83f008ac47 100644 --- a/src/inspector_socket.cc +++ b/src/inspector_socket.cc @@ -393,12 +393,7 @@ static int header_value_cb(http_parser* parser, const char* at, size_t length) { auto inspector = static_cast(parser->data); auto state = inspector->http_parsing_state; state->parsing_value = true; - if (state->current_header.size() == sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1 && - node::StringEqualNoCaseN(state->current_header.data(), - SEC_WEBSOCKET_KEY_HEADER, - sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1)) { - state->ws_key.append(at, length); - } + state->headers[state->current_header].append(at, length); return 0; } @@ -471,10 +466,59 @@ static void handshake_failed(InspectorSocket* inspector) { // init_handshake references message_complete_cb static void init_handshake(InspectorSocket* inspector); +static std::string TrimPort(const std::string& host) { + size_t last_colon_pos = host.rfind(":"); + if (last_colon_pos == std::string::npos) + return host; + size_t bracket = host.rfind("]"); + if (bracket == std::string::npos || last_colon_pos > bracket) + return host.substr(0, last_colon_pos); + return host; +} + +static bool IsIPAddress(const std::string& host) { + if (host.length() >= 4 && host[0] == '[' && host[host.size() - 1] == ']') + return true; + int quads = 0; + for (char c : host) { + if (c == '.') + quads++; + else if (!isdigit(c)) + return false; + } + return quads == 3; +} + +static std::string HeaderValue(const struct http_parsing_state_s* state, + const std::string& header) { + bool header_found = false; + std::string value; + for (const auto& header_value : state->headers) { + if (node::StringEqualNoCaseN(header_value.first.data(), header.data(), + header.length())) { + if (header_found) + return ""; + value = header_value.second; + header_found = true; + } + } + return value; +} + +static bool IsAllowedHost(const std::string& host_with_port) { + std::string host = TrimPort(host_with_port); + return host.empty() || IsIPAddress(host) + || node::StringEqualNoCase(host.data(), "localhost") + || node::StringEqualNoCase(host.data(), "localhost6"); +} + static int message_complete_cb(http_parser* parser) { InspectorSocket* inspector = static_cast(parser->data); struct http_parsing_state_s* state = inspector->http_parsing_state; - if (parser->method != HTTP_GET) { + state->ws_key = HeaderValue(state, "Sec-WebSocket-Key"); + + if (!IsAllowedHost(HeaderValue(state, "Host")) || + parser->method != HTTP_GET) { handshake_failed(inspector); } else if (!parser->upgrade) { if (state->callback(inspector, kInspectorHandshakeHttpGet, state->path)) { diff --git a/src/inspector_socket.h b/src/inspector_socket.h index 46c739c4def207..6a80fdf3c73a06 100644 --- a/src/inspector_socket.h +++ b/src/inspector_socket.h @@ -5,6 +5,7 @@ #include "util-inl.h" #include "uv.h" +#include #include #include @@ -37,6 +38,7 @@ struct http_parsing_state_s { std::string ws_key; std::string path; std::string current_header; + std::map headers; }; struct ws_state_s { diff --git a/test/cctest/test_inspector_socket.cc b/test/cctest/test_inspector_socket.cc index ada3df3d438ce8..48f1fb1be3b23b 100644 --- a/test/cctest/test_inspector_socket.cc +++ b/test/cctest/test_inspector_socket.cc @@ -906,4 +906,35 @@ TEST_F(InspectorSocketTest, ErrorCleansUpTheSocket) { EXPECT_EQ(UV_EPROTO, err); } +static void HostCheckedForGet_handshake(enum inspector_handshake_event state, + const std::string& path, bool* cont) { + EXPECT_EQ(kInspectorHandshakeFailed, state); + EXPECT_TRUE(path.empty()); + *cont = false; +} + +TEST_F(InspectorSocketTest, HostCheckedForGet) { + handshake_delegate = HostCheckedForGet_handshake; + const char WRITE_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n" + "Host: notlocalhost:9222\r\n\r\n"; + send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1); + + expect_handshake_failure(); + assert_both_sockets_closed(); +} + +TEST_F(InspectorSocketTest, HostCheckedForUpgrade) { + handshake_delegate = HostCheckedForGet_handshake; + const char UPGRADE_REQUEST[] = "GET /ws/path HTTP/1.1\r\n" + "Host: nonlocalhost:9229\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: aaa==\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n"; + send_in_chunks(UPGRADE_REQUEST, sizeof(UPGRADE_REQUEST) - 1); + + expect_handshake_failure(); + assert_both_sockets_closed(); +} + } // anonymous namespace