Skip to content

Commit

Permalink
feat(Core/Network): Add Proxy Protocol v2 support. (#18839)
Browse files Browse the repository at this point in the history
* feat(Core/Network): Add Proxy Protocol v2 support.

* Fix codestyle and build.

* Another codestyle fix.

* One more missing include.
  • Loading branch information
walkline committed May 4, 2024
1 parent 715b290 commit 9815025
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 12 deletions.
9 changes: 8 additions & 1 deletion src/server/apps/authserver/Server/AuthSocketMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "AuthSession.h"
#include "SocketMgr.h"
#include "Config.h"

class AuthSocketMgr : public SocketMgr<AuthSession>
{
Expand All @@ -44,7 +45,13 @@ class AuthSocketMgr : public SocketMgr<AuthSession>
protected:
NetworkThread<AuthSession>* CreateThreads() const override
{
return new NetworkThread<AuthSession>[1];
NetworkThread<AuthSession>* threads = new NetworkThread<AuthSession>[1];

bool proxyProtocolEnabled = sConfigMgr->GetOption<bool>("EnableProxyProtocol", false, true);
if (proxyProtocolEnabled)
threads[0].EnableProxyProtocol();

return threads;
}

static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
Expand Down
10 changes: 10 additions & 0 deletions src/server/apps/authserver/authserver.conf.dist
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ RealmServerPort = 3724

BindIP = "0.0.0.0"

#
# EnableProxyProtocol
# Description: Enables Proxy Protocol v2. When your server is behind a proxy,
# load balancer, or similar component, you need to enable Proxy Protocol v2 on both
# this server and the proxy/load balancer to track the real IP address of players.
# Example: 1 - (Enabled)
# Default: 0 - (Disabled)

EnableProxyProtocol = 0

#
# PidFile
# Description: Auth server PID file.
Expand Down
10 changes: 10 additions & 0 deletions src/server/apps/worldserver/worldserver.conf.dist
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ Network.OutUBuff = 4096

Network.TcpNodelay = 1

#
# Network.EnableProxyProtocol
# Description: Enables Proxy Protocol v2. When your server is behind a proxy,
# load balancer, or similar component, you need to enable Proxy Protocol v2 on both
# this server and the proxy/load balancer to track the real IP address of players.
# Example: 1 - (Enabled)
# Default: 0 - (Disabled)

Network.EnableProxyProtocol = 0

#
###################################################################################################

Expand Down
10 changes: 9 additions & 1 deletion src/server/game/Server/WorldSocketMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,13 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)

NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
{
return new WorldSocketThread[GetNetworkThreadCount()];

NetworkThread<WorldSocket>* threads = new WorldSocketThread[GetNetworkThreadCount()];

bool proxyProtocolEnabled = sConfigMgr->GetOption<bool>("Network.EnableProxyProtocol", false, true);
if (proxyProtocolEnabled)
for (int i = 0; i < GetNetworkThreadCount(); i++)
threads[i].EnableProxyProtocol();

return threads;
}
70 changes: 64 additions & 6 deletions src/server/shared/Network/NetworkThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "DeadlineTimer.h"
#include "Define.h"
#include "Socket.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
Expand All @@ -39,7 +40,7 @@ class NetworkThread
{
public:
NetworkThread() :
_ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext) { }
_ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext), _proxyHeaderReadingEnabled(false) { }

virtual ~NetworkThread()
{
Expand Down Expand Up @@ -94,6 +95,8 @@ class NetworkThread

tcp::socket* GetSocketForAccept() { return &_acceptSocket; }

void EnableProxyProtocol() { _proxyHeaderReadingEnabled = true; }

protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
Expand All @@ -105,20 +108,73 @@ class NetworkThread
if (_newSockets.empty())
return;

for (std::shared_ptr<SocketType> sock : _newSockets)
if (!_proxyHeaderReadingEnabled)
{
for (std::shared_ptr<SocketType> sock : _newSockets)
{
if (!sock->IsOpen())
{
SocketRemoved(sock);
--_connections;
continue;
}

_sockets.emplace_back(sock);

sock->Start();
}

_newSockets.clear();
}
else
{
HandleNewSocketsProxyReadingOnConnect();
}
}

void HandleNewSocketsProxyReadingOnConnect()
{
size_t index = 0;
std::vector<int> newSocketsToRemoveIndexes;
for (auto sock_iter = _newSockets.begin(); sock_iter != _newSockets.end(); ++sock_iter, ++index)
{
std::shared_ptr<SocketType> sock = *sock_iter;

if (!sock->IsOpen())
{
newSocketsToRemoveIndexes.emplace_back(index);
SocketRemoved(sock);
--_connections;
continue;
}
else
{
_sockets.emplace_back(sock);

const auto proxyHeaderReadingState = sock->GetProxyHeaderReadingState();
if (proxyHeaderReadingState == PROXY_HEADER_READING_STATE_STARTED)
continue;

switch (proxyHeaderReadingState) {
case PROXY_HEADER_READING_STATE_NOT_STARTED:
sock->AsyncReadProxyHeader();
break;

case PROXY_HEADER_READING_STATE_FINISHED:
newSocketsToRemoveIndexes.emplace_back(index);
_sockets.emplace_back(sock);

sock->Start();

break;

default:
newSocketsToRemoveIndexes.emplace_back(index);
SocketRemoved(sock);
--_connections;
break;
}
}

_newSockets.clear();
for (int removeIndex : newSocketsToRemoveIndexes)
_newSockets.erase(_newSockets.begin() + removeIndex);
}

void Run()
Expand Down Expand Up @@ -177,6 +233,8 @@ class NetworkThread
Acore::Asio::IoContext _ioContext;
tcp::socket _acceptSocket;
Acore::Asio::DeadlineTimer _updateTimer;

bool _proxyHeaderReadingEnabled;
};

#endif // NetworkThread_h__
148 changes: 146 additions & 2 deletions src/server/shared/Network/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "MessageBuffer.h"
#include <atomic>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio.hpp>
#include <functional>
#include <memory>
#include <queue>
Expand All @@ -34,12 +35,25 @@ using boost::asio::ip::tcp;
#define AC_SOCKET_USE_IOCP
#endif

enum ProxyHeaderReadingState {
PROXY_HEADER_READING_STATE_NOT_STARTED,
PROXY_HEADER_READING_STATE_STARTED,
PROXY_HEADER_READING_STATE_FINISHED,
PROXY_HEADER_READING_STATE_FAILED,
};

enum ProxyHeaderAddressFamilyAndProtocol {
PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4 = 0x11,
PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6 = 0x21,
};

template<class T>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false),
_proxyHeaderReadingState(PROXY_HEADER_READING_STATE_NOT_STARTED)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
}
Expand Down Expand Up @@ -92,11 +106,25 @@ class Socket : public std::enable_shared_from_this<T>

_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();

_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}

void AsyncReadProxyHeader()
{
if (!IsOpen())
{
return;
}

_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_STARTED;

_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ProxyReadHeaderHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}

void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
{
if (!IsOpen())
Expand All @@ -120,6 +148,8 @@ class Socket : public std::enable_shared_from_this<T>
#endif
}

[[nodiscard]] ProxyHeaderReadingState GetProxyHeaderReadingState() const { return _proxyHeaderReadingState; }

[[nodiscard]] bool IsOpen() const { return !_closed && !_closing; }

void CloseSocket()
Expand Down Expand Up @@ -187,6 +217,118 @@ class Socket : public std::enable_shared_from_this<T>
ReadHandler();
}

// ProxyReadHeaderHandler reads Proxy Protocol v2 header (v1 is not supported).
// See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt (2.2. Binary header format (version 2)) for more details.
void ProxyReadHeaderHandler(boost::system::error_code error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return;
}

_readBuffer.WriteCompleted(transferredBytes);

MessageBuffer& packet = GetReadBuffer();

const int minimumProxyProtocolV2Size = 28;
if (packet.GetActiveSize() < minimumProxyProtocolV2Size)
{
AsyncReadProxyHeader();
return;
}

uint8* readPointer = packet.GetReadPointer();

const uint8 signatureSize = 12;
const uint8 expectedSignature[signatureSize] = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
if (memcmp(packet.GetReadPointer(), expectedSignature, signatureSize) != 0)
{
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string());
return;
}

const uint8 version = (readPointer[signatureSize] & 0xF0) >> 4;
const uint8 command = (readPointer[signatureSize] & 0xF);

if (version != 2)
{
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string());
return;
}

const uint8 addressFamily = readPointer[13];
const uint16 len = (readPointer[14] << 8) | readPointer[15];
if (len+16 > packet.GetActiveSize())
{
AsyncReadProxyHeader();
return;
}

// Connection created by a proxy itself (health checks?), ignore and do nothing.
if (command == 0)
{
packet.ReadCompleted(len+16);
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED;
return;
}

auto remainingLen = packet.GetActiveSize() - 16;
readPointer += 16; // Skip strait to address.

switch (addressFamily) {
case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4:
{
if (remainingLen < 12)
{
AsyncReadProxyHeader();
return;
}

boost::asio::ip::address_v4::bytes_type b;
auto addressSize = sizeof(b);

std::copy(readPointer, readPointer+addressSize, b.begin());
_remoteAddress = boost::asio::ip::address_v4(b);

readPointer += 2 * addressSize; // Skip server address.
_remotePort = (readPointer[0] << 8) | readPointer[1];

break;
}

case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6:
{
if (remainingLen < 36)
{
AsyncReadProxyHeader();
return;
}

boost::asio::ip::address_v6::bytes_type b;
auto addressSize = sizeof(b);

std::copy(readPointer, readPointer+addressSize, b.begin());
_remoteAddress = boost::asio::ip::address_v6(b);

readPointer += 2 * addressSize; // Skip server address.
_remotePort = (readPointer[0] << 8) | readPointer[1];

break;
}

default:
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: unsupported address family type {}", GetRemoteIpAddress().to_string());
return;
}

packet.ReadCompleted(len+16);
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED;
}

#ifdef AC_SOCKET_USE_IOCP
void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
{
Expand Down Expand Up @@ -283,6 +425,8 @@ class Socket : public std::enable_shared_from_this<T>
std::atomic<bool> _closing;

bool _isWritingAsync;

ProxyHeaderReadingState _proxyHeaderReadingState;
};

#endif // __SOCKET_H__
2 changes: 0 additions & 2 deletions src/server/shared/Network/SocketMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class SocketMgr
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();

_threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
Expand Down

0 comments on commit 9815025

Please sign in to comment.