Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traversal: Use low TTL for probe packet #11382

Merged
merged 2 commits into from Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
143 changes: 136 additions & 7 deletions Source/Core/Common/TraversalClient.cpp
Expand Up @@ -14,8 +14,9 @@

namespace Common
{
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
: m_NetHost(netHost), m_Server(server), m_port(port)
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port,
const u16 port_alt)
: m_NetHost(netHost), m_Server(server), m_port(port), m_portAlt(port_alt)
{
netHost->intercept = TraversalClient::InterceptCallback;

Expand Down Expand Up @@ -146,6 +147,8 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
{
if (it->packet.requestId == packet->requestId)
{
if (packet->requestId == m_TestRequestId)
HandleTraversalTest();
m_OutgoingTraversalPackets.erase(it);
break;
}
Expand All @@ -161,6 +164,7 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
}
m_HostId = packet->helloFromServer.yourHostId;
m_external_address = packet->helloFromServer.yourAddress;
NewTraversalTest();
m_State = State::Connected;
if (m_Client)
m_Client->OnTraversalStateChanged();
Expand All @@ -175,7 +179,18 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
ENetBuffer buf;
buf.data = message;
buf.dataLength = sizeof(message) - 1;
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
if (m_ttlReady)
{
int oldttl;
enet_socket_get_option(m_NetHost->socket, ENET_SOCKOPT_TTL, &oldttl);
enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, m_ttl);
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, oldttl);
}
else
{
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
}
}
else
{
Expand Down Expand Up @@ -231,12 +246,15 @@ void TraversalClient::OnFailure(FailureReason reason)

void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
{
bool testPacket =
m_TestSocket != ENET_SOCKET_NULL && info->packet.type == TraversalPacketType::TestPlease;
info->sendTime = enet_time_get();
info->tries++;
ENetBuffer buf;
buf.data = &info->packet;
buf.dataLength = sizeof(info->packet);
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
if (enet_socket_send(testPacket ? m_TestSocket : m_NetHost->socket, &m_ServerAddress, &buf, 1) ==
-1)
OnFailure(FailureReason::SocketSendError);
}

Expand Down Expand Up @@ -275,6 +293,112 @@ void TraversalClient::HandlePing()
}
}

void TraversalClient::NewTraversalTest()
{
// create test socket
if (m_TestSocket != ENET_SOCKET_NULL)
enet_socket_destroy(m_TestSocket);
m_TestSocket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
ENetAddress addr = {ENET_HOST_ANY, 0};
if (m_TestSocket == ENET_SOCKET_NULL || enet_socket_bind(m_TestSocket, &addr) < 0)
{
// error, abort
if (m_TestSocket != ENET_SOCKET_NULL)
{
enet_socket_destroy(m_TestSocket);
m_TestSocket = ENET_SOCKET_NULL;
}
return;
}
enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_NONBLOCK, 1);
// create holepunch packet
TraversalPacket packet = {};
packet.type = TraversalPacketType::Ping;
packet.ping.hostId = m_HostId;
packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
// create buffer
ENetBuffer buf;
buf.data = &packet;
buf.dataLength = sizeof(packet);
// send to alt port
ENetAddress altAddress = m_ServerAddress;
altAddress.port = m_portAlt;
// set up ttl and send
int oldttl;
enet_socket_get_option(m_TestSocket, ENET_SOCKOPT_TTL, &oldttl);
enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, m_ttl);
if (enet_socket_send(m_TestSocket, &altAddress, &buf, 1) == -1)
{
// error, abort
enet_socket_destroy(m_TestSocket);
m_TestSocket = ENET_SOCKET_NULL;
return;
}
enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, oldttl);
// send the test request
packet.type = TraversalPacketType::TestPlease;
m_TestRequestId = SendTraversalPacket(packet);
}

void TraversalClient::HandleTraversalTest()
{
if (m_TestSocket != ENET_SOCKET_NULL)
{
// check for packet on test socket (with timeout)
u32 deadline = enet_time_get() + 50;
u32 waitCondition;
do
{
waitCondition = ENET_SOCKET_WAIT_RECEIVE | ENET_SOCKET_WAIT_INTERRUPT;
u32 currentTime = enet_time_get();
if (currentTime > deadline ||
enet_socket_wait(m_TestSocket, &waitCondition, deadline - currentTime) != 0)
{
// error or timeout, exit the loop and assume test failure
waitCondition = 0;
break;
}
else if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
{
// try reading the packet and see if it's relevant
ENetAddress raddr;
TraversalPacket packet;
ENetBuffer buf;
buf.data = &packet;
buf.dataLength = sizeof(packet);
int rv = enet_socket_receive(m_TestSocket, &raddr, &buf, 1);
if (rv < 0)
{
// error, exit the loop and assume test failure
waitCondition = 0;
break;
}
else if (rv < sizeof(packet) || raddr.host != m_ServerAddress.host ||
raddr.host != m_portAlt || packet.requestId != m_TestRequestId)
{
// irrelevant packet, ignore
continue;
}
}
} while (waitCondition & ENET_SOCKET_WAIT_INTERRUPT);
// regardless of what happens next, we can throw out the socket
enet_socket_destroy(m_TestSocket);
m_TestSocket = ENET_SOCKET_NULL;
if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
{
// success, we can stop now
m_ttlReady = true;
m_Client->OnTtlDetermined(m_ttl);
}
else
{
// fail, increment and retry
if (++m_ttl < 32)
NewTraversalTest();
}
}
}

TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
{
OutgoingTraversalPacketInfo info;
Expand Down Expand Up @@ -313,15 +437,19 @@ ENet::ENetHostPtr g_MainNetHost;
// explicitly requested.
static std::string g_OldServer;
static u16 g_OldServerPort;
static u16 g_OldServerPortAlt;
static u16 g_OldListenPort;

bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 listen_port)
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 server_port_alt,
u16 listen_port)
{
if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
server_port != g_OldServerPort || listen_port != g_OldListenPort)
server_port != g_OldServerPort || server_port_alt != g_OldServerPortAlt ||
listen_port != g_OldListenPort)
{
g_OldServer = server;
g_OldServerPort = server_port;
g_OldServerPortAlt = server_port_alt;
g_OldListenPort = listen_port;

ENetAddress addr = {ENET_HOST_ANY, listen_port};
Expand All @@ -337,7 +465,8 @@ bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 liste
}
host->mtu = std::min(host->mtu, NetPlay::MAX_ENET_MTU);
g_MainNetHost = std::move(host);
g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, server_port));
g_TraversalClient.reset(
new TraversalClient(g_MainNetHost.get(), server, server_port, server_port_alt));
}
return true;
}
Expand Down
16 changes: 14 additions & 2 deletions Source/Core/Common/TraversalClient.h
Expand Up @@ -24,6 +24,7 @@ class TraversalClientClient
virtual void OnTraversalStateChanged() = 0;
virtual void OnConnectReady(ENetAddress addr) = 0;
virtual void OnConnectFailed(TraversalConnectFailedReason reason) = 0;
virtual void OnTtlDetermined(u8 ttl) = 0;
};

class TraversalClient
Expand All @@ -43,7 +44,8 @@ class TraversalClient
SocketSendError,
ResendTimeout,
};
TraversalClient(ENetHost* netHost, const std::string& server, const u16 port);
TraversalClient(ENetHost* netHost, const std::string& server, const u16 port,
const u16 port_alt = 0);
~TraversalClient();

TraversalHostId GetHostID() const;
Expand Down Expand Up @@ -79,6 +81,9 @@ class TraversalClient
void HandlePing();
static int ENET_CALLBACK InterceptCallback(ENetHost* host, ENetEvent* event);

void NewTraversalTest();
void HandleTraversalTest();

ENetHost* m_NetHost;
TraversalHostId m_HostId{};
TraversalInetAddress m_external_address{};
Expand All @@ -90,14 +95,21 @@ class TraversalClient
ENetAddress m_ServerAddress{};
std::string m_Server;
u16 m_port;
u16 m_portAlt;
u32 m_PingTime = 0;

ENetSocket m_TestSocket = ENET_SOCKET_NULL;
TraversalRequestId m_TestRequestId = 0;
u8 m_ttl = 2;
bool m_ttlReady = false;
};

extern std::unique_ptr<TraversalClient> g_TraversalClient;
// the NetHost connected to the TraversalClient.
extern ENet::ENetHostPtr g_MainNetHost;

// Create g_TraversalClient and g_MainNetHost if necessary.
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 listen_port = 0);
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 server_port_alt = 0,
u16 listen_port = 0);
void ReleaseTraversalClient();
} // namespace Common
8 changes: 8 additions & 0 deletions Source/Core/Common/TraversalProto.h
Expand Up @@ -31,6 +31,10 @@ enum class TraversalPacketType : u8
ConnectReady = 6,
// [s->c] Alternately, the server might not have heard of this host.
ConnectFailed = 7,
// [c->s] Perform a traveral test. This will send two acks:
// one via the server's alt port, and one to the address corresponding to
// the given host ID.
TestPlease = 8,
};

constexpr u8 TraversalProtoVersion = 0;
Expand Down Expand Up @@ -91,6 +95,10 @@ struct TraversalPacket
TraversalRequestId requestId;
TraversalConnectFailedReason reason;
} connectFailed;
struct
{
TraversalHostId hostId;
} testPlease;
};
};
#pragma pack(pop)
Expand Down