Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Core/Network): Add Proxy Protocol v2 support. #18839

Merged
merged 4 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/server/apps/authserver/Server/AuthSocketMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "AuthSession.h"
#include "SocketMgr.h"
#include "Config.h"

class AuthSocketMgr : public SocketMgr<AuthSession>
{
Expand All @@ -44,7 +45,13 @@ class AuthSocketMgr : public SocketMgr<AuthSession>
protected:
NetworkThread<AuthSession>* CreateThreads() const override
{
return new NetworkThread<AuthSession>[1];
NetworkThread<AuthSession>* threads = new NetworkThread<AuthSession>[1];

bool proxyProtocolEnabled = sConfigMgr->GetOption<bool>("EnableProxyProtocol", false, true);
if (proxyProtocolEnabled)
threads[0].EnableProxyProtocol();

return threads;
}

static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
Expand Down
10 changes: 10 additions & 0 deletions src/server/apps/authserver/authserver.conf.dist
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ RealmServerPort = 3724

BindIP = "0.0.0.0"

#
# EnableProxyProtocol
# Description: Enables Proxy Protocol v2. When your server is behind a proxy,
# load balancer, or similar component, you need to enable Proxy Protocol v2 on both
# this server and the proxy/load balancer to track the real IP address of players.
# Example: 1 - (Enabled)
# Default: 0 - (Disabled)

EnableProxyProtocol = 0

#
# PidFile
# Description: Auth server PID file.
Expand Down
10 changes: 10 additions & 0 deletions src/server/apps/worldserver/worldserver.conf.dist
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ Network.OutUBuff = 4096

Network.TcpNodelay = 1

#
# Network.EnableProxyProtocol
# Description: Enables Proxy Protocol v2. When your server is behind a proxy,
# load balancer, or similar component, you need to enable Proxy Protocol v2 on both
# this server and the proxy/load balancer to track the real IP address of players.
# Example: 1 - (Enabled)
# Default: 0 - (Disabled)

Network.EnableProxyProtocol = 0

#
###################################################################################################

Expand Down
10 changes: 9 additions & 1 deletion src/server/game/Server/WorldSocketMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,13 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)

NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
{
return new WorldSocketThread[GetNetworkThreadCount()];

NetworkThread<WorldSocket>* threads = new WorldSocketThread[GetNetworkThreadCount()];

bool proxyProtocolEnabled = sConfigMgr->GetOption<bool>("Network.EnableProxyProtocol", false, true);
if (proxyProtocolEnabled)
for (int i = 0; i < GetNetworkThreadCount(); i++)
threads[i].EnableProxyProtocol();

return threads;
}
70 changes: 64 additions & 6 deletions src/server/shared/Network/NetworkThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "DeadlineTimer.h"
#include "Define.h"
#include "Socket.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
Expand All @@ -39,7 +40,7 @@ class NetworkThread
{
public:
NetworkThread() :
_ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext) { }
_ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext), _proxyHeaderReadingEnabled(false) { }

virtual ~NetworkThread()
{
Expand Down Expand Up @@ -94,6 +95,8 @@ class NetworkThread

tcp::socket* GetSocketForAccept() { return &_acceptSocket; }

void EnableProxyProtocol() { _proxyHeaderReadingEnabled = true; }

protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
Expand All @@ -105,20 +108,73 @@ class NetworkThread
if (_newSockets.empty())
return;

for (std::shared_ptr<SocketType> sock : _newSockets)
if (!_proxyHeaderReadingEnabled)
{
for (std::shared_ptr<SocketType> sock : _newSockets)
{
if (!sock->IsOpen())
{
SocketRemoved(sock);
--_connections;
continue;
}

_sockets.emplace_back(sock);

sock->Start();
}

_newSockets.clear();
}
else
{
HandleNewSocketsProxyReadingOnConnect();
}
}

void HandleNewSocketsProxyReadingOnConnect()
{
size_t index = 0;
std::vector<int> newSocketsToRemoveIndexes;
for (auto sock_iter = _newSockets.begin(); sock_iter != _newSockets.end(); ++sock_iter, ++index)
{
std::shared_ptr<SocketType> sock = *sock_iter;

if (!sock->IsOpen())
{
newSocketsToRemoveIndexes.emplace_back(index);
SocketRemoved(sock);
--_connections;
continue;
}
else
{
_sockets.emplace_back(sock);

const auto proxyHeaderReadingState = sock->GetProxyHeaderReadingState();
if (proxyHeaderReadingState == PROXY_HEADER_READING_STATE_STARTED)
continue;

switch (proxyHeaderReadingState) {
case PROXY_HEADER_READING_STATE_NOT_STARTED:
sock->AsyncReadProxyHeader();
break;

case PROXY_HEADER_READING_STATE_FINISHED:
newSocketsToRemoveIndexes.emplace_back(index);
_sockets.emplace_back(sock);

sock->Start();

break;

default:
newSocketsToRemoveIndexes.emplace_back(index);
SocketRemoved(sock);
--_connections;
break;
}
}

_newSockets.clear();
for (int removeIndex : newSocketsToRemoveIndexes)
_newSockets.erase(_newSockets.begin() + removeIndex);
}

void Run()
Expand Down Expand Up @@ -177,6 +233,8 @@ class NetworkThread
Acore::Asio::IoContext _ioContext;
tcp::socket _acceptSocket;
Acore::Asio::DeadlineTimer _updateTimer;

bool _proxyHeaderReadingEnabled;
};

#endif // NetworkThread_h__
148 changes: 146 additions & 2 deletions src/server/shared/Network/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "MessageBuffer.h"
#include <atomic>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio.hpp>
#include <functional>
#include <memory>
#include <queue>
Expand All @@ -34,12 +35,25 @@ using boost::asio::ip::tcp;
#define AC_SOCKET_USE_IOCP
#endif

enum ProxyHeaderReadingState {
PROXY_HEADER_READING_STATE_NOT_STARTED,
PROXY_HEADER_READING_STATE_STARTED,
PROXY_HEADER_READING_STATE_FINISHED,
PROXY_HEADER_READING_STATE_FAILED,
};

enum ProxyHeaderAddressFamilyAndProtocol {
PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4 = 0x11,
PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6 = 0x21,
};

template<class T>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false),
_proxyHeaderReadingState(PROXY_HEADER_READING_STATE_NOT_STARTED)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
}
Expand Down Expand Up @@ -92,11 +106,25 @@ class Socket : public std::enable_shared_from_this<T>

_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();

_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}

void AsyncReadProxyHeader()
{
if (!IsOpen())
{
return;
}

_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_STARTED;

_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ProxyReadHeaderHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}

void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
{
if (!IsOpen())
Expand All @@ -120,6 +148,8 @@ class Socket : public std::enable_shared_from_this<T>
#endif
}

[[nodiscard]] ProxyHeaderReadingState GetProxyHeaderReadingState() const { return _proxyHeaderReadingState; }

[[nodiscard]] bool IsOpen() const { return !_closed && !_closing; }

void CloseSocket()
Expand Down Expand Up @@ -187,6 +217,118 @@ class Socket : public std::enable_shared_from_this<T>
ReadHandler();
}

// ProxyReadHeaderHandler reads Proxy Protocol v2 header (v1 is not supported).
// See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt (2.2. Binary header format (version 2)) for more details.
void ProxyReadHeaderHandler(boost::system::error_code error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return;
}

_readBuffer.WriteCompleted(transferredBytes);

MessageBuffer& packet = GetReadBuffer();

const int minimumProxyProtocolV2Size = 28;
if (packet.GetActiveSize() < minimumProxyProtocolV2Size)
{
AsyncReadProxyHeader();
return;
}

uint8* readPointer = packet.GetReadPointer();

const uint8 signatureSize = 12;
const uint8 expectedSignature[signatureSize] = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
if (memcmp(packet.GetReadPointer(), expectedSignature, signatureSize) != 0)
{
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string());
return;
}

const uint8 version = (readPointer[signatureSize] & 0xF0) >> 4;
const uint8 command = (readPointer[signatureSize] & 0xF);

if (version != 2)
{
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string());
return;
}

const uint8 addressFamily = readPointer[13];
const uint16 len = (readPointer[14] << 8) | readPointer[15];
if (len+16 > packet.GetActiveSize())
{
AsyncReadProxyHeader();
return;
}

// Connection created by a proxy itself (health checks?), ignore and do nothing.
if (command == 0)
{
packet.ReadCompleted(len+16);
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED;
return;
}

auto remainingLen = packet.GetActiveSize() - 16;
readPointer += 16; // Skip strait to address.

switch (addressFamily) {
case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4:
{
if (remainingLen < 12)
{
AsyncReadProxyHeader();
return;
}

boost::asio::ip::address_v4::bytes_type b;
auto addressSize = sizeof(b);

std::copy(readPointer, readPointer+addressSize, b.begin());
_remoteAddress = boost::asio::ip::address_v4(b);

readPointer += 2 * addressSize; // Skip server address.
_remotePort = (readPointer[0] << 8) | readPointer[1];

break;
}

case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6:
{
if (remainingLen < 36)
{
AsyncReadProxyHeader();
return;
}

boost::asio::ip::address_v6::bytes_type b;
auto addressSize = sizeof(b);

std::copy(readPointer, readPointer+addressSize, b.begin());
_remoteAddress = boost::asio::ip::address_v6(b);

readPointer += 2 * addressSize; // Skip server address.
_remotePort = (readPointer[0] << 8) | readPointer[1];

break;
}

default:
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: unsupported address family type {}", GetRemoteIpAddress().to_string());
return;
}

packet.ReadCompleted(len+16);
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED;
}

#ifdef AC_SOCKET_USE_IOCP
void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
{
Expand Down Expand Up @@ -283,6 +425,8 @@ class Socket : public std::enable_shared_from_this<T>
std::atomic<bool> _closing;

bool _isWritingAsync;

ProxyHeaderReadingState _proxyHeaderReadingState;
};

#endif // __SOCKET_H__
2 changes: 0 additions & 2 deletions src/server/shared/Network/SocketMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class SocketMgr
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();

_threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
Expand Down
Loading