Permalink
Browse files

http: Allow selecting a WebSocket subprotocol.

  • Loading branch information...
unknownbrackets committed Apr 12, 2018
1 parent 8b094f8 commit 556a46f9d5250dc1d26a612f35d40db3bb0b7727
Showing with 37 additions and 5 deletions.
  1. +36 −4 ext/native/net/websocket_server.cpp
  2. +1 −1 ext/native/net/websocket_server.h
@@ -45,7 +45,32 @@ enum class Opcode {
CONTROL_MAX = 10,
};
WebSocketServer *WebSocketServer::CreateAsUpgrade(const http::Request &request) {
static inline std::string TrimString(const std::string &s) {
auto wsfront = std::find_if_not(s.begin(), s.end(), [](int c) {
// isspace() expects 0 - 255, so convert any sign-extended value.
return std::isspace(c & 0xFF);
});
auto wsback = std::find_if_not(s.rbegin(), s.rend(), [](int c){
return std::isspace(c & 0xFF);
}).base();
return wsback > wsfront ? std::string(wsfront, wsback) : std::string();
}
static bool ListContainsNoCase(const std::string &list, const std::string value) {
std::vector<std::string> split;
SplitString(list, ',', split);
for (auto item : split) {
std::transform(item.begin(), item.end(), item.begin(), tolower);
if (TrimString(item) == value) {
return true;
}
}
return false;
}
WebSocketServer *WebSocketServer::CreateAsUpgrade(const http::Request &request, const std::string &protocol) {
auto requireHeader = [&](const char *name, const char *expected) {
std::string val;
if (!request.GetHeader(name, &val)) {
@@ -58,8 +83,7 @@ WebSocketServer *WebSocketServer::CreateAsUpgrade(const http::Request &request)
if (!request.GetHeader(name, &val)) {
return false;
}
std::transform(val.begin(), val.end(), val.begin(), tolower);
return strstr(val.c_str(), expected) != 0;
return ListContainsNoCase(val, expected);
};
if (!requireHeader("upgrade", "websocket") || !requireHeaderContains("connection", "upgrade")) {
@@ -73,6 +97,14 @@ WebSocketServer *WebSocketServer::CreateAsUpgrade(const http::Request &request)
return nullptr;
}
std::string requestedProtocols;
std::string obtainedProtocolHeader;
if (!protocol.empty() && request.GetHeader("sec-websocket-protocol", &requestedProtocols)) {
if (ListContainsNoCase(requestedProtocols, protocol)) {
obtainedProtocolHeader = "Sec-WebSocket-Protocol: " + protocol + "\r\n";
}
}
std::string key;
if (!request.GetHeader("sec-websocket-key", &key)) {
request.WriteHttpResponseHeader(400, -1, "text/plain");
@@ -85,7 +117,7 @@ WebSocketServer *WebSocketServer::CreateAsUpgrade(const http::Request &request)
sha1((unsigned char *)key.c_str(), (int)key.size(), accept);
std::string acceptKey = Base64Encode(accept, 20);
std::string otherHeaders = StringFromFormat("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\n", acceptKey.c_str());
std::string otherHeaders = StringFromFormat("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\n%s", acceptKey.c_str(), obtainedProtocolHeader.c_str());
// Okay, we're good to go then.
request.WriteHttpResponseHeader(101, -1, "websocket", otherHeaders.c_str());
@@ -28,7 +28,7 @@ enum class WebSocketClose : uint16_t {
// RFC 6455
class WebSocketServer {
public:
static WebSocketServer *CreateAsUpgrade(const http::Request &request);
static WebSocketServer *CreateAsUpgrade(const http::Request &request, const std::string &protocol = "");
void Send(const std::string &str);
void Send(const std::vector<uint8_t> &payload);

0 comments on commit 556a46f

Please sign in to comment.