From 9815025341d29f8ee5e83e3bf9d872d7f2317058 Mon Sep 17 00:00:00 2001 From: Anton Popovichenko Date: Sat, 4 May 2024 18:38:32 +0200 Subject: [PATCH] feat(Core/Network): Add Proxy Protocol v2 support. (#18839) * feat(Core/Network): Add Proxy Protocol v2 support. * Fix codestyle and build. * Another codestyle fix. * One more missing include. --- .../apps/authserver/Server/AuthSocketMgr.h | 9 +- .../apps/authserver/authserver.conf.dist | 10 ++ .../apps/worldserver/worldserver.conf.dist | 10 ++ src/server/game/Server/WorldSocketMgr.cpp | 10 +- src/server/shared/Network/NetworkThread.h | 70 ++++++++- src/server/shared/Network/Socket.h | 148 +++++++++++++++++- src/server/shared/Network/SocketMgr.h | 2 - 7 files changed, 247 insertions(+), 12 deletions(-) diff --git a/src/server/apps/authserver/Server/AuthSocketMgr.h b/src/server/apps/authserver/Server/AuthSocketMgr.h index 53f46ff92aee1..c6cb27f4f4748 100644 --- a/src/server/apps/authserver/Server/AuthSocketMgr.h +++ b/src/server/apps/authserver/Server/AuthSocketMgr.h @@ -20,6 +20,7 @@ #include "AuthSession.h" #include "SocketMgr.h" +#include "Config.h" class AuthSocketMgr : public SocketMgr { @@ -44,7 +45,13 @@ class AuthSocketMgr : public SocketMgr protected: NetworkThread* CreateThreads() const override { - return new NetworkThread[1]; + NetworkThread* threads = new NetworkThread[1]; + + bool proxyProtocolEnabled = sConfigMgr->GetOption("EnableProxyProtocol", false, true); + if (proxyProtocolEnabled) + threads[0].EnableProxyProtocol(); + + return threads; } static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex) diff --git a/src/server/apps/authserver/authserver.conf.dist b/src/server/apps/authserver/authserver.conf.dist index d11634cf76ebc..6cd737219189a 100644 --- a/src/server/apps/authserver/authserver.conf.dist +++ b/src/server/apps/authserver/authserver.conf.dist @@ -64,6 +64,16 @@ RealmServerPort = 3724 BindIP = "0.0.0.0" +# +# EnableProxyProtocol +# Description: Enables Proxy Protocol v2. When your server is behind a proxy, +# load balancer, or similar component, you need to enable Proxy Protocol v2 on both +# this server and the proxy/load balancer to track the real IP address of players. +# Example: 1 - (Enabled) +# Default: 0 - (Disabled) + +EnableProxyProtocol = 0 + # # PidFile # Description: Auth server PID file. diff --git a/src/server/apps/worldserver/worldserver.conf.dist b/src/server/apps/worldserver/worldserver.conf.dist index 59d4708a05571..25e16abcac2e4 100644 --- a/src/server/apps/worldserver/worldserver.conf.dist +++ b/src/server/apps/worldserver/worldserver.conf.dist @@ -378,6 +378,16 @@ Network.OutUBuff = 4096 Network.TcpNodelay = 1 +# +# Network.EnableProxyProtocol +# Description: Enables Proxy Protocol v2. When your server is behind a proxy, +# load balancer, or similar component, you need to enable Proxy Protocol v2 on both +# this server and the proxy/load balancer to track the real IP address of players. +# Example: 1 - (Enabled) +# Default: 0 - (Disabled) + +Network.EnableProxyProtocol = 0 + # ################################################################################################### diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp index 5c161078ee273..f4afe7ceea7cc 100644 --- a/src/server/game/Server/WorldSocketMgr.cpp +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -114,5 +114,13 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) NetworkThread* WorldSocketMgr::CreateThreads() const { - return new WorldSocketThread[GetNetworkThreadCount()]; + + NetworkThread* threads = new WorldSocketThread[GetNetworkThreadCount()]; + + bool proxyProtocolEnabled = sConfigMgr->GetOption("Network.EnableProxyProtocol", false, true); + if (proxyProtocolEnabled) + for (int i = 0; i < GetNetworkThreadCount(); i++) + threads[i].EnableProxyProtocol(); + + return threads; } diff --git a/src/server/shared/Network/NetworkThread.h b/src/server/shared/Network/NetworkThread.h index a6d4ff3528e68..b1f18a3aedc2f 100644 --- a/src/server/shared/Network/NetworkThread.h +++ b/src/server/shared/Network/NetworkThread.h @@ -20,6 +20,7 @@ #include "DeadlineTimer.h" #include "Define.h" +#include "Socket.h" #include "Errors.h" #include "IoContext.h" #include "Log.h" @@ -39,7 +40,7 @@ class NetworkThread { public: NetworkThread() : - _ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext) { } + _ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext), _proxyHeaderReadingEnabled(false) { } virtual ~NetworkThread() { @@ -94,6 +95,8 @@ class NetworkThread tcp::socket* GetSocketForAccept() { return &_acceptSocket; } + void EnableProxyProtocol() { _proxyHeaderReadingEnabled = true; } + protected: virtual void SocketAdded(std::shared_ptr /*sock*/) { } virtual void SocketRemoved(std::shared_ptr /*sock*/) { } @@ -105,20 +108,73 @@ class NetworkThread if (_newSockets.empty()) return; - for (std::shared_ptr sock : _newSockets) + if (!_proxyHeaderReadingEnabled) { + for (std::shared_ptr sock : _newSockets) + { + if (!sock->IsOpen()) + { + SocketRemoved(sock); + --_connections; + continue; + } + + _sockets.emplace_back(sock); + + sock->Start(); + } + + _newSockets.clear(); + } + else + { + HandleNewSocketsProxyReadingOnConnect(); + } + } + + void HandleNewSocketsProxyReadingOnConnect() + { + size_t index = 0; + std::vector newSocketsToRemoveIndexes; + for (auto sock_iter = _newSockets.begin(); sock_iter != _newSockets.end(); ++sock_iter, ++index) + { + std::shared_ptr sock = *sock_iter; + if (!sock->IsOpen()) { + newSocketsToRemoveIndexes.emplace_back(index); SocketRemoved(sock); --_connections; + continue; } - else - { - _sockets.emplace_back(sock); + + const auto proxyHeaderReadingState = sock->GetProxyHeaderReadingState(); + if (proxyHeaderReadingState == PROXY_HEADER_READING_STATE_STARTED) + continue; + + switch (proxyHeaderReadingState) { + case PROXY_HEADER_READING_STATE_NOT_STARTED: + sock->AsyncReadProxyHeader(); + break; + + case PROXY_HEADER_READING_STATE_FINISHED: + newSocketsToRemoveIndexes.emplace_back(index); + _sockets.emplace_back(sock); + + sock->Start(); + + break; + + default: + newSocketsToRemoveIndexes.emplace_back(index); + SocketRemoved(sock); + --_connections; + break; } } - _newSockets.clear(); + for (int removeIndex : newSocketsToRemoveIndexes) + _newSockets.erase(_newSockets.begin() + removeIndex); } void Run() @@ -177,6 +233,8 @@ class NetworkThread Acore::Asio::IoContext _ioContext; tcp::socket _acceptSocket; Acore::Asio::DeadlineTimer _updateTimer; + + bool _proxyHeaderReadingEnabled; }; #endif // NetworkThread_h__ diff --git a/src/server/shared/Network/Socket.h b/src/server/shared/Network/Socket.h index 0bf100b2a3501..af948618e408b 100644 --- a/src/server/shared/Network/Socket.h +++ b/src/server/shared/Network/Socket.h @@ -22,6 +22,7 @@ #include "MessageBuffer.h" #include #include +#include #include #include #include @@ -34,12 +35,25 @@ using boost::asio::ip::tcp; #define AC_SOCKET_USE_IOCP #endif +enum ProxyHeaderReadingState { + PROXY_HEADER_READING_STATE_NOT_STARTED, + PROXY_HEADER_READING_STATE_STARTED, + PROXY_HEADER_READING_STATE_FINISHED, + PROXY_HEADER_READING_STATE_FAILED, +}; + +enum ProxyHeaderAddressFamilyAndProtocol { + PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4 = 0x11, + PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6 = 0x21, +}; + template class Socket : public std::enable_shared_from_this { public: explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), - _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false) + _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false), + _proxyHeaderReadingState(PROXY_HEADER_READING_STATE_NOT_STARTED) { _readBuffer.Resize(READ_BLOCK_SIZE); } @@ -92,11 +106,25 @@ class Socket : public std::enable_shared_from_this _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); - _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), std::bind(&Socket::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } + void AsyncReadProxyHeader() + { + if (!IsOpen()) + { + return; + } + + _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_STARTED; + + _readBuffer.Normalize(); + _readBuffer.EnsureFreeSpace(); + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + std::bind(&Socket::ProxyReadHeaderHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + } + void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t)) { if (!IsOpen()) @@ -120,6 +148,8 @@ class Socket : public std::enable_shared_from_this #endif } + [[nodiscard]] ProxyHeaderReadingState GetProxyHeaderReadingState() const { return _proxyHeaderReadingState; } + [[nodiscard]] bool IsOpen() const { return !_closed && !_closing; } void CloseSocket() @@ -187,6 +217,118 @@ class Socket : public std::enable_shared_from_this ReadHandler(); } + // ProxyReadHeaderHandler reads Proxy Protocol v2 header (v1 is not supported). + // See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt (2.2. Binary header format (version 2)) for more details. + void ProxyReadHeaderHandler(boost::system::error_code error, size_t transferredBytes) + { + if (error) + { + CloseSocket(); + return; + } + + _readBuffer.WriteCompleted(transferredBytes); + + MessageBuffer& packet = GetReadBuffer(); + + const int minimumProxyProtocolV2Size = 28; + if (packet.GetActiveSize() < minimumProxyProtocolV2Size) + { + AsyncReadProxyHeader(); + return; + } + + uint8* readPointer = packet.GetReadPointer(); + + const uint8 signatureSize = 12; + const uint8 expectedSignature[signatureSize] = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}; + if (memcmp(packet.GetReadPointer(), expectedSignature, signatureSize) != 0) + { + _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED; + LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string()); + return; + } + + const uint8 version = (readPointer[signatureSize] & 0xF0) >> 4; + const uint8 command = (readPointer[signatureSize] & 0xF); + + if (version != 2) + { + _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED; + LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string()); + return; + } + + const uint8 addressFamily = readPointer[13]; + const uint16 len = (readPointer[14] << 8) | readPointer[15]; + if (len+16 > packet.GetActiveSize()) + { + AsyncReadProxyHeader(); + return; + } + + // Connection created by a proxy itself (health checks?), ignore and do nothing. + if (command == 0) + { + packet.ReadCompleted(len+16); + _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED; + return; + } + + auto remainingLen = packet.GetActiveSize() - 16; + readPointer += 16; // Skip strait to address. + + switch (addressFamily) { + case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4: + { + if (remainingLen < 12) + { + AsyncReadProxyHeader(); + return; + } + + boost::asio::ip::address_v4::bytes_type b; + auto addressSize = sizeof(b); + + std::copy(readPointer, readPointer+addressSize, b.begin()); + _remoteAddress = boost::asio::ip::address_v4(b); + + readPointer += 2 * addressSize; // Skip server address. + _remotePort = (readPointer[0] << 8) | readPointer[1]; + + break; + } + + case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6: + { + if (remainingLen < 36) + { + AsyncReadProxyHeader(); + return; + } + + boost::asio::ip::address_v6::bytes_type b; + auto addressSize = sizeof(b); + + std::copy(readPointer, readPointer+addressSize, b.begin()); + _remoteAddress = boost::asio::ip::address_v6(b); + + readPointer += 2 * addressSize; // Skip server address. + _remotePort = (readPointer[0] << 8) | readPointer[1]; + + break; + } + + default: + _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED; + LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: unsupported address family type {}", GetRemoteIpAddress().to_string()); + return; + } + + packet.ReadCompleted(len+16); + _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED; + } + #ifdef AC_SOCKET_USE_IOCP void WriteHandler(boost::system::error_code error, std::size_t transferedBytes) { @@ -283,6 +425,8 @@ class Socket : public std::enable_shared_from_this std::atomic _closing; bool _isWritingAsync; + + ProxyHeaderReadingState _proxyHeaderReadingState; }; #endif // __SOCKET_H__ diff --git a/src/server/shared/Network/SocketMgr.h b/src/server/shared/Network/SocketMgr.h index 02720b2055514..5cb64d0f180f8 100644 --- a/src/server/shared/Network/SocketMgr.h +++ b/src/server/shared/Network/SocketMgr.h @@ -94,8 +94,6 @@ class SocketMgr try { std::shared_ptr newSocket = std::make_shared(std::move(sock)); - newSocket->Start(); - _threads[threadIndex].AddSocket(newSocket); } catch (boost::system::system_error const& err)